美文网首页人工智能
GraphSAGE:大型图的归纳式表示学习

GraphSAGE:大型图的归纳式表示学习

作者: 酷酷的群 | 来源:发表于2021-08-16 08:53 被阅读0次

论文标题:Inductive Representation Learning on Large Graphs
论文链接:https://arxiv.org/abs/1706.02216
论文来源:NIPS 2017

一、概述

图节点的低维向量embedding对于一些图分析任务是非常有用的,节点embedding技术的基本思想是使用降维技术来蒸馏节点的邻域信息到一个稠密向量embedding,这些embedding可以用在一些下游任务上比如节点分类、聚类等。

然而,以前的工作都集中在从单个固定的图中学习节点embedding(直推式的方法,transductive),而许多现实世界的应用程序都需要为未见的节点全新的(子)图节点快速生成节点的embedding。这种归纳(inductive)能力对于高吞吐量的生产机器学习系统是至关重要的,这些系统运行在不断演化的图上,并不断遇到未见节点(例如Reddit上的帖子、Youtube上的用户和视频)。一种生成节点embedding的归纳方法也有助于在具有相同特征形式的图之间进行泛化:例如,一个人可以在一个模型生物的蛋白质-蛋白质相互作用图上训练一个embedding生成器,然后很容易地使用训练的模型生成从新生物中收集的数据的节点embedding。

这种归纳式的方法相较于直推式的方法是更加困难的,因为泛化到未见节点需要对齐(align)新观察到的子图到算法已经优化过的图上。这叫要求归纳式的算法必须学习识别节点邻域的结构特性,这既包括节点在图中的局部(local)角色以及它的全局(global)位置。

大多数现有方法是直推式的,也就是只能应用在单个固定的图上,大多数这类方法采用基于矩阵分解的策略,这就导致在面对新的节点或新的图时需要重新进行矩阵的分解操作,计算上是昂贵的。另外一些图卷积网络(GCN)的方法也是直推式的,不过在本文中会将这一类方法拓展成归纳式的,并且本文提出了一个框架,该框架将GCN方法推广到使用可训练的聚合函数(而非简单的卷积)。

本文提出的框架,命名为 GraphSAGE(SAmple and aggreGatE)。与基于矩阵分解的embedding方法不同,我们利用节点特征(例如文本属性、节点概要信息、节点度)来学习一个可推广到未见节点的embedding函数。通过在学习算法中引入节点特征,我们同时学习了每个节点的邻域拓扑结构以及邻域节点特征的分布情况。

GraphSAGE的思路是训练一系列聚合函数来从节点的邻域聚合邻域节点的特征信息,不同的聚合函数对应不同的hops(也就是与当前节点的距离),该过程如下图所示:

GraphSAGE

在测试或者推断时,我们使用学习到的聚合函数来为未见节点来生成其embedding向量。另外,本文为GraphSAGE设计了一种无监督的loss,同时也可以使用有监督的方式来训练GraphSAGE。

二、方法

  1. embedding生成(前向传播)算法

首先,GraphSAGE有K个需要学习的聚合函数,记作\mathrm{AGGREGATE}_{k},\forall k\in \left \{1,\cdots ,K\right \},这些聚合函数用于聚合邻域节点信息。另外需要学习的参数有W_k,\forall k\in \left \{1,\cdots ,K\right \},用于在不同的层之间传播信息。这里的K可以认为是神经网络的层数,每层含有一个聚合函数\mathrm{AGGREGATE}_{k}W_k。前向传播的算法如下:

embedding生成(前向传播)算法

这里算法第4行表示将k-1次迭代后当前节点v的邻域节点信息聚合到向量h_{N(v)}^{k}上。第5行则表示将h_{N(v)}^{k}与将k-1次迭代后当前节点v的隐层表示向量h_{v}^{k-1}拼接然后通过W^{k}(也就是一个全连接网络)进行融合,最终通过一个非线性函数以后得到k次迭代后当前节点v的隐层表示向量h_{v}^{k}。第7行表示对h_{v}^{k}进行标准化。第9行表示网络最终输出的节点的embedding向量就是h_{v}^{K}

注意这里每一步都会聚合邻域信息到当前节点,也就是说,第一次迭代当前节点会接受到1阶邻域的信息,第二次迭代当前节点就会接受到2阶邻域的信息,依次类推,随着迭代的继续,当前节点会接触到距离自己越来越远的节点的信息。另外,这里的聚合函数可以采用多种类型,后续会介绍本文所采用的几种聚合函数。

在实际操作上,上述算法的N(v)并非使用所有的邻域节点,而是采用邻域节点集合的一个固定大小的采样以避免过高的计算复杂度,并且在每一层迭代中都采用不同的采样。这样计算复杂度就为O(\prod_{i=1}^{K}S_{i})S_{i}是每层采样的大小。在本文的实验中使用K=2并且S_{1}\cdot S_{2}\leq 500可以取得一个较高的性能。

  1. embedding生成(前向传播)算法的minibatch版本

minibatch版本解决的是对于一些给定的节点而非整张图,如何计算这些节点的embedding。思想是首先采样需要的所有节点,然后再运行上述算法的循环。具体过程如下:

minibatch版本
  1. GraphSAGE的参数学习

GraphSAGE可以采用无监督或者有监督的方式进行参数的学习。对于无监督的方式,我们希望采用邻近节点有相似的表示,而距离较远的节点的表示能够是不同的,因而设计了以下损失函数:

J_{G}(z_{u})=-log(\sigma (z_{u}^{T}z_{v}))-Q\cdot E_{v_{n}\sim P_{n}(v)}log(\sigma (-z_{u}^{T}z_{v_{n}}))

v是节点u的固定长度随机游走得到的节点,\sigma是sigmoid函数,P_{n}(v)是节点v的负采样分布,Q是负采样的个数。采样里目标节点z_v较远的点就叫做负采样:

负采样
  1. 聚合函数

由于节点的邻域是没有顺序的,因而对聚合函数的一个要求就是它必须应用在向量的无序集合上,也就是说聚合函数必须具有对称属性(与输入的排列顺序无关)。接下来将介绍本文使用的3种聚合函数。

  • 平均聚合函数

平均聚合函数就是简单地将\left \{h_{u}^{k-1},\forall u\in N(v)\right \}进行平均,类似直推式的GCN操作,最终实现的过程相当于一个归纳式的GCN:

h_{v}^{k}\leftarrow \sigma (W\cdot MEAN(\left \{h_{v}^{k-1}\right \}\cup \left \{h_{u}^{k-1},\forall u\in N(v)\right \}))

因此这个聚合函数也可以叫做卷积聚合函数。不过在这个聚合函数里没有执行拼接操作,这种拼接操作可以看做是网络的不同层之间的“跳跃连接”(残差连接),可以显著提高性能。

  • LSTM聚合函数

相比于平均的方式,采用LSTM作为聚合函数可以使模型具有更大的表达能力,不过LSTM并没有对称属性,我们通过将LSTM应用于随机排列的邻域节点集合上来使它能够应用于无序集合。

  • Pooling聚合函数

这种聚合函数既是对称的也是可训练的,在这种pooling方法中,每个邻域节点向量被独立输入到一个全连接网络中,然后应用一个max-pooling操作:

AGGREGATE_{k}^{pool}=max(\left \{\sigma (W_{pool}h_{u_{i}}^{k}+b),\forall u_{i}\in N(v)\right \})

这里的全连接网络可以是任意深层多层感知机,不过本文只关注单层架构。通过对每个计算特征应用max-pooling,该模型有效地捕获了邻域集的不同方面特征。事实上任何对称函数都能替换max-pooling(比如平均),不过对比平均与max-pooling的效果差别不大,因此本文采用max-pooling。

三、实验

本文在Citation、Reddit、Protein-protein interactions三个数据集上进行实验测试效果,其中前两个是不断演化的图数据,最后一个是多图数据集。同时也对比了其他baseline方法,具体实验设置参照原文。实验结果如下:

实验

同时也测试了算法的时间效率以及邻域采样大小对模型效果的影响:

实验

图B表明随着邻域采样大小的增长,模型的性能收益是在下降的,运行时间却大量提高,这表明我们需要合理选择邻域采样的大小。

相关文章

网友评论

    本文标题:GraphSAGE:大型图的归纳式表示学习

    本文链接:https://www.haomeiwen.com/subject/zaoybltx.html