美文网首页
pytorch DataLoader 如何处理字符数据?

pytorch DataLoader 如何处理字符数据?

作者: 井底蛙蛙呱呱呱 | 来源:发表于2020-11-25 15:45 被阅读0次

    在pytorch的数据加载中,我们一般先定义一个Dataset,然后再去创建DataLoader。默认的DataLoader要求Dataset中返回的(即__getitem__)为numpy 数组或tensor等数据类型,并不支持字符类型,如果我们需要返回字符类型,则需要对DataLoader中的collate _fn函数做一些更改了:

    from torch.utils.data.dataloader import default_collate 
    
    
    def id_collate(batch):
        new_batch = []
        ids = []
        for _batch in batch:
            new_batch.append(_batch[:-1])
            ids.append(_batch[-1])
        return default_collate(new_batch), ids
    

    使用上面的函数作为DataLoader中的collate _fn函数便能使dataloader处理字符数据了。

    参考:
    https://discuss.pytorch.org/t/building-custom-dataset-how-to-return-ids-as-well/22931

    相关文章

      网友评论

          本文标题:pytorch DataLoader 如何处理字符数据?

          本文链接:https://www.haomeiwen.com/subject/kpkniktx.html