美文网首页
pytorch-geometric 从入门到不放弃 day2

pytorch-geometric 从入门到不放弃 day2

作者: 不太聪明的亚子 | 来源:发表于2021-09-14 17:35 被阅读0次

1、Data

首先需要从数据入手,了解怎么把数据处理成能输入到模型中的格式,geometric中是torch_geometric.data包,其中torch_geometric.data.Data是组织数据的具体形式,用来规范单个样本的格式,它包含一下参数(可选):

x: 图中节点的特征矩阵,大小是[num_nodes, num_node_features] ,即[节点数量,节点特征维度] ,数据类型一般是torch.float;

edge_index: COO形式的边邻接表,大小是 [2, num_edges],即只存储存在边的节点对,其中第一行是边出发节点,第二行对应位置是边指向节点,数据类型是torch.long。如果是无向图,那么1->2,2->1,都要存储;

edge_attr: 边的特征矩阵,大小是[num_edges, num_edge_features],即[边数量,边特征维度];

y: 标签,如果是节点级的标签,大小是[num_nodes, *],如果是图级标签,大小是[1, *],数据类型一般是torch.float;

pos:节点的位置矩阵,大小是 [num_nodes, num_dimensions](这个参数我不太懂,一般也用不着);

举个栗子:

有4个结点,每个节点有两个特征,有自己的类别标签,是个有向图

import torch

from torch_geometric.data import Data

x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)

y = torch.tensor([0, 1, 0, 1], dtype=torch.float) 

edge_index = torch.tensor([[0, 1, 2, 0, 3], [1, 0, 1, 3, 2]], dtype=torch.long)  #应该与节点对顺序无关,这个顺序怎么写都行

data = Data(x=x, y=y, edge_index=edge_index)

这个data有一些属性函数:

data.keys : ['x', 'edge_index', 'y'] #可以吧data看成是一个字典类似的数据结构

data.num_nodes :4

data.num_edges :5

还有一些官方文档上的data.num_node_features,data.has_isolated_nodes(),data.has_self_loops(),data.is_directed()等很多功能,我的geometric版本是1.7.2,也可以看源码

以上是同质图,也就是图中节点种类一致,geometric也支持异质图(latest版中),但现在我安装不了latest版本,之后再研究如何升级版本,想研究一下异构,先码住。

from torch_geometric.data import HeteroData

data = HeteroData()

#创建两种类别的节点特征矩阵

num_posts = 5

num_posts_features = 25

num_users = 3

num_users_features = 10

data['post'].x = torch.randn(num_posts, num_posts_features)

data['user'].x = torch.randn(num_users, num_users_features)

#创建一种边的类别

data['user', 'writes', 'post'].edge_index =... # [2, num_edges] 

终于安装高版本的geometric了,折腾了一天,之前一直因为其他依赖包出错而不能升级,妥了。


2、Dataset

因为Data只是单个样本,所以需要构建一个数据集,把它们组织在一起,方便调用,参数如下:

root (string, optional)  :用于存储该数据集的根目录,默认为none;

transform (callable, optional) :默认为none,对Data对象进行变换的函数(类比CV里面对图像进行裁剪啊,缩放啊,旋转啊之类),返回一个变换后的数据,但这种变换不会保存下来,每次读取的时候变换完了存在内存里;

pre_transform (callable, optional)  :默认为none,这个变换后的数据就会存在磁盘里,之后直接读取变换完的数据;

pre_filter (callable, optional) :默认为none,返回一个布尔值,表示这个数据是否会在最终的dataset中,相当于过滤器

先看它具有的属性:

self.raw_dir = self.root + 'raw' #是存放原始数据的路径

self.processed_dir = self.root + 'processed' #是存放处理后的数据的路径

self.raw_file_names 获取self.raw_dir下的文件名信息

self.processed_file_names 获取self.processed_dir下的文件名信息

self.num_node_features 每个节点的特征维度

再看它支持的函数:

download():将数据下载并存放到self.raw_dir下

process():处理数据,并将处理后的数据存放到self.processed_dir下  

len() : 数据包长度

get (idx:int) : 根据idx索引从processed_dir下的数据中返回其中对应的单个图数据

shuffle () : 混淆dataset中的数据样本,默认为false 

dataset支持切片操作,可以用于划分数据集。

dataset还有一类是torch_geometric.data.InMemoryDataset,它是继承torch_geometric.data.Dataset,如果数据是存到CPU内存中计算,就用torch_geometric.data.InMemoryDataset,如果是GPU,就用torch_geometric.data.Dataset,我使用GPU,就先学Dataset.

官方给出的自定义Dataset代码:


3. DataLoader

dataloader的作用是从dataset中分批次取数据样本送入模型,可以一个一个取,也可以一个mini_batch取,也可以全部取。

参数如下:

dataset:上面Dataset类的实例化对象

batch_size :int,你的batch的大小,默认为1

shuffle :布尔值,是否打乱每个epoch时从dataset中取mini_batch的顺序,默认为false

follow_batch : 是一个列表或元组,为列表或元组中的每个值赋一个批处理向量(没太懂)

exclude_keys :是一个列表或元组,其中包含的值会被从dataset中去掉

dataloader取出的每个minibatch的数据类型是torch_geometric.data.Batch (继承自torch_geometric.data.Data)。Batch类型有一个属性是batch,它是一个列向量,表示这个minibatch中每个节点对应的图数据的编号:

假如这个minibatch有n个图,那么batch = [0,..0,1,...,1,2,...,2,...n-1,...,n-1],这个属性的作用是可以结合scatter_mean()函数求每个图的所有节点在各个维度的平均值,scatter_mean(data.x,data.batch,dim=0)。

相关文章

网友评论

      本文标题:pytorch-geometric 从入门到不放弃 day2

      本文链接:https://www.haomeiwen.com/subject/hlunwltx.html