美文网首页
pytorch.nn.Embadding 详解

pytorch.nn.Embadding 详解

作者: 何哀何欢 | 来源:发表于2020-07-29 17:07 被阅读0次

数据和枚举的对应关系:{A:1, B:2, C:3, ...}

网络层输入按照枚举方式,比如是A的话,那么输入层就是
A: [1, 0, 0, 0, 0, 0, ...], 如果是B,输入就是:
B: [0, 1, 0, 0, 0, 0, ...], 依次类推:
C: [0, 0, 1, 0, 0, 0, ...]
有多少枚举,就有多少个输入。

从输入到Hidden层,因为只有一个1,其他都是0,如下图:


其实没必要把所有输入i*w+b都计算了。因为其余都是0,只计算所选的那个i就好了。

这就是 nn.Embadding(num_embaddings, num_dim)的意义。

  • num_embaddings就是枚举的个数,也是输入节点数,他会根据输入自动转换为枚举,比如输入2,输入为[0,0,1,0,0,0, \cdots]
  • num_dim是hidden层的数量。
  • padding_idx 就是说这个index不用,作为补齐的。那么遇到这个index,所有输入都是[0,\cdots],就等于什么也不运算。

相关文章

网友评论

      本文标题:pytorch.nn.Embadding 详解

      本文链接:https://www.haomeiwen.com/subject/llckrktx.html