美文网首页
pytorch-geometric 从入门到不放弃 day3

pytorch-geometric 从入门到不放弃 day3

作者: 不太聪明的亚子 | 来源:发表于2021-09-16 19:09 被阅读0次

已经学习了data,dataset和dataloader,不如就先实战根据自己的数据集,写好自定义的dataset吧。

1、首先将每个图数据预处理成Data需要的形式:

x是所有节点的特征,【num_nodes, embed_dim】,要注意这里所有的节点特征维度需要一致;

edge_index是邻接表,有向图:【【0,1】,【1,2】】;无向图:【【0,1,1,2】,【1,2,0,1】】;

y类别标签;

其他自定义的数据,需要是int或者float类型。

最后分别转换成numpy.array类型,使用numpy.savez()保存成npz文件,分别存放在train/eval/test路径下的graph文件夹里,后面要用。

np.savez(os.path.join(path, data_name, 'graph', file_id+'.npz'), x=x, edge_index=edge_idx, y=y, dtype=object)

2、自定义dataset,主要是__getitem__函数,逻辑是传入上面处理好的文件list,然后getitem函数按照列表下标读取,返回Data类型就好。

class GraphDataset(Dataset):

    def __init__(self, root, file_list, treeLenDic, lower = 2, upper = 100000):

        super(GraphDataset, self).__init__()

        self.root = root

        self.file_list = list(filter(lambda id: id.split('.')[0] in treeLenDic.keys() and treeLenDic[id.split('.')[0]] >= lower and treeLenDic[id.split('.')[0]] <= upper, file_list))

    def __len__(self):

        return len(self.file_list)

    def __getitem__(self, idx):

        id = self.file_list[idx]

        data = np.load(os.path.join(self.root, id), allow_pickle=True)

        return Data(x=torch.tensor(data['x'], dtype=torch.float32),

                edge_index=torch.LongTensor(data['edge_index']),

                y=torch.LongTensor([int(data['y'])]))

这里对每个图文件的长度做了筛选,要至少有两个节点,那种只有一个点的就不考虑了,TreeLenDic是个字典,{graph_id: len}.

3. 将Dataset实例化的对象传入DataLoader就可以批量读取数据了

好啦,到这里我数据预处理以及自定义Dataset就搞定了,可以开始学习torch.geometric.nn里面的网络模型啦~

相关文章

网友评论

      本文标题:pytorch-geometric 从入门到不放弃 day3

      本文链接:https://www.haomeiwen.com/subject/tlncgltx.html