美文网首页
Embedding的理解

Embedding的理解

作者: sretik | 来源:发表于2024-04-23 23:56 被阅读0次

Embedding :一个简单的查找表,存储固定字典和大小的嵌入。

>>> import torch
>>> import torch.nn as nn
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
    tensor([[
                [-0.0251, -1.6902,  0.7172],
                [-0.6431,  0.0748,  0.6969],
                [ 1.4970,  1.3448, -0.9685],         
                [-0.3677, -2.7265, -0.1685]
            ],        
            [
                [ 1.4970,  1.3448, -0.9685],         
                [ 0.4362, -0.4004,  0.9400],         
                [-0.6431,  0.0748,  0.6969],         
                [ 0.9124, -2.3616,  1.1151]
            ]])

如上面的例子所示,nn.Embedding生成了一个shape=(10,3)的向量,分别表示0-9十个数字。
可以看到input向量中两个2的向量表示是一样的,4的向量表示也是一样的。

相关文章

网友评论

      本文标题:Embedding的理解

      本文链接:https://www.haomeiwen.com/subject/rpwwxjtx.html