美文网首页
GPT图解:代码记录-Word2Vec-Skip-Gram模型

GPT图解:代码记录-Word2Vec-Skip-Gram模型

作者: 万州客 | 来源:发表于2024-01-09 08:35 被阅读0次

    一,代码

    sentences = [
        "Kage is Teacher",
        "Mazong is Boss",
        "Niuzong is Boss",
        "Xiaobing is Student",
        "Xiaoxue is Student"
    ]
    words = ' '.join(sentences).split()
    word_list = list(set(words))
    word_to_idx = {word: idx for idx, word in enumerate(word_list)}
    idx_to_word = {idx: word for idx, word in enumerate(word_to_idx)}
    voc_size = len(word_list)
    
    print("词汇表:", word_list)
    print("词汇到索引的字典: ", word_to_idx)
    print("索引到词汇的字典: ", idx_to_word)
    print("词汇表大小: ", voc_size)
    
    def create_skipgram_dataset(sentences, window_size=2):
        data = []
        for sentence in sentences:
            sentence = sentence.split()
            for idx, word in enumerate(sentence):
                for neighbor in sentence[max(idx-window_size, 0):min(idx+window_size+1, len(sentence))]:
                    if neighbor != word:
                        data.append((neighbor, word))
        return data
    
    skipgram_data = create_skipgram_dataset(sentences)
    print("Skip-Gram 数据样例(未编码): ", skipgram_data)
    
    import torch
    def one_hot_encoding(word, word_to_idx):
        tensor = torch.zeros(len(word_to_idx))
        tensor[word_to_idx[word]] = 1
        return tensor
    
    word_example = "Teacher"
    print("One-Hot编码前的单词:", word_example)
    print("One-Hot编码后的向量:", one_hot_encoding(word_example, word_to_idx))
    print("Skip-Gram样例数据(已编码):", [(one_hot_encoding(context, word_to_idx), word_to_idx[target])
                                         for context, target in skipgram_data[:3]])
    
    import torch.nn as nn
    class SkipGram(nn.Module):
        def __init__(self, vec_size, embeding_size):
            super(SkipGram, self).__init__()
            self.input_to_hidden = nn.Embedding(vec_size, embeding_size)
            self.hidden_to_output = nn.Linear(embeding_size, vec_size, bias=False)
        def forward(self, X):
            hidden_layer = self.input_to_hidden(X)
            output_layer = self.hidden_to_output(hidden_layer)
            return output_layer
    
    embedding_size = 2
    skipgram_model = SkipGram(voc_size, embedding_size)
    print("Skip-Gram类:", skipgram_model)
    
    leaning_rate = 0.001
    epochs = 1000
    criterion = nn.CrossEntropyLoss()
    import torch.optim as optim
    optimizer = optim.SGD(skipgram_model.parameters(), lr=leaning_rate)
    loss_values = []
    for epoch in range(epochs):
        loss_sum = 0
        for context, target in skipgram_data:
            X = torch.tensor([word_to_idx[target]], dtype=torch.long)
            y_true = torch.tensor([word_to_idx[context]], dtype=torch.long)
            y_pred = skipgram_model(X)
            loss = criterion(y_pred, y_true)
            loss_sum += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if (epoch+1) % 100 == 0:
            print(f"Epoch: {epoch+1}, Loss: {loss_sum/len(skipgram_data)}")
            loss_values.append(loss_sum/len(skipgram_data))
    
    import matplotlib.pyplot as plt
    plt.rcParams["font.sans-serif"] = ['SimHei']
    plt.rcParams["axes.unicode_minus"] = False
    plt.plot(range(1, epochs//100 + 1), loss_values)
    plt.title('训练损失曲线')
    plt.xlabel('轮次')
    plt.ylabel('损失')
    plt.show()
    
    """
    print("Skip-Gram词嵌入:")
    for word, idx in word_to_idx.items():
        print(f"{word}: {skipgram_model.input_to_hidden.weight[:, idx].detach().numpy()}")
    
    fig, ax = plt.subplots()
    for word, idx in word_to_idx.items():
        vec = skipgram_model.input_to_hidden.weight[:, idx].detach().numpy()
        ax.scatter(vec[0], vec[1])
        ax.annotate(word, (vec[0], vec[1]), fontsize=12)
    
    plt.title('二维词嵌入')
    plt.xlabel('向量维度1')
    plt.ylabel('向量维度2')
    plt.show()
    """
    
    

    二,截图

    image.png

    相关文章

      网友评论

          本文标题:GPT图解:代码记录-Word2Vec-Skip-Gram模型

          本文链接:https://www.haomeiwen.com/subject/pnjundtx.html