1、Data
首先需要从数据入手,了解怎么把数据处理成能输入到模型中的格式,geometric中是torch_geometric.data包,其中torch_geometric.data.Data是组织数据的具体形式,用来规范单个样本的格式,它包含一下参数(可选):
x: 图中节点的特征矩阵,大小是[num_nodes, num_node_features] ,即[节点数量,节点特征维度] ,数据类型一般是torch.float;
edge_index: COO形式的边邻接表,大小是 [2, num_edges],即只存储存在边的节点对,其中第一行是边出发节点,第二行对应位置是边指向节点,数据类型是torch.long。如果是无向图,那么1->2,2->1,都要存储;
edge_attr: 边的特征矩阵,大小是[num_edges, num_edge_features],即[边数量,边特征维度];
y: 标签,如果是节点级的标签,大小是[num_nodes, *],如果是图级标签,大小是[1, *],数据类型一般是torch.float;
pos:节点的位置矩阵,大小是 [num_nodes, num_dimensions](这个参数我不太懂,一般也用不着);
举个栗子:

有4个结点,每个节点有两个特征,有自己的类别标签,是个有向图
import torch
from torch_geometric.data import Data
x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 0, 3], [1, 0, 1, 3, 2]], dtype=torch.long) #应该与节点对顺序无关,这个顺序怎么写都行
data = Data(x=x, y=y, edge_index=edge_index)
这个data有一些属性函数:
data.keys : ['x', 'edge_index', 'y'] #可以吧data看成是一个字典类似的数据结构
data.num_nodes :4
data.num_edges :5
还有一些官方文档上的data.num_node_features,data.has_isolated_nodes(),data.has_self_loops(),data.is_directed()等很多功能,我的geometric版本是1.7.2,也可以看源码。
以上是同质图,也就是图中节点种类一致,geometric也支持异质图(latest版中),但现在我安装不了latest版本,之后再研究如何升级版本,想研究一下异构,先码住。
from torch_geometric.data import HeteroData
data = HeteroData()
#创建两种类别的节点特征矩阵
num_posts = 5
num_posts_features = 25
num_users = 3
num_users_features = 10
data['post'].x = torch.randn(num_posts, num_posts_features)
data['user'].x = torch.randn(num_users, num_users_features)
#创建一种边的类别
data['user', 'writes', 'post'].edge_index =... # [2, num_edges]
终于安装高版本的geometric了,折腾了一天,之前一直因为其他依赖包出错而不能升级,妥了。
2、Dataset
因为Data只是单个样本,所以需要构建一个数据集,把它们组织在一起,方便调用,参数如下:
root (string, optional) :用于存储该数据集的根目录,默认为none;
transform (callable, optional) :默认为none,对Data对象进行变换的函数(类比CV里面对图像进行裁剪啊,缩放啊,旋转啊之类),返回一个变换后的数据,但这种变换不会保存下来,每次读取的时候变换完了存在内存里;
pre_transform (callable, optional) :默认为none,这个变换后的数据就会存在磁盘里,之后直接读取变换完的数据;
pre_filter (callable, optional) :默认为none,返回一个布尔值,表示这个数据是否会在最终的dataset中,相当于过滤器
先看它具有的属性:
self.raw_dir = self.root + 'raw' #是存放原始数据的路径
self.processed_dir = self.root + 'processed' #是存放处理后的数据的路径
self.raw_file_names 获取self.raw_dir下的文件名信息
self.processed_file_names 获取self.processed_dir下的文件名信息
self.num_node_features 每个节点的特征维度
再看它支持的函数:
download():将数据下载并存放到self.raw_dir下
process():处理数据,并将处理后的数据存放到self.processed_dir下
len() : 数据包长度
get (idx:int) : 根据idx索引从processed_dir下的数据中返回其中对应的单个图数据
shuffle () : 混淆dataset中的数据样本,默认为false
dataset支持切片操作,可以用于划分数据集。
dataset还有一类是torch_geometric.data.InMemoryDataset,它是继承torch_geometric.data.Dataset,如果数据是存到CPU内存中计算,就用torch_geometric.data.InMemoryDataset,如果是GPU,就用torch_geometric.data.Dataset,我使用GPU,就先学Dataset.
官方给出的自定义Dataset代码:
3. DataLoader
dataloader的作用是从dataset中分批次取数据样本送入模型,可以一个一个取,也可以一个mini_batch取,也可以全部取。
参数如下:
dataset:上面Dataset类的实例化对象
batch_size :int,你的batch的大小,默认为1
shuffle :布尔值,是否打乱每个epoch时从dataset中取mini_batch的顺序,默认为false
follow_batch : 是一个列表或元组,为列表或元组中的每个值赋一个批处理向量(没太懂)
exclude_keys :是一个列表或元组,其中包含的值会被从dataset中去掉
dataloader取出的每个minibatch的数据类型是torch_geometric.data.Batch (继承自torch_geometric.data.Data)。Batch类型有一个属性是batch,它是一个列向量,表示这个minibatch中每个节点对应的图数据的编号:
假如这个minibatch有n个图,那么batch = [0,..0,1,...,1,2,...,2,...n-1,...,n-1],这个属性的作用是可以结合scatter_mean()函数求每个图的所有节点在各个维度的平均值,scatter_mean(data.x,data.batch,dim=0)。
网友评论