GCN输出的H'矩阵,最后怎么令其作节点分类。即,GCN输出H’如何让节点分类的?
以pytorch的GCN模型为例:GCN
GCN已经将计算简化为:
假设一个图的顶点数目为:
import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj)) ###注: X = AXW1 A=[n,n] X[n,nfeat] W=[nfeat,nhid] ==> X=[n,nhid]
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj) ###注:X=AXW2 A=[n,n] X=[n,nhid] W=[nhid,nclass] ==> X=[n,nclass]
return F.log_softmax(x, dim=1)
代码里的x就是与公式里的H对应,x是图顶点的原始特征矩阵,x输入gc1层时的维度是:[n,nfeat],n是图节点数,nfeat是图节点原始特征的维度;
第一次计算,即,A矩阵维度[n,n],X矩阵就是x维度[n,nfeat],变量维度[nfeat,nhid],所以是新特征矩阵。
第二次计算,即,A矩阵维度[n,n],X矩阵就是x维度[n,nhid],变量维度[nhid,nclass],所以是新矩阵,它就对应nclass分类。
后面return F.log_softmax(x, dim=1),即对分类上分数进行softmax归一化处理,即可以和真实的标签向量进行对标,计算损失值。
网友评论