美文网首页
Pytorch Geometric中的图神经网络GAT是如何实现

Pytorch Geometric中的图神经网络GAT是如何实现

作者: 四碗饭儿 | 来源:发表于2021-09-02 00:35 被阅读0次

    最近在使用Pytorch Geometric, 这个包收集了最新的图神经网络的Pytorch实现。这篇文章想研究下它是怎么实现GAT(Graph Attention Network)。在PyG中实现图神经网络,主要依靠MessagePassing这个类。在继承或使用MessagePassing类时,你可以指明使用哪一种消息合并方式

    MessagePassing(aggr="add", flow="src_to_tgt", node_dim=-2)
    

    MessagePassing自带propagate函数,一般你只需在forwad里调用一下就好了

    propagate(edge_index, size=None, **kwargs)# 输入边和其他必要数据,然后构造消息,更新节点的表示
    

    message函数一般是需要你自定义的

    message(...)#对于图中的每一条边$(j,i)$创建一个消息,传送个节点$i$,在pytroch geometric的代码库中,通常i指central node,j指neighboring node
    

    Pytorch Geometric中的GAT实现源码在这里。我这里写了个精简版,方便阅读理解,代码中添加了相关注释。

    class GATConv(MessagePassing):
    def __init__(self):
        # 超参
        self.in_channels = in_channels # 节点特征的维度
        self.out_channels = out_channels # 每个attention head的输出维度
        self.heads = heads
        self.dropout = dropout
        # 模型参数
        self.lin = Linear(in_channels, heads * out_channels, False)
        self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) # 有点像一个Mask
    
    def forward(self, x, edge_index, size, return_attention_weights):
        # MLP
        H, C = self.heads, self.out_channels
        x_src = x_dst = self.lin(x).view(-1, H, C) # MLP输出整理成H个head的格式
        # 计算H个head的attention
        alpha_src = (x_src * self.att_src).sum(dim=-1)
        # 添加`self_loop`到`edge_index`
        edge_index, _  = add_self_loops(edge_index, num_nodes=num_nodes)
        # 消息传播和更新
        out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
        # 拼接和输出
        out = out.view(-1, self.heads * self.out_channels)
        return out
    
    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i)
        # 把两个节点的attention weight 相加
        alpha =  alpha_j + alpha_i
        # 经过一个非线性
        alpha = F.leaky_relu(alpha, self.negative_slope)
        # 经过一个Softmax    
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha  # Save for later use.
        # 经过一个dropout
         alpha = F.dropout(alpha, p=self.dropout, training=self.training)
         # attention weight 乘以node feature
         return x_j * alpha.unsqueeze(-1)
    

    我不太喜欢的点在于将GAT的实现拆到每条边上, 特别是attention weight,完整写起来其实是个N \times N的矩阵,但是硬要写成message的话,就只能从矩阵中取元素了。

    相关文章

      网友评论

          本文标题:Pytorch Geometric中的图神经网络GAT是如何实现

          本文链接:https://www.haomeiwen.com/subject/dwotwltx.html