美文网首页Spark
Skip-gram源码实现及解读

Skip-gram源码实现及解读

作者: lodestar | 来源:发表于2023-03-19 21:18 被阅读0次

  Skip-gram是一种用于训练词向量的算法,它可以将每个单词表示为一个向量,并且在这些向量之间保留单词之间的语义关系,是一种通过中心词预测周围词的神经网络算法。

import paddle
import paddle.nn as nn


class SkipGram(paddle.nn.Layer):

    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()

        # 定义两个embedding层,一个用于输入的单词,另一个用于输出的上下文单词

        self.embedding_in = nn.Embedding(vocab_size, embedding_dim)

        self.embedding_out = nn.Embedding(vocab_size, embedding_dim)

        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, input_word, context_word):
        # 获取输入单词和上下文单词的嵌入向量
        input_embed = self.embedding_in(input_word)
        context_embed = self.embedding_out(context_word)

        # 可以加一个全链接层
        input_embed = self.linear(input_embed)

        context_embed = self.linear(context_embed)

        # 如果在Skip-Gram模型中加入一层全连接层,可以增加模型的表达能力和非线性能力,从而提高模型的性能和效果。全连接层可以将输入的嵌入向量进行更复杂的变换,从而得到更丰富的特征表示。
        # 具体来说,可以在输入嵌入向量上加入一层全连接层,然后再通过softmax函数得到预测值。这样可以使模型更加灵活,适应更复杂的语义关系。但是,加入全连接层也会增加模型的复杂度和训练难度,需要更多的训练数据和计算资源来训练和调整模型。

        # 计算内积得到预测值
        # 方法1
        score = paddle.mm(input_embed, context_embed.t())

        # 方法 2
        # score = paddle.matmul(input_embed, context_embed, transpose_y=True)
        # score = paddle.sum(score, axis=-1)
        return score


num_epochs = 10
str1 = "the quick brown fox jumps over the lazy dog"
# skip为2,也就是左右各两个单词
training_data = [
    (0, 1), (0, 2), (1, 0), (1, 2), (1, 3),

    (2, 0), (2, 1), (2, 3), (3, 1), (3, 2),

    (3, 4), (4, 3), (4, 5), (5, 4), (5, 6),

    (6, 5), (6, 7), (7, 6), (7, 8), (8, 7)
]

word2idx = {"the": 0, "quick": 1, "brown": 2, "fox": 3, "jumps": 4, "over": 5, "lazy": 6, "dog": 7}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(word2idx)

embedding_dim = 64

model = SkipGram(vocab_size + 1, embedding_dim)
criterion = nn.BCEWithLogitsLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

for epoch in range(num_epochs):
    total_loss = 0

    for input_word, context_word in training_data:
        # 将数据转换成tensor
        input_word = paddle.to_tensor([input_word])
        context_word = paddle.to_tensor([context_word])

        # 将模型设置为训练模式
        model.train()
        # 前向传播
        output = model(input_word, context_word)

        ones_label = paddle.ones_like(output)
        # 计算损失函数
        # loss = nn.functional.binary_cross_entropy_with_logits(output, ones_label)
        loss = criterion(output, ones_label)
        total_loss += loss.item()

        # 反向传播和优化
        optimizer.clear_grad()
        loss.backward()
        optimizer.step()
    # 打印每个epoch的平均损失
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, total_loss / len(training_data)))

# 计算词向量并保存到文件中
embeddings = model.embedding_in.weight.numpy()
with open('embeddings.txt', 'w') as f:
    for i, word in idx2word.items():
        embedding = ' '.join(map(str, embeddings[i]))
        f.write('{} {}\n'.format(word, embedding))




损失值运行结果如下:

Epoch [1/10], Loss: 0.6915
Epoch [2/10], Loss: 0.6712
Epoch [3/10], Loss: 0.6448
Epoch [4/10], Loss: 0.6096
Epoch [5/10], Loss: 0.5639
Epoch [6/10], Loss: 0.5069
Epoch [7/10], Loss: 0.4408
Epoch [8/10], Loss: 0.3709
Epoch [9/10], Loss: 0.3037
Epoch [10/10], Loss: 0.2444

在embeddings.txt包含每个单词的embedding,格式如下:

the 0.23742944 -0.09895072 -0.11146876 0.2605418 0.024332121 0.0910439 0.12424937 -0.013771858 0.116671495 0.0015973783 -0.10863184 -0.10972429 0.07718096 -0.0033314745 0.2224411 -0.22004403 0.1281483 -0.12591755 0.14134666 -0.034466334 0.24389501 -0.07925096 0.10862582 -0.22061633 0.18360858 -0.17183 0.117620915 -0.23871568 -0.21196923 -0.014604413 0.040446073 0.17018412 -0.24544948 0.10585982 0.05756965 0.058975074 -0.2550219 0.2217722 -0.11203941 0.21279284 0.26438716 -0.17886016 -0.17222002 0.036797255 0.00933418 0.07391108 -0.20267555 -0.21875764 -0.30134645 0.25732276 -0.12506652 -0.060186304 -0.058356136 0.1225115 0.043293785 0.040848006 0.008795115 0.15603863 -0.23737802 -0.045909956 0.14689028 -0.01215158 0.2031173 0.101658516
quick 0.22985205 -0.28023568 0.17331894 -0.18404774 0.19435923 0.054511975 -0.12446486 -0.20461103 -0.20024063 0.074314184 -0.19651282 -0.15052138 -0.25369388 -0.0021391092 -0.2760222 0.10754039 0.11228328 -0.019922553 0.24608843 -0.2582981 -0.11957916 -0.18738061 0.018962713 -0.096384935 -0.26450405 0.066627055 -0.0071602613 -0.077308446 0.26354805 0.07547034 0.058478173 -0.19880083 -0.29015306 0.034329493 0.2207786 -0.11239037 0.049043965 -0.021390196 -0.004287906 -0.28705558 -0.1730856 -0.27100953 0.1121744 -0.25906146 -0.074053064 0.10330311 0.04657338 -0.119830996 -0.17361426 0.17114878 0.1927943 -0.2083592 -0.1774211 -0.2773358 -0.114716105 -0.011761455 -0.1675885 0.1555276 -0.15725754 -0.00861447 -0.27093074 -0.24180736 -0.18109317 0.27589953
brown 0.15149793 -0.19986486 0.2548086 0.020637682 -0.11013863 0.024790183 -0.02504396 0.037789762 -0.020729668 -0.23532745 0.28116202 -0.04157986 -0.29003182 0.29875976 0.16469309 0.23130749 0.17639601 -0.23869719 -0.13300861 0.27599373 0.02700885 0.05513569 0.26320535 -0.22142021 0.013878512 0.10758007 0.22711909 -0.18499781 0.070877045 0.079043075 -0.24289952 -0.2636248 0.0006990259 0.18134123 -0.023455022 -0.034577943 -0.25355765 0.29205313 0.23203316 0.04200985 -0.039580178 -0.21799651 0.20781282 0.083057314 -0.22915262 0.21067782 -0.21856064 0.16073515 0.10993917 -0.14174365 0.097185716 -0.17790347 0.18403171 0.012047063 0.20417404 0.05510201 0.135194 -0.0029973947 0.007548025 0.04317737 0.12034502 0.05921867 0.030197665 0.061334394

相关文章

网友评论

    本文标题:Skip-gram源码实现及解读

    本文链接:https://www.haomeiwen.com/subject/xgaerdtx.html