美文网首页
连续子图抽取:根据指定目标节点抽取连续子图k_hop_subgr

连续子图抽取:根据指定目标节点抽取连续子图k_hop_subgr

作者: 马尔代夫Maldives | 来源:发表于2024-01-15 10:39 被阅读0次

当需要从一个图中抽取某一个或一系列节点(目标节点)周围的k级连续节点时,可以用PyG的 k_hop_subgraph() 方法。
k=1:与目标节点直接相连的节点;
k=2:与目标节点隔1个节点相连的节点;
……

一、构建原始大图

第一部分内容纯粹是为了准备原始大图(相关内容见:PyG构建图对象并转换成networkx图对象),与本文的核心内容无关。

1.1 原始数据准备

import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx,  k_hop_subgraph

import matplotlib.pyplot as plt
# 节点特征矩阵(一行对应一个节点的特征,共7个节点(=节点特征矩阵的行数),每个节点有3个特征)
my_node_features = torch.tensor([[0, 0, 0], 
                                 [-1, -1, -1], 
                                 [-2, -2, -2],
                                 [-3, -3, -3],
                                 [-4, -4, -4],
                                 [-5, -5, -5],
                                 [-6, -6, -6]],dtype=torch.float)

# 边的节点对,共有6条边(7个节点:0、1、2、3、4、5、6,与节点特征矩阵的行标一一对应)
my_edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
                              [2, 2, 4, 4, 6, 6]])

# 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
my_edge_attr = torch.tensor([[11, 11, 11, 11],
                             [22, 22, 22, 22],
                             [33, 33, 33, 33],
                             [44, 44, 44, 44],
                             [55, 55, 55, 55],
                             [66, 66, 66, 66]], dtype=torch.float)

# 边权重,共有6个边权重,一条边一个
my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)

1.2 根据原始数据构建PyG对象

# 构建Pyg对象
pyg_G = Data(x=my_node_features,
             edge_index=my_edge_index,
             edge_attr=my_edge_attr,
             edge_weight=my_edge_weight)
print(pyg_G)
输出:Data(x=[7, 3], edge_index=[2, 6], edge_attr=[6, 4], edge_weight=[6])
# 输出信息,为to_networkx()的参数提供参考
print(pyg_G.node_attrs())
print(pyg_G.edge_attrs())
输出:
['x']
['edge_attr', 'edge_index', 'edge_weight']

1.3 将PyG对象转化成networkx对象,用于成图

这里需要注意的是如果原始数据准备不恰当,可能会导致to_networkx()将PyG对象转化成networkx对象后,多出来一些节点,详见:PyG构建图对象并转换成networkx图对象

# PyG对象转化成networkx对象
nx_G = to_networkx(data=pyg_G, 
                   node_attrs=['x'],
                   edge_attrs=['edge_weight', 'edge_attr'],
                   to_undirected=False)  # 将PyG的Data对象转化成networkx的数据对象

print(f'节点名:{nx_G.nodes}')
print(f'边的节点对:{nx_G.edges}')
print('每个节点的属性:')
# print(nx_G.nodes(data=True))
for node in nx_G.nodes(data=True):
    print(node)
print('每条边的属性:')
# print(nx_G.edges(data=True))
for edge in nx_G.edges(data=True):
    print(edge)

# 画图
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
nx.draw(nx_G, pos, node_size=400, with_labels=True)  # 绘图
plt.show()
原大图.png

二、k_hop_subgraph()抽取子图

2.1 方法解释

result = k_hop_subgraph(
node_idx=, 目标节点(int,或list int);
num_hops=, 待获取的目标节点的几级周围节点(int);
edge_index=, 原图的边节点对矩阵(tensor),其shape=[2, 边数];
relabel_nodes=, 是否对获取的节点从0开始重新顺序编号(True/False)(要注意★★);
flow=, 根据边的方向选择节点(str = target_to_source(目标节点到其他节点)或source_to_target(其他节点到目标节点)(要注意★★);
directed=, 如果=False,将包括所有所有采样节点之间的边。(默认=True)
)

k_hop_subgraph()的返回值result是一个包含四个元素的tuple:
result[0]:抽取出来的节点(包括目标节点)list,已经按照从小到大顺序排列好了;
result[1]:抽取的节点的边对,是个shape=[2, 抽取的边的条数]的tensor;
result[2]:每个目标节点在result[0]中的位置,是一个长度与目标节点个数相同的一维tensor;
result[3]:抽取的每条边在原图边对矩阵中的位置,一个由True和False组成的list,长度等原图边对矩阵的列数。

2.2 抽取子图并绘图

下面看具体例子(紧接着前面代码):

2.2.1 抽取子图信息

我们希望找到6号节点周围的k=2的节点(即找到6号节点的1、2级节点)
设置:relabel_nodes=False,不对找到的节点重新命名
设置:flow='source_to_target'),只要求边是指向目标节点的节点
(上面这两个参数的设置对结果影响很大,特别注意。)

target_node_idx = [6]  # 确定目标节点序列
k = 2  # 目标节点往周围跳跃的次数(即几级节点)

# 设置重要参数
relabel_nodes=False
flow='source_to_target'

# 抽取节点
result = k_hop_subgraph(node_idx=target_node_idx,
                        num_hops=k,
                        edge_index=pyg_G.edge_index,
                        relabel_nodes=relabel_nodes,
                        flow=flow,
                        directed=False)

sub_nodes_names = result[0]
sub_edge_index  = result[1]
target_node_map = result[2]
sub_edge_mask   = result[3]

print(f'抽取的节点序列:{sub_nodes_names}')
print(f'抽取的边节点对:{sub_edge_index}')
print(f'目标节点在抽取节点序列中的位置:{target_node_map}')
print(f'选中的边在原图的边序列中的位置:{sub_edge_mask}')
print(f'抽取目标节点:{sub_nodes_names[target_node_map]}')
抽取结果.png

从上述‘输出结果’看relabel_nodes=False时,和‘我们的目标’是完全一致的。
relabel_nodes=True时,只有边的节点对的序号被从0开始重新命名了,其他没变。

2.2.2 计算抽取边在原图边中的序号

这一步是为了从原图的一些其他数据中抽取跟子图相关的数据。

# 计算抽取的边对矩阵sub_edge_index的每一条边对在原边对矩阵my_edge_index中的位置序号
match_indices_list = []
for row in sub_edge_index.t():
    res = torch.where(torch.all(torch.isin(my_edge_index.t(), row), dim=1))
    if res[0].numel()!=0:
        print(res)
        match_indices_list.append(res[0].item())
        
#match_indices_list

2.2.3 构建子图PyG对象

# 创建子图的 Data 对象
sub_pyg_G = Data(x=my_node_features[sub_nodes_names,:],  # 在原节点特征矩阵中抽取被选中的节点的特征
                 edge_index=sub_edge_index,  # 在原边节点对矩阵中抽取被选中的边节点对
                 edge_attr=my_edge_attr[match_indices_list,:],  # 在原边特征矩阵中抽取被选中的边特征
                 edge_weight=my_edge_weight[match_indices_list])   # 在原边权重矩阵中抽取被选中的边权重

sub_pyg_G
输出:
Data(x=[3, 3], edge_index=[2, 2], edge_attr=[2, 4], edge_weight=[2])
# 输出信息作为to_networkx()设置参数时的参考。
print(sub_pyg_G.node_attrs())
print(sub_pyg_G.edge_attrs())
输出:
['x']
['edge_attr', 'edge_index', 'edge_weight']

2.2.4 将子图的PyG对象转换为networkx对象

# 将PyG对象转换为networkx对象
sub_nx_G = to_networkx(data=sub_pyg_G, 
                       node_attrs=['x'],
                       edge_attrs=['edge_attr', 'edge_weight'],
                       to_undirected=False)

2.2.5 将冗余节点从子图的networkx图对象中删除

这一步是需要的,这里我们抽取的子图节点为‘抽取的节点序列:tensor([2, 3, 4, 5, 6])’,显然是不包括0和1号节点,但因为to_networkx()方法本身的一些原因,会在转化时把0和1号节点也加上(注意转化时加上的0和1号节点与原图中的0和1号节点是完全不同的,只是名称相同),称其为冗余节点。这样一来,就会导致networkx对象的节点数比PyG对象的节点数多。因此冗余节点需要删除。
如果k_hop_subgraph()函数的 relabel_nodes=Ture,则无论被选取的节点是什么,都会被从0开始重新命名,这时就不会出现冗余节点,如果执行下述代码,反而会删除有效节点。
(注意:这一步不是必须的,只有存在冗余节点时需要执行。当。

# 将没有被选中的节点从networkx图对象中删除
if not relabel_nodes:
    nodes_to_remove = list(set(list(sub_nx_G.nodes)) - set(sub_nodes_names.tolist()))
    sub_nx_G.remove_nodes_from(nodes_to_remove)

2.2.6 子图绘图

plt.figure(figsize=(4, 4))

pos = nx.spring_layout(sub_nx_G)  # 定义节点的布局
nx.draw(sub_nx_G, pos, with_labels=True, node_color='red', edge_color="green", node_size=100, font_size=10)
最终结果.png

相关文章

  • SQL根据指定节点ID获取所有父级节点和子级节点

    根据指定节点ID获取所有父节点 根据指定节点ID获取所有子节点

  • 关系抽取(分类)总结

    关系抽取(分类)总结 关系抽取研究现状 基于路径的实体图关系抽取模型 ChineseNRE 关系抽取(关系学习)综...

  • MVP 终极抽取(下)

    流程图: 通过MVP终极抽取(上)让我们可以知道Model 的抽取。那么开始view ,presenter 的抽取...

  • 异质图网络

    了解了一丢丢异质图网络的文章,做下记录。 一、基础概念 异质图:各个网络节点的类型不一样。 元路径:从异质图抽取出...

  • 知识图谱学习笔记(五)——实体识别(1)

    实体识别(信息抽取) 1. 信息抽取概述 信息抽取定义:从自然语言文本中抽取指定类型的实体、关系、事件等事实信息,...

  • Relation Extraction Survey

    实体关系抽取介绍 实体关系抽(RE, Relation Extraction)取任务是信息抽取中重要的一个子任务,...

  • 电路基础第二章

    图是节点和支路的集合 连通图,非连通图,分离度ρ/有向图,无向图/子图/平面图 节点的次数(度数)/路径→回路 割...

  • 电路基础第二章 - 草稿

    图是节点(n)和支路(b)的集合 连通图,非连通图,分离度ρ/有向图,无向图/子图/平面图 节点的次数(度数)/路...

  • 优先队列关键字总结

    性质 优先队列为二叉树,用连续的数组存储 分类 小顶堆:每个节点的子节点大于等于其父节点 大顶堆:每个节点的子节点...

  • python3 通过pyspark抽取Hive数据 进行线性回

    背景:抽取hive上的数据,搭建线性回归模型,进行预测。目标:抽取hive数据,并进行预测。 一、数据抽取 本次为...

网友评论

      本文标题:连续子图抽取:根据指定目标节点抽取连续子图k_hop_subgr

      本文链接:https://www.haomeiwen.com/subject/elvtodtx.html