在pytorch的数据加载中,我们一般先定义一个Dataset
,然后再去创建DataLoader
。默认的DataLoader要求Dataset中返回的(即__getitem__
)为numpy 数组或tensor等数据类型,并不支持字符类型,如果我们需要返回字符类型,则需要对DataLoader中的collate _fn
函数做一些更改了:
from torch.utils.data.dataloader import default_collate
def id_collate(batch):
new_batch = []
ids = []
for _batch in batch:
new_batch.append(_batch[:-1])
ids.append(_batch[-1])
return default_collate(new_batch), ids
使用上面的函数作为DataLoader
中的collate _fn
函数便能使dataloader处理字符数据了。
参考:
https://discuss.pytorch.org/t/building-custom-dataset-how-to-return-ids-as-well/22931
网友评论