美文网首页
DLRM代码理解

DLRM代码理解

作者: CPinging | 来源:发表于2021-06-07 22:42 被阅读0次

在DLRM中有对训练集做处理的函数,我们对训练序列做了研究,

    def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
        # WARNING: notice that we are processing the batch at once. We implicitly
        # assume that the data is laid out such that:
        # 1. each embedding is indexed with a group of sparse indices,
        #   corresponding to a single lookup
        # 2. for each embedding the lookups are further organized into a batch
        # 3. for a list of embedding tables there is a list of batched lookups

        ly = []
        for k, sparse_index_group_batch in enumerate(lS_i):
            sparse_offset_group_batch = lS_o[k]

            # embedding lookup
            # We are using EmbeddingBag, which implicitly uses sum operator.
            # The embeddings are represented as tall matrices, with sum
            # happening vertically across 0 axis, resulting in a row vector
            # E = emb_l[k]

            if v_W_l[k] is not None:
                per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
            else:
                per_sample_weights = None

            if:
                ....
            else:
                E = emb_l[k]
                V = E(
                    sparse_index_group_batch,
                    sparse_offset_group_batch,
                    per_sample_weights=per_sample_weights,
                )
            
                ly.append(V)

重点是这个地方,其中E是所有打包好的Embedding:

image.png

其中第一维为这个Embedding table中包括的vector的数量,第二维64为vector的维度(有64个float)。

sparse_index_group_batch以及sparse_offset_group_batch为训练时需要的index以及offset,Embedding table会根据index找具体的vector。

offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)代表三个样本,第一个样本是0 ~ 1,第二个是1 ~ 4,第三个是4(网上解释的都不够清楚,所以我这里通过代码实际跑了一下测出来是这个结果) 。且左闭右开[0,1)这种形式取整数(已经根据代码进行过验证)。

详细解释一下流程:

首先在apply_emb函数中每次循环会取出当前第k个Emb table:E = emb_l[k],其中k是当前所在轮数。

对于index数组与offset数组:

image.png

我们能看到,第一个tensor是index,有五个元素,代表我要取的当前table中的vector的编号(共5个)。

而后面的offset就代表我取出来的这5个数组哪些要进行reduce操作(加和等)。

例如我如果取offset为[0,3],则代表0,1,2相加进行reduce,3,4进行reduce。所以最终出来的数字个数就是offset的size。

IS_I以及IS_O生成的位置

在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()函数里:

def collate_wrapper_criteo_offset(list_of_tuples):
    # where each tuple is (X_int, X_cat, y)
    transposed_data = list(zip(*list_of_tuples))
    X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
    X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
    T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)

    batchSize = X_cat.shape[0]
    featureCnt = X_cat.shape[1]
    lS_i = [X_cat[:, i] for i in range(featureCnt)]
    lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
    return X_int, torch.stack(lS_o), torch.stack(lS_i), T

在这里生成访问序列,首先将传入的数据解析为X_cat,当bs=2时,X_cat为:

tensor([[    0,    17, 36684, 11838,     1,     0,   145,     9,     0,  1176,
            24, 34569,    24,     5,    24, 15109,     0,    19,    14,     3,
         32351,     0,     1,  4159,    32,  5050],
        [    3,    12, 33818, 19987,     0,     5,  1426,     1,     0,  8616,
           729, 31879,   658,     1,    50, 26833,     1,    12,    89,     0,
         29850,     0,     1,  1637,     3,  1246]])

其中每一个tensor有26个数字,代表26个Embedding table。每一个数字代表其中每个table需要访问的vector。(比如0代表访问第一个table的0号vector)

下面将访问序列打包,IS_i为:

[tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32,  3]), tensor([5050, 1246])]

这里bs为2,所以[tensor([0, 3])代表访问第一个table的0,3个vactor。

这里我们要再次理解一下数据集的含义,这里每一个table都是用户的一个特征(所在城市、年龄等),所以每一个用户也就是每个table拥有一个数值,所以当bs=2时,这里的tensor[0,3]代表对两个用户进行训练,其中第一个用户的第一个table取值是0号vector,第二个用户第一个table取值是3号vector。

相关文章

  • DLRM代码理解

    在DLRM中有对训练集做处理的函数,我们对训练序列做了研究, 重点是这个地方,其中E是所有打包好的Embeddin...

  • NullSafe代码理解

    相信大家多NullSafe并不陌生,但是为什么就一个文件,还不用导入就能解决NULL导致的崩溃呢. 首先.不用导入...

  • 代码实战理解

    Python Python中要注意代码的可复用性。能够提取出来的公共部分尽量写成一个函数或者一个类,注意代码之间的...

  • 编写可维护的代码(一)

    一,可维护的代码二,保证代码性能三,部署代码 一,什么是可维护的代码1,可理解性的代码-其他人可以接受代码并且理解...

  • <五>React组件样式

    理论理解 代码实现 代码说明 运行结果

  • 贪心-选择排序

    更好理解的代码

  • 编写可读性代码

    脏话的频率是衡量代码好坏的标准 一、代码应该易于理解 可读性定理 别人理解它的时间最小化并不是意味着代码块越小理解...

  • ROIPooling代码理解(CPU)

    MXNet中ROIPooling的具体实现。代码来自https://github.com/apache/incub...

  • Unsupervised NMT 代码理解

    对应论文:Phrase-based & Neural Unsupervised Machine Translati...

  • 代码应当易于理解

    大多数程序员依靠直觉和灵感来决定如何编程。我们都知道这样的代码: 比下面的代码好: (尽管两个例子的行为完全一样)...

网友评论

      本文标题:DLRM代码理解

      本文链接:https://www.haomeiwen.com/subject/njwpsltx.html