在训练神经网络时我们可能会同时使用多个dataloader,则需要在原来的enumerate(dataloader)
上加入zip
函数:
for i, data in enumerate(zip(dataloader1, dataloader2)):
pass
此时,data
是一个(2,2)的元组,第一行是dataloader1的data和label,第二行是dataloader2的data和label。
另外,dataloader1和dataloader2的大小很有可能不一样,即len(dataloader1) != len(dataloader2)
,则它会以数量最少的那个dataloader为标准停止,例如len(dataloader1)=85
,len(dataloader2)=80
,则最终的i
就是80。并且dataloader2的最后一个batch的大小可能不够一个batch size。
这时,我们可以调用:
from itertools import cycle
for i, data in enumerate(zip(dataloader1, cycle(dataloader2))):
pass
dataloader2就又会从头开始循环,直到将dataloader1也循环完。
但是你有没有发现,这是另一个死循环啊~!dataloader1的最后一个batch的数据数量也不一定等于batch size。
网友评论