当需要从一个图中抽取某一个或一系列节点(目标节点)周围的k级连续节点时,可以用PyG的 k_hop_subgraph() 方法。
k=1:与目标节点直接相连的节点;
k=2:与目标节点隔1个节点相连的节点;
……
一、构建原始大图
第一部分内容纯粹是为了准备原始大图(相关内容见:PyG构建图对象并转换成networkx图对象),与本文的核心内容无关。
1.1 原始数据准备
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx, k_hop_subgraph
import matplotlib.pyplot as plt
# 节点特征矩阵(一行对应一个节点的特征,共7个节点(=节点特征矩阵的行数),每个节点有3个特征)
my_node_features = torch.tensor([[0, 0, 0],
[-1, -1, -1],
[-2, -2, -2],
[-3, -3, -3],
[-4, -4, -4],
[-5, -5, -5],
[-6, -6, -6]],dtype=torch.float)
# 边的节点对,共有6条边(7个节点:0、1、2、3、4、5、6,与节点特征矩阵的行标一一对应)
my_edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
[2, 2, 4, 4, 6, 6]])
# 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
my_edge_attr = torch.tensor([[11, 11, 11, 11],
[22, 22, 22, 22],
[33, 33, 33, 33],
[44, 44, 44, 44],
[55, 55, 55, 55],
[66, 66, 66, 66]], dtype=torch.float)
# 边权重,共有6个边权重,一条边一个
my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)
1.2 根据原始数据构建PyG对象
# 构建Pyg对象
pyg_G = Data(x=my_node_features,
edge_index=my_edge_index,
edge_attr=my_edge_attr,
edge_weight=my_edge_weight)
print(pyg_G)
输出:Data(x=[7, 3], edge_index=[2, 6], edge_attr=[6, 4], edge_weight=[6])
# 输出信息,为to_networkx()的参数提供参考
print(pyg_G.node_attrs())
print(pyg_G.edge_attrs())
输出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
1.3 将PyG对象转化成networkx对象,用于成图
这里需要注意的是如果原始数据准备不恰当,可能会导致to_networkx()将PyG对象转化成networkx对象后,多出来一些节点,详见:PyG构建图对象并转换成networkx图对象
# PyG对象转化成networkx对象
nx_G = to_networkx(data=pyg_G,
node_attrs=['x'],
edge_attrs=['edge_weight', 'edge_attr'],
to_undirected=False) # 将PyG的Data对象转化成networkx的数据对象
print(f'节点名:{nx_G.nodes}')
print(f'边的节点对:{nx_G.edges}')
print('每个节点的属性:')
# print(nx_G.nodes(data=True))
for node in nx_G.nodes(data=True):
print(node)
print('每条边的属性:')
# print(nx_G.edges(data=True))
for edge in nx_G.edges(data=True):
print(edge)
# 画图
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(nx_G) # 迭代计算‘可视化图片’上每个节点的坐标
nx.draw(nx_G, pos, node_size=400, with_labels=True) # 绘图
plt.show()

二、k_hop_subgraph()抽取子图
2.1 方法解释
result = k_hop_subgraph(
node_idx=, 目标节点(int,或list int);
num_hops=, 待获取的目标节点的几级周围节点(int);
edge_index=, 原图的边节点对矩阵(tensor),其shape=[2, 边数];
relabel_nodes=, 是否对获取的节点从0开始重新顺序编号(True/False)(要注意★★);
flow=, 根据边的方向选择节点(str = target_to_source(目标节点到其他节点)或source_to_target(其他节点到目标节点)(要注意★★);
directed=, 如果=False,将包括所有所有采样节点之间的边。(默认=True)
)
k_hop_subgraph()的返回值result是一个包含四个元素的tuple:
result[0]:抽取出来的节点(包括目标节点)list,已经按照从小到大顺序排列好了;
result[1]:抽取的节点的边对,是个shape=[2, 抽取的边的条数]的tensor;
result[2]:每个目标节点在result[0]中的位置,是一个长度与目标节点个数相同的一维tensor;
result[3]:抽取的每条边在原图边对矩阵中的位置,一个由True和False组成的list,长度等原图边对矩阵的列数。
2.2 抽取子图并绘图
下面看具体例子(紧接着前面代码):
2.2.1 抽取子图信息
我们希望找到6号节点周围的k=2的节点(即找到6号节点的1、2级节点)。
设置:relabel_nodes=False,不对找到的节点重新命名;
设置:flow='source_to_target'),只要求边是指向目标节点的节点;
(上面这两个参数的设置对结果影响很大,特别注意。)
target_node_idx = [6] # 确定目标节点序列
k = 2 # 目标节点往周围跳跃的次数(即几级节点)
# 设置重要参数
relabel_nodes=False
flow='source_to_target'
# 抽取节点
result = k_hop_subgraph(node_idx=target_node_idx,
num_hops=k,
edge_index=pyg_G.edge_index,
relabel_nodes=relabel_nodes,
flow=flow,
directed=False)
sub_nodes_names = result[0]
sub_edge_index = result[1]
target_node_map = result[2]
sub_edge_mask = result[3]
print(f'抽取的节点序列:{sub_nodes_names}')
print(f'抽取的边节点对:{sub_edge_index}')
print(f'目标节点在抽取节点序列中的位置:{target_node_map}')
print(f'选中的边在原图的边序列中的位置:{sub_edge_mask}')
print(f'抽取目标节点:{sub_nodes_names[target_node_map]}')

从上述‘输出结果’看relabel_nodes=False时,和‘我们的目标’是完全一致的。
relabel_nodes=True时,只有边的节点对的序号被从0开始重新命名了,其他没变。
2.2.2 计算抽取边在原图边中的序号
这一步是为了从原图的一些其他数据中抽取跟子图相关的数据。
# 计算抽取的边对矩阵sub_edge_index的每一条边对在原边对矩阵my_edge_index中的位置序号
match_indices_list = []
for row in sub_edge_index.t():
res = torch.where(torch.all(torch.isin(my_edge_index.t(), row), dim=1))
if res[0].numel()!=0:
print(res)
match_indices_list.append(res[0].item())
#match_indices_list
2.2.3 构建子图PyG对象
# 创建子图的 Data 对象
sub_pyg_G = Data(x=my_node_features[sub_nodes_names,:], # 在原节点特征矩阵中抽取被选中的节点的特征
edge_index=sub_edge_index, # 在原边节点对矩阵中抽取被选中的边节点对
edge_attr=my_edge_attr[match_indices_list,:], # 在原边特征矩阵中抽取被选中的边特征
edge_weight=my_edge_weight[match_indices_list]) # 在原边权重矩阵中抽取被选中的边权重
sub_pyg_G
输出:
Data(x=[3, 3], edge_index=[2, 2], edge_attr=[2, 4], edge_weight=[2])
# 输出信息作为to_networkx()设置参数时的参考。
print(sub_pyg_G.node_attrs())
print(sub_pyg_G.edge_attrs())
输出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
2.2.4 将子图的PyG对象转换为networkx对象
# 将PyG对象转换为networkx对象
sub_nx_G = to_networkx(data=sub_pyg_G,
node_attrs=['x'],
edge_attrs=['edge_attr', 'edge_weight'],
to_undirected=False)
2.2.5 将冗余节点从子图的networkx图对象中删除
这一步是需要的,这里我们抽取的子图节点为‘抽取的节点序列:tensor([2, 3, 4, 5, 6])’,显然是不包括0和1号节点,但因为to_networkx()方法本身的一些原因,会在转化时把0和1号节点也加上(注意转化时加上的0和1号节点与原图中的0和1号节点是完全不同的,只是名称相同),称其为冗余节点。这样一来,就会导致networkx对象的节点数比PyG对象的节点数多。因此冗余节点需要删除。
如果k_hop_subgraph()函数的 relabel_nodes=Ture,则无论被选取的节点是什么,都会被从0开始重新命名,这时就不会出现冗余节点,如果执行下述代码,反而会删除有效节点。
(注意:这一步不是必须的,只有存在冗余节点时需要执行。当。)
# 将没有被选中的节点从networkx图对象中删除
if not relabel_nodes:
nodes_to_remove = list(set(list(sub_nx_G.nodes)) - set(sub_nodes_names.tolist()))
sub_nx_G.remove_nodes_from(nodes_to_remove)
2.2.6 子图绘图
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(sub_nx_G) # 定义节点的布局
nx.draw(sub_nx_G, pos, with_labels=True, node_color='red', edge_color="green", node_size=100, font_size=10)

网友评论