美文网首页
DLRM代码理解

DLRM代码理解

作者: CPinging | 来源:发表于2021-06-07 22:42 被阅读0次

    在DLRM中有对训练集做处理的函数,我们对训练序列做了研究,

        def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
            # WARNING: notice that we are processing the batch at once. We implicitly
            # assume that the data is laid out such that:
            # 1. each embedding is indexed with a group of sparse indices,
            #   corresponding to a single lookup
            # 2. for each embedding the lookups are further organized into a batch
            # 3. for a list of embedding tables there is a list of batched lookups
    
            ly = []
            for k, sparse_index_group_batch in enumerate(lS_i):
                sparse_offset_group_batch = lS_o[k]
    
                # embedding lookup
                # We are using EmbeddingBag, which implicitly uses sum operator.
                # The embeddings are represented as tall matrices, with sum
                # happening vertically across 0 axis, resulting in a row vector
                # E = emb_l[k]
    
                if v_W_l[k] is not None:
                    per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
                else:
                    per_sample_weights = None
    
                if:
                    ....
                else:
                    E = emb_l[k]
                    V = E(
                        sparse_index_group_batch,
                        sparse_offset_group_batch,
                        per_sample_weights=per_sample_weights,
                    )
                
                    ly.append(V)
    

    重点是这个地方,其中E是所有打包好的Embedding:

    image.png

    其中第一维为这个Embedding table中包括的vector的数量,第二维64为vector的维度(有64个float)。

    sparse_index_group_batch以及sparse_offset_group_batch为训练时需要的index以及offset,Embedding table会根据index找具体的vector。

    offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)代表三个样本,第一个样本是0 ~ 1,第二个是1 ~ 4,第三个是4(网上解释的都不够清楚,所以我这里通过代码实际跑了一下测出来是这个结果) 。且左闭右开[0,1)这种形式取整数(已经根据代码进行过验证)。

    详细解释一下流程:

    首先在apply_emb函数中每次循环会取出当前第k个Emb table:E = emb_l[k],其中k是当前所在轮数。

    对于index数组与offset数组:

    image.png

    我们能看到,第一个tensor是index,有五个元素,代表我要取的当前table中的vector的编号(共5个)。

    而后面的offset就代表我取出来的这5个数组哪些要进行reduce操作(加和等)。

    例如我如果取offset为[0,3],则代表0,1,2相加进行reduce,3,4进行reduce。所以最终出来的数字个数就是offset的size。

    IS_I以及IS_O生成的位置

    在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()函数里:

    def collate_wrapper_criteo_offset(list_of_tuples):
        # where each tuple is (X_int, X_cat, y)
        transposed_data = list(zip(*list_of_tuples))
        X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
        X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
        T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)
    
        batchSize = X_cat.shape[0]
        featureCnt = X_cat.shape[1]
        lS_i = [X_cat[:, i] for i in range(featureCnt)]
        lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
        return X_int, torch.stack(lS_o), torch.stack(lS_i), T
    

    在这里生成访问序列,首先将传入的数据解析为X_cat,当bs=2时,X_cat为:

    tensor([[    0,    17, 36684, 11838,     1,     0,   145,     9,     0,  1176,
                24, 34569,    24,     5,    24, 15109,     0,    19,    14,     3,
             32351,     0,     1,  4159,    32,  5050],
            [    3,    12, 33818, 19987,     0,     5,  1426,     1,     0,  8616,
               729, 31879,   658,     1,    50, 26833,     1,    12,    89,     0,
             29850,     0,     1,  1637,     3,  1246]])
    

    其中每一个tensor有26个数字,代表26个Embedding table。每一个数字代表其中每个table需要访问的vector。(比如0代表访问第一个table的0号vector)

    下面将访问序列打包,IS_i为:

    [tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32,  3]), tensor([5050, 1246])]
    

    这里bs为2,所以[tensor([0, 3])代表访问第一个table的0,3个vactor。

    这里我们要再次理解一下数据集的含义,这里每一个table都是用户的一个特征(所在城市、年龄等),所以每一个用户也就是每个table拥有一个数值,所以当bs=2时,这里的tensor[0,3]代表对两个用户进行训练,其中第一个用户的第一个table取值是0号vector,第二个用户第一个table取值是3号vector。

    相关文章

      网友评论

          本文标题:DLRM代码理解

          本文链接:https://www.haomeiwen.com/subject/njwpsltx.html