美文网首页深度学习
PyG构建图对象并转换成networkx图对象

PyG构建图对象并转换成networkx图对象

作者: 马尔代夫Maldives | 来源:发表于2024-01-10 14:37 被阅读0次

    一、写在前面

    PyG 是一款基于PyTorch 的图神经网络库,它提供了很多经典的图神经网络模型和图数据集。
    在使用 PyG 框架来构建和训练图网络模型时,需要事先将图数据换成PyG定义的“图对象”
    PyG 提供多种类型的图对象(在torch_geometric.data下),常用的包括:Data(同构图)和HeteroData(异构图)

    无标题.png

    二、基本用法(以Data对象为例)

    2.1) 构建图对象

    构建一张图的Data对象时,通常需要提供以下基本数据:
    from torch_geometric.data import Data
    Data ( x: Optional[torch.Tensor] = None,
       edge_index: Optional[torch.Tensor] = None,
       edge_attr: Optional[torch.Tensor] = None,
       y: Optional[torch.Tensor] = None,
       pos: Optional[torch.Tensor] = None,
       **kwargs)

    • 节点节点名称用数字序号表示:0,1,2,... ,num _ node-1(共num_nodes个节点),这是默认且固定的,不需要指定

    • x:节点特征矩阵:shape为[num_nodes, num_node_features]
      一张图的所有节点的特征存储于该二维矩阵中,即一行表示一个节点,一列表示一个特征(一个节点可以有多个特征),行序号对应节点序号(0,1,2,... ,num _ node-1)

    • edge_index:边矩阵:shape为[2, num_edges]
      一张图的所有边存储于该二维矩阵中,其中第一行表示所有边的起始节点编号,第二行表示所有边的目标节点编号,类型为 torch.long。
      (注意:edge _ index 中的元素必须在{0,1,2,... ,num _ node-1}

    • edge_attr:边特征矩阵:shape为[num_edges, num_edge_features]
      一张图的所有边的特征存储于该二维矩阵中,即一行表示一条边,一列表示一个特征(一条边可以有多个特征)

    • y:训练标签:(可能具有任意形状)。例如,如果是节点级别的标签,其形状为 [num_nodes, *];如果是图级别的标签, 其形状为为 [1,*]

    • 节点位置(pos):记录每个节点的具体位置,存储于shape为[num_nodes, num_dimensions]的二维矩阵中

    上述信息通常需要用户提前准备好,才能构建一个Data对象,但都不是必须要提供的。一般对于一张图而言,最重要的是节点特征矩阵、边矩阵、边特征矩阵

    Data 对象有点类似 Python 中的字典,属性和数据用键值对表示,因此可以用点“.”或方括号“[]”来访问、修改、增加其内部的数据,就跟字典的操作方式一样。

    2.2) 图对象的方法

    见‘举例1’一节的3.3)

    三、举例1(简单例子)

    目标:为下图创建Data对象:


    原图.png
    import torch
    from torch_geometric.data import Data
    import networkx as nx
    from torch_geometric.utils import to_networkx
    import matplotlib.pyplot as plt
    

    3.1)图的原始图数据准备:

    首先将上图中的图(Graph)转化成对应的tensor(非常重要,决定了后面的图对象是否能正确构建)。

    # 节点特征矩阵(一行对应一个节点的特征,每个节点有3个特征)
    >>my_node_features = torch.tensor([[-1, -1, -1], 
                                     [-2, -2, -2],
                                     [-3, -3, -3],
                                     [-4, -4, -4]],dtype=torch.float)
    
    # 边的节点对,共有7条边(四个节点:0、1、2、3),必须用7组节点对来表示
    >>my_edge_index = torch.tensor([[0, 1, 2, 1, 3, 2, 3],
                                  [1, 2, 1, 3, 1, 3, 2]], dtype=torch.long)
    
    # 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
    >>my_edge_attr = torch.tensor([[11, 11, 11, 11],
                                 [22, 22, 22, 22],
                                 [33, 33, 33, 33],
                                 [44, 44, 44, 44],
                                 [55, 55, 55, 55],
                                 [66, 66, 66, 66],
                                 [77, 77, 77, 77]], dtype=torch.float)
    
    # 边权重,共有7个边权重,一条边一个
    >>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)
    

    3.2)根据图的原始数据构建PyG图对象(Data对象):

    >>pyg_G = Data(x=my_node_features, 
                 edge_index=my_edge_index, 
                 edge_attr=my_edge_attr, 
                 edge_weight=my_edge_weight)
    >>print(pyg_G)
    输出:
    Data(x=[4, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])
    
    PyG对象输出解读.jpg
    对PyG对象输出信息的解读很重要(特别是对于无法图像化的大图)!

    3.3)图对象(Data对象)提供的几种常用方法(其他方法使用‘dir(图对象)’获取):

    • .num_nodes:返回节点个数(int)
    • .num_node_types:返回节点种类数(int)
    • .num_node_features:返回节点特征数(int)
    • .node_attrs():返回与节点相关的属性名列表(str list)
    >>pyg_G.node_attrs()
    ['x']
    
    • .x:★返回节点特征矩阵(tensor array)
    >>pyg_G.x
    tensor([[-1., -1., -1.],
            [-2., -2., -2.],
            [-3., -3., -3.],
            [-4., -4., -4.]])
    
    • .num_edges:返回边条数(int)

    • .num_edge_types:返回边种类数(int)

    • .num_edge_features:返回边特征数(int)

    • .edge_index:★返回边的节点对(tensor array)

    >>pyg_G.edge_index
    tensor([[0, 1, 2, 1, 3, 2, 3],
            [1, 2, 1, 3, 1, 3, 2]])
    
    • edge_attrs():返回与边相关的属性名列表(str list)
    pyg_G.edge_attrs()
    ['edge_weight', 'edge_attr', 'edge_index']
    
    • pyg_G.edge_weight:返回边权重(tensor array)
    >>pyg_G.edge_weight
    tensor([1., 2., 3., 4., 5., 6., 7.])
    
    • .edge_attr:返回边的特征矩阵(tensor array)
    >>pyg_G.edge_attr
    tensor([[11., 11., 11., 11.],
            [22., 22., 22., 22.],
            [33., 33., 33., 33.],
            [44., 44., 44., 44.],
            [55., 55., 55., 55.],
            [66., 66., 66., 66.],
            [77., 77., 77., 77.]])
    
    • .edge_stores 和.node_stores:返回存储了整个图的信息(dict list)
    >>pyg_G.node_stores
    [{'x': tensor([[-1., -1., -1.],
             [-2., -2., -2.],
             [-3., -3., -3.],
             [-4., -4., -4.]]), 'edge_index': tensor([[0, 1, 2, 1, 3, 2, 3],
             [1, 2, 1, 3, 1, 3, 2]]), 'edge_attr': tensor([[11., 11., 11., 11.],
             [22., 22., 22., 22.],
             [33., 33., 33., 33.],
             [44., 44., 44., 44.],
             [55., 55., 55., 55.],
             [66., 66., 66., 66.],
             [77., 77., 77., 77.]]), 'edge_weight': tensor([1., 2., 3., 4., 5., 6., 7.])}]
    

    3.4)PyG图对象与networkx图对象的转换(检查我们创建的PyG对象是否与原图一致)

    https://blog.csdn.net/zzy_NIC/article/details/127996911
    https://zhuanlan.zhihu.com/p/92482339
    PyG主要用于图网络计算,本身没有可视化功能。可利用PyG的to_networkx()方法将PyG同构图对象转化成networkx对象,然后可视化。
    to_networkx(
       data: PyG的Data或HeteroData对象,
       node_attrs: 节点属性名(可迭代str对象,默认None),
       edge_attrs: 边属性名(可迭代str对象,默认None),
       graph_attrs: 图属性名(可迭代str对象,默认None),
       to_undirected: 转换成无向图还是有向图(True/False,默认False),
       remove_self_loops: 是否将图中的loop移除(True/False,默认False),
    )

    ■■Case1:转换时,不指定 node_attrs、edge_attrs、graph_attrs参数。
    从输出结果来看,这种情况to_networkx()只会把PyG对象的节点(nodes)和边(edges)转换到networkx对象中,其他属性信息不会包含(下图中全是空{ })。其次,从输出的节点名、边的节点对以及图像来看,与最前面的‘原图’是相同的,说明我们构建的PyG是对的。

    # Case1
    >>nx_G = to_networkx(data=pyg_G, to_undirected=False)  # 将PyG的Data对象转化成networkx的数据对象
    
    >>print(f'节点名:{nx_G.nodes}')
    >>print(f'边的节点对:{nx_G.edges}')
    >>print('每个节点的属性:')
    # print(nx_G.nodes(data=True))
    >>for node in nx_G.nodes(data=True):
        print(node)
    >>print('每条边的属性:')
    # print(nx_G.edges(data=True))
    >>for edge in nx_G.edges(data=True):
        print(edge)
    
    # 画图
    >>pos = nx.spring_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
    >>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20)  # 绘图
    >>plt.show()
    
    输出:如下图所示
    
    图片1.png

    ■■Case2:转换时,指定 node_attrs、edge_attrs、graph_attrs参数。
    这种情况,首先得查看原PyG对象有哪些属性:

    >>print(pyg_G.node_attrs())
    >>print(pyg_G.edge_attrs())
    输出:
    ['x']
    ['edge_weight', 'edge_attr', 'edge_index']
    

    可见,该PyG对象有节点属性有['x'],边属性有['edge_weight', 'edge_attr', 'edge_index'],
    于是可以在to_networkx()转换时进行指定(特别注意:'edge_index'这个属性不能写在to_networkx()的edge_attrs变量中,否则出错),见下面代码:

    # Case2
    >>nx_G = to_networkx(data=pyg_G, 
                       node_attrs=['x'],
                       edge_attrs=['edge_weight', 'edge_attr'],
                       to_undirected=True)  # 将PyG的Data对象转化成networkx的数据对象
    
    >>print(f'节点名:{nx_G.nodes}')
    >>print(f'边的节点对:{nx_G.edges}')
    >>print('每个节点的属性:')
    # print(nx_G.nodes(data=True))
    >>for node in nx_G.nodes(data=True):
        print(node)
    >>print('每条边的属性:')
    # print(nx_G.edges(data=True))
    >>for edge in nx_G.edges(data=True):
        print(edge)
    
    # 画图
    >>pos = nx.spring_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
    >>nx.draw(nx_G, pos, node_size=400, with_labels=True)  # 绘图
    >>plt.show()
    
    图片2.png

    从上图的输出结果看,已经把PyG对象的节点和边的各种属性同时转化成networkx对象的属性了。

    四、举例2(PyG对象节点、边、节点特征、边特征之间的对应关系剖析)

    import torch
    from torch_geometric.data import Data
    import networkx as nx
    from torch_geometric.utils import to_networkx
    import matplotlib.pyplot as plt
    

    4.1)原始图数据准备

    与前面不同的是,此例中事先并不知道图的结构,只有数据。
    而且注意:
    my_node_features的shape=[5,3],即节点序号为:0、1、2、3、4;
    但边的节点对my_edge_index 指定的节点为:10、11、12、13。

    # 节点特征矩阵(一行对应一个节点的特征,每个节点有3个特征)
    >>my_node_features = torch.tensor([[-1, -1, -1], 
                                     [-2, -2, -2],
                                     [-3, -3, -3],
                                     [-4, -4, -4],
                                     [-5, -5, -5]],
                                    dtype=torch.float)
    
    # 边矩阵(这里共有7条边,必须用7组节点对来表示,节点对的前后位置可以任意调换,对结果没有影响)
    >>my_edge_index = torch.tensor([[10, 11, 12, 11, 13, 13, 12],
                                  [11, 12, 11, 13, 11, 12, 13]], dtype=torch.long)
    
    # 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
    >>my_edge_attr = torch.tensor([[11, 11, 11, 11],
                                 [22, 22, 22, 22],
                                 [33, 33, 33, 33],
                                 [44, 44, 44, 44],
                                 [55, 55, 55, 55],
                                 [66, 66, 66, 66],
                                 [77, 77, 77, 77]], dtype=torch.float)
    
    # 边权重,共设置了7个边权重
    >>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)
    

    4.2)根据原始数据构建PyG图对象

    >>pyg_G = Data(x=my_node_features, 
                 edge_index=my_edge_index, 
                 edge_attr=my_edge_attr, 
                 edge_weight=my_edge_weight)
    >>print(pyg_G)
    输出:
    Data(x=[5, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])
    

    从PyG对象的输出结果看,该图有5个节点,每个节点3个特征;共有7条边,每条边4个特征,1个权重。

    输出节点和边的属性名列表:
    >>print(pyg_G.node_attrs())
    >>print(pyg_G.edge_attrs())
    输出:
    ['x']
    ['edge_index', 'edge_weight', 'edge_attr']
    

    4.3)将PyG对象转换成networkx对象,并成图

    >>nx_G = to_networkx(data=pyg_G, 
                       node_attrs=['x'],
                       edge_attrs=['edge_weight', 'edge_attr'],
                       to_undirected=False)  # 将PyG的Data对象转化成networkx的数据对象
    
    >>print(f'节点名:{nx_G.nodes}')
    >>print(f'边的节点对:{nx_G.edges}')
    >>print('每个节点的属性:')
    # print(nx_G.nodes(data=True))
    >>for node in nx_G.nodes(data=True):
        print(node)
    >>print('每条边的属性:')
    # print(nx_G.edges(data=True))
    >>for edge in nx_G.edges(data=True):
        print(edge)
    
    # 画图
    >>pos = nx.circular_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
    >>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20)  # 绘图
    >>plt.show()
    

    Case1:参数to_undirected=False,即有向图
    从输出结果的节点名来看,该图共有9个节点,前面的[0,1,2,3,4]五个节点(注意,代码中我们并没有指定这些节点名)是to_networkx()根据节点特征矩阵my_node_features的行数按0,1,2……顺序自动分配的(这是PyG固定的);后面四个节点[10,11,12,13]是to_networkx()根据用户给的边的节点对矩阵my_edge_index中自动抽取并生成的
    ★★可见,在利用to_networkx()将PyG对象转换成networkx对象时,to_networkx会自动补充一些节点,比如这里的[0,1,2,3,4],我们将其称为冗余节点!可以写额外的代码来将这些冗余节点删除,见子图抽取的‘2.2.5 将冗余节点从子图的networkx图对象中删除’

    关于边的特征和权重,PyG会自动将边特征矩阵my_edge_attr的
    第1行作为第1条边【这里是(10,11)】的特征;
    第2行作为第2条边【这里是(11,12)】的特征;
    第3行作为第3条边【这里是(12,11)】的特征;
    ……
    同理,PyG会自动将边权重向量my_edge_weight的
    第1个值作为第1条边【这里是(10,11)】的权重;
    第2个值作为第2条边【这里是(11,12)】的权重;
    第3个值作为第3条边【这里是(12,11)】的权重;
    ……

    特别注意:边特征矩阵(my_edge_attr)的行数、边权重向量(my_edge_weight)的元素个数都必须和边节点对矩阵(my_edge_index )的列数相同,否则结果会出错

    14029140-5305108e9eb8121b.jpg

    Case2:参数to_undirected=True,即无向图
    Case2除了边有所变化以外,其他都与Cas1一样。
    Case2主要为了说明to_networkx()这个函数的参数to_undirected=False/True(有向图和无向图)的区别。
    Cas1是有向图,根据给定的节点对矩阵my_edge_index从起点到终点画图即可,这个没啥疑问。
    Cas2是无向图:

    • 如果两个节点之间只有1条边,则有向图和无向图都用这条边,比如这里的(10,11);
    • 如果两个节点之间有2条边,则使用小节点序号到大节点序号的边作为无向边,比如这里的(11,12)和(12,11),选择(11,12)作为无向边,(11,13)和(13,11),选择(11,13)作为无向边,(13,12)和(12,13),选择(12,13)作为无向边。
    新建 Microsoft Visio 绘图.jpg

    参考:
    https://zhuanlan.zhihu.com/p/599104296
    https://blog.csdn.net/ARPOSPF/article/details/128398393

    相关文章

      网友评论

        本文标题:PyG构建图对象并转换成networkx图对象

        本文链接:https://www.haomeiwen.com/subject/lecvidtx.html