美文网首页
1.pytorch数据集的加载和使用

1.pytorch数据集的加载和使用

作者: 三角绿毛怪 | 来源:发表于2020-10-14 10:38 被阅读0次
    import torch
    from torch.utils.data import Dataset,DataLoader
    import math
    #传入数据的地址,加r为了防止转义字符出错,让他认为后面是字符串
    data_path = r"D:\19黑马\阶段9-人工智能NLP项目\第三天\代码\data\SMSSpamCollection"
    
    #完成数据集
    class MyDataset(Dataset):
        #继承Dataset
        def __init__(self):
            #初始化
            self.lines = open(data_path,errors='ignore').readlines()
    
        def __getitem__(self,index):
            #获取索引对应位置的一条数据
            #删除首尾的空格
            #返回标签和文本
            curline = self.lines[index].strip()
            label = curline[:4].strip()
            content = curline[4:].strip()
            return label,content
    
        def __len__(self):
            #返回数据的总数量
            return len(self.lines)
    
    my_dataset = MyDataset()
    #batch_size表示分组的量,shuffle表示是否每次打乱顺序,drop_last表示向上取整,最后一个小于batch_size的就不要了
    data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True,drop_last=True)
    
    
    if __name__ == '__main__':
        my_dataset = MyDataset()
        print(my_dataset[101])
        print(len(my_dataset))
        for i in data_loader:
            print(i)
            break
        print(len(my_dataset))
        # print(math.ceil(len(my_dataset)/7))
        print(len(data_loader))
    

    相关文章

      网友评论

          本文标题:1.pytorch数据集的加载和使用

          本文链接:https://www.haomeiwen.com/subject/ixbipktx.html