美文网首页
pytorch学习笔记-dataloader忽略异常值

pytorch学习笔记-dataloader忽略异常值

作者: 升不上三段的大鱼 | 来源:发表于2022-03-24 16:56 被阅读0次

    在使用自己的数据的时候,如果希望输入的数据满足一些条件,不满足条件的数据不会用于训练,一个方法是预处理,把不满足条件的数据去掉,另一种就是重写dataloader 的 collate_fn 函数。

    class DataSet():
         def __init__(self, data):  
              self.data= data
              self.visited = np.zeros(len(data))   # 用来避免重复取值
         def __getitem__(self,idx):
              if self.visited[idx] == 1:   # 避免重复取到不想要的
                  return None
              data = self.data[idx]
              self.visited[idx] = 1
              if data is None:  # 这里写去掉数据的条件
                  return None 
              return data
    
    dataset = Dataset(data)
    
    dataloader = DataLoader(dataset , batch_size=4,
                            shuffle=True, num_workers=1, collate_fn = my_collate )
    
    def my_collate(batch): 
        len_batch = len(batch)  # original batch length
        batch = list(filter(lambda x: x is not None, batch))  # filter out all the Nones
        if len_batch > len(batch):  # source all the required samples from the original dataset at random
            diff = len_batch - len(batch)
            for i in range(diff):
                item = dataset[np.random.randint(0, len(dataset))]
                while item is None:
                    item = dataset[np.random.randint(0, len(dataset))]
                batch.append(item)
        return torch.utils.data.dataloader.default_collate(batch) 
    

    参考: https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another

    相关文章

      网友评论

          本文标题:pytorch学习笔记-dataloader忽略异常值

          本文链接:https://www.haomeiwen.com/subject/rrqljrtx.html