美文网首页
pytorch DataLoader 如何处理字符数据?

pytorch DataLoader 如何处理字符数据?

作者: 井底蛙蛙呱呱呱 | 来源:发表于2020-11-25 15:45 被阅读0次

在pytorch的数据加载中,我们一般先定义一个Dataset,然后再去创建DataLoader。默认的DataLoader要求Dataset中返回的(即__getitem__)为numpy 数组或tensor等数据类型,并不支持字符类型,如果我们需要返回字符类型,则需要对DataLoader中的collate _fn函数做一些更改了:

from torch.utils.data.dataloader import default_collate 


def id_collate(batch):
    new_batch = []
    ids = []
    for _batch in batch:
        new_batch.append(_batch[:-1])
        ids.append(_batch[-1])
    return default_collate(new_batch), ids

使用上面的函数作为DataLoader中的collate _fn函数便能使dataloader处理字符数据了。

参考:
https://discuss.pytorch.org/t/building-custom-dataset-how-to-return-ids-as-well/22931

相关文章

网友评论

      本文标题:pytorch DataLoader 如何处理字符数据?

      本文链接:https://www.haomeiwen.com/subject/kpkniktx.html