美文网首页
基于PGL的图嵌入算法metapath2vec源码解读

基于PGL的图嵌入算法metapath2vec源码解读

作者: lodestar | 来源:发表于2023-03-22 23:14 被阅读0次

    Metapath2vec是一种基于深度学习的网络表示学习方法,用于学习复杂网络中节点的向量表示。它是在元路径(metapath)的基础上进行的。元路径是指网络中的一条特定类型的节点序列,例如在社交网络中,用户之间的关系可以用元路径“用户-群组-用户”来表示。Metapath2vec利用Skip-gram模型,从元路径中提取出节点序列,然后将这些节点序列作为输入,训练一个神经网络来学习节点向量表示(embedding)。通过这种方式,Metapath2vec可以在复杂网络中学习节点之间的关系,得到节点embedding后,再加一层全连接层和softmax就可以做节点分类,通过计算节点间相似度就可以做推荐。

    目的

      生成节点embedding

    metapath2vec算法

      metapath是一个异构图的随机游走算法,比如,A-P-A表示一篇论文有共同的作者,A1-P1-C1-P2-A3表示一个会议上有两个个不同作者发表。它是一个对称的结构。当达到最大长度或者找不到合适的节点才结束游走。


    metapath

      整体的框架如下,当完成元路径随机游走后,我们会得到一些元路径,这个路径像是NLP中的句子,NLP中有通过skip-gram来预测词。我们的训练数据需要得到一个pair对,比如下图A4P3,他们的label为1;而A4P5是一个负样本,他们的label为0,这个过程像是NLP中的二分类,训练完成后,就会得到一个模型,然后根据这个模型推理所有节点的embeding。


    metapath整体框架

      接下来的工作我们就解读下基于PGL的metapath2vec算法代码,metapath2vec源代码。大部分的工作就是将图转化为skip-gram算法所需要的中心词和周围词的关系,不明白的skip-gram的可以参考上一遍文章。

    datasets文件夹
    ├── dataset.py
    ├── helper.py
    ├── node.py
    ├── pair.py
    ├── sampling.py
    └── walk.py
    
    数据生成文件调用关系.png
    #config.yaml
    task_name: distributed_metapath2vec
    
    # ---------------------------数据配置-------------------------------------------------#
    # for data preprocessing
    data_path: ./data/net_aminer
    author_label_file: ./data/label/googlescholar.8area.author.label.txt
    venue_label_file: ./data/label/googlescholar.8area.venue.label.txt
    processed_path: ./graph_data
    
    # for pgl graph engine
    etype2files: "p2a:./graph_data/paper2author_edges.txt,p2c:./graph_data/paper2conf_edges.txt"
    ntype2files: "p:./graph_data/node_types.txt,a:./graph_data/node_types.txt,c:./graph_data/node_types.txt"
    #表示无向图,会生成两条数据
    symmetry: True
    #metapath是对成的
    meta_path: "c2p-p2a-a2p-p2c"
    first_node_type: "c"
    
    shard_num: 100
    
    # walk游走的最大长度
    walk_len: 24
    #skip-gram 中的skip大小
    win_size: 3
    #负采样的个数
    neg_num: 5
    #游走最大的度
    walk_times: 20
    
    
    # ---------------------------模型参数配置---------------------------------------------#
    model_type: SkipGramModel
    warm_start_from: null
    num_nodes: 5000000
    embed_size: 64
    sparse_embed: False
    
    # ---------------------------训练参数配置---------------------------------------------#
    epochs: 1
    num_workers: 4
    lr: 0.001
    lazy_mode: False
    batch_node_size: 200
    batch_pair_size: 1000
    pair_stream_shuffle_size: 100000
    log_dir: ./logs
    output_dir: ./outputs
    save_dir: ./checkpoints
    log_steps: 1000
    

    dropbox文件不好下载, 现已经上传到百度云盘链接: net_aminer 数据集 提取码: s9iv

    处理数据集

    python data_preprocess.py --config config.yaml
    

    node_types.txt 格式node_type"\t" node_id

    c       0
    c       1
    c       2
    c       3
    a       3885
    a       3886
    p       4891796
    p       4891797
    

    paper2author_edges.txt 格式paper_id"\t"author_id

    1738139 1105483
    1963494 1629565
    2128630 418483
    2509017 841304
    3536281 1611393
    

    paper2conf_edges.txt 格式paper_id"\t"conf_id

    2090976 1108
    4666445 2808
    4704329 2055
    1951251 3195
    3680120 779
    
    # dataset.py
    class TrainPairDataset(StreamDataset):
        def __init__(self, config, ip_list_file, mode="train"):
            self.config = config
            self.ip_list_file = ip_list_file
            self.mode = mode
    
        def __iter__(self):
            client_id = os.getpid()
            self.graph = DistGraphClient(self.config, self.config.shard_num,
                                         self.ip_list_file, client_id)
    
            self.generator = PairGenerator(
                self.config,
                self.graph,
                mode=self.mode,
                rank=self._worker_info.fid,
                nrank=self._worker_info.num_workers)
    
            for data in self.generator():
                yield data
    
    class CollateFn(object):
        def __init__(self):
            pass
    
        def __call__(self, batch_data):
            src_list = []
            pos_list = []
            for src, pos in batch_data:
                src_list.append(src)
                pos_list.append(pos)
            #model获取这里的数据
            src_list = np.array(src_list, dtype="int64").reshape(-1, 1)
            pos_list = np.array(pos_list, dtype="int64").reshape(-1, 1)
            return {'src': src_list, 'pos': pos_list}
    
    # pair.py
    class PairGenerator(object):
        #...
        def __call__(self):
            iterval = 20000000 * 24 // self.config.walk_len
            pair_count = 0
            for walks in self.walk_generator():
                try:
                    for walk in walks:
                        index = np.arange(0, len(walk), dtype="int64")
                        batch_s, batch_p = skip_gram_gen_pair(index,
                                                              self.config.win_size)
                        for s, p in zip(batch_s, batch_p):
                            # 返回给CollateFn
                            yield walk[s], walk[p]
                            pair_count += 1
                            if pair_count % iterval == 0 and self.rank == 0:
                                log.info("[%s] pairs have been loaded in rank [%s]" \
                                        % (pair_count, self.rank))
    
                except Exception as e:
                    log.exception(e)
    
            log.info("total [%s] pairs in rank [%s]" % (pair_count, self.rank))
    

    异构图的随机游走,返回metapath节点路径

    #sampling.py
    def metapath_randomwalk_with_walktimes(graph,
                                           start_nodes,
                                           metapath,
                                           walk_length,
                                           walk_times=10,
                                           alias_name=None,
                                           events_name=None):
        """Implementation of metapath random walk in heterogeneous graph.
    
        Args:
            graph: instance of pgl heterogeneous graph
            start_nodes: start nodes to generate walk
            metapath: meta path for sample nodes.
                e.g: "c2p-p2a-a2p-p2c"
            walk_length: the walk length
    
        Return:
            a list of metapath walks.
    
        """
    
        edge_types = metapath.split('-')
        walk = []
        cur_nodes = []
        # start_nodes size=200
        neighbors = graph.sample_successor(
            np.array(
                start_nodes, dtype="uint64"),
            max_degree=walk_times,
            edge_type=edge_types[0])
        # 将开始节点和继承节点加入到返回的walk中,walk 的size=200*20
        for neigh, walk_id in zip(neighbors, start_nodes):
            for node_id in neigh:
                walk.append([walk_id, node_id])
                cur_nodes.append(node_id)
      
        if len(walk) == 0:
            return walk
    
        cur_walk_ids = np.arange(0, len(walk))
        cur_nodes = np.array(cur_nodes, dtype="uint64")
        #  if np.random.random() - 0.02 < 0:
        #      sys.stderr.write("length of walks %s\n" % (len(walk)))
    
        mp_len = len(edge_types)
        for i in range(1, walk_length - 1):
            cur_succs = graph.sample_successor(
                cur_nodes, max_degree=1, edge_type=edge_types[i % mp_len])
            mask = np.array([len(succ) > 0 for succ in cur_succs], dtype="bool")
            # mask: array([ True,  True,  True, ...,  True,  True,  True])
            # np.any()是或操作,任意一个元素为True,输出为True
            # 所有的节点都没有出节点的时候才结束
            if np.any(mask):
                # 取出为True的节点
                cur_walk_ids = cur_walk_ids[mask]
                cur_nodes = cur_nodes[mask]
                cur_succs = np.array(cur_succs, dtype="object")[mask]
            else:
                # stop when all nodes have no successor
                break
            #walk[0] 就是一个完整的metapath
            nxt_cur_nodes = []
            for s, walk_id in zip(cur_succs, cur_walk_ids):
                walk[walk_id].append(s[0])
                nxt_cur_nodes.append(s[0])
            cur_nodes = np.array(nxt_cur_nodes, dtype="uint64")
        return walk
    
    
    # model.py
    class SkipGramModel(nn.Layer):
        #...
        def forward(self, feed_dict):
            src_embed = self.embedding(feed_dict['src'])
            pos_embed = self.embedding(feed_dict['pos'])
    
            # batch neg sample
            # 负采样在这里生成
            batch_size = feed_dict['pos'].shape[0]
            neg_idx = paddle.randint(
                low=0, high=batch_size, shape=[batch_size, self.neg_num])
    
            negs = []
            for i in range(self.neg_num):
                tmp = paddle.gather(pos_embed, neg_idx[:, i])
                tmp = paddle.reshape(tmp, [-1, 1, self.embed_size])
                negs.append(tmp)
    
            neg_embed = paddle.concat(negs, axis=1)
            src_embed = paddle.reshape(src_embed, [-1, 1, self.embed_size])
            pos_embed = paddle.reshape(pos_embed, [-1, 1, self.embed_size])
    
            # [batch_size, 1, 1]
            pos_logits = paddle.matmul(src_embed, pos_embed, transpose_y=True)
            # [batch_size, 1, neg_num]
            neg_logits = paddle.matmul(src_embed, neg_embed, transpose_y=True)
    
            ones_label = paddle.ones_like(pos_logits)
            pos_loss = self.loss_fn(pos_logits, ones_label)
    
            zeros_label = paddle.zeros_like(neg_logits)
            neg_loss = self.loss_fn(neg_logits, zeros_label)
    
            loss = (pos_loss + neg_loss) / 2
            return loss
    

    相关文章

      网友评论

          本文标题:基于PGL的图嵌入算法metapath2vec源码解读

          本文链接:https://www.haomeiwen.com/subject/tiomrdtx.html