LDA文档主题生成模型入门

作者: 海天一树X | 来源:发表于2018-05-22 11:35 被阅读27次

    一、LDA简介

    LDA(Latent Dirichlet Allocation)是一种文档主题生成模型,也称为一个三层贝叶斯概率模型,包含词、主题和文档三层结构。所谓生成模型,就是说,我们认为一篇文章的每个词都是通过“以一定概率选择了某个主题,并从这个主题中以一定概率选择某个词语”这样一个过程得到。文档到主题服从多项式分布,主题到词服从多项式分布。

    LDA是一种非监督机器学习技术,可以用来识别大规模文档集(document collection)或语料库(corpus)中潜藏的主题信息。它采用了词袋(bag of words)的方法,这种方法将每一篇文档视为一个词频向量,从而将文本信息转化为了易于建模的数字信息。但是词袋方法没有考虑词与词之间的顺序,这简化了问题的复杂性,同时也为模型的改进提供了契机。每一篇文档代表了一些主题所构成的一个概率分布,而每一个主题又代表了很多单词所构成的一个概率分布。

    二、安装LDA库

    pip install lda
    

    安装完成后,可以在python安装目录下的Lib/site-packages目录下看到lda相关的目录。

    三、了解数据集

    1.png

    数据集位于lda安装目录的tests文件夹中,包含三个文件:reuters.ldac, reuters.titles, reuters.tokens。
    reuters.titles包含了395个文档的标题
    reuters.tokens包含了这395个文档中出现的所有单词,总共是4258个
    reuters.ldac有395行,第i行代表第i个文档中各个词汇出现的频率。以第0行为例,第0行代表的是第0个文档,从reuters.titles中可查到该文档的标题为“UK: Prince Charles spearheads British royal revolution. LONDON 1996-08-20”。
    第0行的数据为:
    159 0:1 2:1 6:1 9:1 12:5 13:2 20:1 21:4 24:2 29:1 ……
    第一个数字159表示第0个文档里总共出现了159个单词(每个单词出现一或多次),
    0:1表示第0个单词出现了1次,从reuters.tokens查到第0个单词为church
    2:1表示第2个单词出现了1次,从reuters.tokens查到第2个单词为years
    6:1表示第6个单词出现了1次,从reuters.tokens查到第6个单词为told
    9:1表示第9个单词出现了1次,从reuters.tokens查到第9个单词为year
    12:5表示第12个单词出现了5次,从reuters.tokens查到第12个单词为charles
    ……
    这里第1、3、4、5、7、8、10、11……个单词序号和次数没列出来,表示出现的次数为0

    注意:
    395个文档的原文是没有的。上述三个文档是根据这395个文档处理之后得到的。

    四、程序实现

    (一)载入数据

    (1)查看文档中词出现的频率

    import numpy as np
    import lda
    import lda.datasets
    
    # document-term matrix
    X = lda.datasets.load_reuters()
    print("type(X): {}".format(type(X)))
    print("shape: {}\n".format(X.shape))
    print(X[:5, :5])        #前五行的前五列
    

    运行结果:

    type(X): <class 'numpy.ndarray'>
    shape: (395, 4258)
    
    [[ 1  0  1  0  0]
     [ 7  0  2  0  0]
     [ 0  0  0  1 10]
     [ 6  0  1  0  0]
     [ 0  0  0  2 14]]
    

    观察reuters.ldac中的前5行的前5列,发现:
    第0行的前5列,单词编号为0,1,2,3,4的出现频次,正是1,0,1,0,0
    第1行的前5列,单词编程为0,1,2,3,4的出现频次,正是7,0,2,0,0
    ……

    (2)查看词

    # the vocab
    vocab = lda.datasets.load_reuters_vocab()
    print("type(vocab): {}".format(type(vocab)))
    print("len(vocab): {}\n".format(len(vocab)))
    print(vocab[:5])
    

    运行结果:

    type(vocab): <class 'tuple'>
    len(vocab): 4258
    
    ('church', 'pope', 'years', 'people', 'mother')
    

    可以看出,reuters.tokens中有4258个单词,前五个分别是church, pope, years, people, mother.

    (3)查看文档标题

    # titles for each story
    titles = lda.datasets.load_reuters_titles()
    print("type(titles): {}".format(type(titles)))
    print("len(titles): {}\n".format(len(titles)))
    print(titles[:5])       # 打印前五个文档的标题
    

    运行结果:

    type(titles): <class 'tuple'>
    len(titles): 395
    
    ('0 UK: Prince Charles spearheads British royal revolution. LONDON 1996-08-20', 
    '1 GERMANY: Historic Dresden church rising from WW2 ashes. DRESDEN, Germany 1996-08-21',
    "2 INDIA: Mother Teresa's condition said still unstable. CALCUTTA 1996-08-23", 
    '3 UK: Palace warns British weekly over Charles pictures. LONDON 1996-08-25', 
    '4 INDIA: Mother Teresa, slightly stronger, blesses nuns. CALCUTTA 1996-08-25')
    

    (4)查看前5个文档第0个词出现的次数

    doc_id = 0
    word_id = 0
    while doc_id < 5:
        print("doc id: {} word id: {}".format(doc_id, word_id))
        print("-- count: {}".format(X[doc_id, word_id]))
        print("-- word : {}".format(vocab[word_id]))
        print("-- doc  : {}\n".format(titles[doc_id]))
        doc_id += 1
    

    运行结果:

    doc id: 0 word id: 0
    -- count: 1
    -- word : church
    -- doc  : 0 UK: Prince Charles spearheads British royal revolution. LONDON 1996-08-20
    
    doc id: 1 word id: 0
    -- count: 7
    -- word : church
    -- doc  : 1 GERMANY: Historic Dresden church rising from WW2 ashes. DRESDEN, Germany 1996-08-21
    
    doc id: 2 word id: 0
    -- count: 0
    -- word : church
    -- doc  : 2 INDIA: Mother Teresa's condition said still unstable. CALCUTTA 1996-08-23
    
    doc id: 3 word id: 0
    -- count: 6
    -- word : church
    -- doc  : 3 UK: Palace warns British weekly over Charles pictures. LONDON 1996-08-25
    
    doc id: 4 word id: 0
    -- count: 0
    -- word : church
    -- doc  : 4 INDIA: Mother Teresa, slightly stronger, blesses nuns. CALCUTTA 1996-08-25
    

    (二)训练模型

    设置20个主题,500次迭代

    model = lda.LDA(n_topics=20, n_iter=500, random_state=1)
    model.fit(X)          # model.fit_transform(X) is also available
    

    (三)主题-单词分布

    计算前3个单词在所有主题(共20个)中所占的权重

    topic_word = model.topic_word_
    print("type(topic_word): {}".format(type(topic_word)))
    print("shape: {}".format(topic_word.shape))
    print(vocab[:3])
    print(topic_word[:, :3])    #打印所有行(20)行的前3列
    

    运行结果:

    type(topic_word): <class 'numpy.ndarray'>
    shape: (20, 4258)
    ('church', 'pope', 'years')
    [[2.72436509e-06 2.72436509e-06 2.72708945e-03]
     [2.29518860e-02 1.08771556e-06 7.83263973e-03]
     [3.97404221e-03 4.96135108e-06 2.98177200e-03]
     [3.27374625e-03 2.72585033e-06 2.72585033e-06]
     [8.26262882e-03 8.56893407e-02 1.61980569e-06]
     [1.30107788e-02 2.95632328e-06 2.95632328e-06]
     [2.80145003e-06 2.80145003e-06 2.80145003e-06]
     [2.42858077e-02 4.66944966e-06 4.66944966e-06]
     [6.84655429e-03 1.90129250e-06 6.84655429e-03]
     [3.48361655e-06 3.48361655e-06 3.48361655e-06]
     [2.98781661e-03 3.31611166e-06 3.31611166e-06]
     [4.27062069e-06 4.27062069e-06 4.27062069e-06]
     [1.50994982e-02 1.64107142e-06 1.64107142e-06]
     [7.73480150e-07 7.73480150e-07 1.70946848e-02]
     [2.82280146e-06 2.82280146e-06 2.82280146e-06]
     [5.15309856e-06 5.15309856e-06 4.64294180e-03]
     [3.41695768e-06 3.41695768e-06 3.41695768e-06]
     [3.90980357e-02 1.70316633e-03 4.42279319e-03]
     [2.39373034e-06 2.39373034e-06 2.39373034e-06]
     [3.32493234e-06 3.32493234e-06 3.32493234e-06]]
    

    计算所有行的比重之和(等于1)

    for n in range(20):
        sum_pr = sum(topic_word[n,:])   # 第n行所有列的比重之和,等于1
        print("topic: {} sum: {}".format(n, sum_pr))
    

    计算结果:

    topic: 0 sum: 1.0000000000000875
    topic: 1 sum: 1.0000000000001148
    topic: 2 sum: 0.9999999999998656
    topic: 3 sum: 1.0000000000000042
    topic: 4 sum: 1.0000000000000928
    topic: 5 sum: 0.9999999999999372
    topic: 6 sum: 0.9999999999999049
    topic: 7 sum: 1.0000000000001694
    topic: 8 sum: 1.0000000000000906
    topic: 9 sum: 0.9999999999999195
    topic: 10 sum: 1.0000000000001261
    topic: 11 sum: 0.9999999999998876
    topic: 12 sum: 1.0000000000001268
    topic: 13 sum: 0.9999999999999034
    topic: 14 sum: 1.0000000000001892
    topic: 15 sum: 1.0000000000000984
    topic: 16 sum: 1.0000000000000768
    topic: 17 sum: 0.9999999999999146
    topic: 18 sum: 1.0000000000000364
    topic: 19 sum: 1.0000000000001434
    

    (四)计算各主题top-N个词

    计算每个主题中,比重最大的5个词

    n = 5
    for i, topic_dist in enumerate(topic_word):
        topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(n+1):-1]
        print('*Topic {}\n- {}'.format(i, ' '.join(topic_words)))
    

    运行结果:

    *Topic 0
    - government british minister west group
    *Topic 1
    - church first during people political
    *Topic 2
    - elvis king wright fans presley
    *Topic 3
    - yeltsin russian russia president kremlin
    *Topic 4
    - pope vatican paul surgery pontiff
    *Topic 5
    - family police miami versace cunanan
    *Topic 6
    - south simpson born york white
    *Topic 7
    - order church mother successor since
    *Topic 8
    - charles prince diana royal queen
    *Topic 9
    - film france french against actor
    *Topic 10
    - germany german war nazi christian
    *Topic 11
    - east prize peace timor quebec
    *Topic 12
    - n't told life people church
    *Topic 13
    - years world time year last
    *Topic 14
    - mother teresa heart charity calcutta
    *Topic 15
    - city salonika exhibition buddhist byzantine
    *Topic 16
    - music first people tour including
    *Topic 17
    - church catholic bernardin cardinal bishop
    *Topic 18
    - harriman clinton u.s churchill paris
    *Topic 19
    - century art million museum city
    

    (五)文档-主题分布

    总共有395篇文档,计算前10篇文档最可能的主题

    doc_topic = model.doc_topic_
    print("type(doc_topic): {}".format(type(doc_topic)))
    print("shape: {}".format(doc_topic.shape))
    for n in range(10):
        topic_most_pr = doc_topic[n].argmax()
        print("doc: {} topic: {}".format(n, topic_most_pr))
    

    运行结果:

    type(doc_topic): <class 'numpy.ndarray'>
    shape: (395, 20)
    doc: 0 topic: 8
    doc: 1 topic: 1
    doc: 2 topic: 14
    doc: 3 topic: 8
    doc: 4 topic: 14
    doc: 5 topic: 14
    doc: 6 topic: 14
    doc: 7 topic: 14
    doc: 8 topic: 14
    doc: 9 topic: 8
    

    (六)可视化分析

    (1)绘制主题0、主题5、主题9、主题14、主题19的词出现次数分布

    import matplotlib.pyplot as plt
    
    f, ax = plt.subplots(5, 1, figsize=(8, 6), sharex=True)
    for i, k in enumerate([0, 5, 9, 14, 19]):
        print(i, k)
        ax[i].stem(topic_word[k, :], linefmt='b-',
                   markerfmt='bo', basefmt='w-')
        ax[i].set_xlim(-50, 4350)
        ax[i].set_ylim(0, 0.08)
        ax[i].set_ylabel("Prob")
        ax[i].set_title("topic {}".format(k))
    
    ax[4].set_xlabel("word")
    
    plt.tight_layout()
    plt.show()
    

    运行结果:

    2.png

    (2)绘制文档1、文档3、文档4、文档8和文档9的主题分布

    f, ax = plt.subplots(5, 1, figsize=(8, 6), sharex=True)
    for i, k in enumerate([1, 3, 4, 8, 9]):
        ax[i].stem(doc_topic[k, :], linefmt='r-',
                   markerfmt='ro', basefmt='w-')
        ax[i].set_xlim(-1, 21)
        ax[i].set_ylim(0, 1)
        ax[i].set_ylabel("Prob")
        ax[i].set_title("Document {}".format(k))
    
    ax[4].set_xlabel("Topic")
    
    plt.tight_layout()
    plt.show()
    

    运行结果:

    3.png

    五、完整代码

    import numpy as np
    import lda
    import lda.datasets
    
    # document-term matrix
    X = lda.datasets.load_reuters()
    print("type(X): {}".format(type(X)))
    print("shape: {}\n".format(X.shape))
    print(X[:5, :5])        #前五行的前五列
    
    # the vocab
    vocab = lda.datasets.load_reuters_vocab()
    print("type(vocab): {}".format(type(vocab)))
    print("len(vocab): {}\n".format(len(vocab)))
    print(vocab[:5])
    
    # titles for each story
    titles = lda.datasets.load_reuters_titles()
    print("type(titles): {}".format(type(titles)))
    print("len(titles): {}\n".format(len(titles)))
    print(titles[:5])       # 打印前五个文档的标题
    
    print("\n************************************************************")
    doc_id = 0
    word_id = 0
    while doc_id < 5:
        print("doc id: {} word id: {}".format(doc_id, word_id))
        print("-- count: {}".format(X[doc_id, word_id]))
        print("-- word : {}".format(vocab[word_id]))
        print("-- doc  : {}\n".format(titles[doc_id]))
        doc_id += 1
    
    topicCnt = 20
    model = lda.LDA(n_topics = topicCnt, n_iter = 500, random_state = 1)
    model.fit(X)          # model.fit_transform(X) is also available
    
    print("\n************************************************************")
    topic_word = model.topic_word_
    print("type(topic_word): {}".format(type(topic_word)))
    print("shape: {}".format(topic_word.shape))
    print(vocab[:3])
    print(topic_word[:, :3])    #打印所有行(20)行的前3列
    
    for n in range(20):
        sum_pr = sum(topic_word[n,:])   # 第n行所有列的比重之和,等于1
        print("topic: {} sum: {}".format(n, sum_pr))
    
    print("\n************************************************************")
    n = 5
    for i, topic_dist in enumerate(topic_word):
        topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(n+1):-1]
        print('*Topic {}\n- {}'.format(i, ' '.join(topic_words)))
    
    print("\n************************************************************")
    doc_topic = model.doc_topic_
    print("type(doc_topic): {}".format(type(doc_topic)))
    print("shape: {}".format(doc_topic.shape))
    for n in range(10):
        topic_most_pr = doc_topic[n].argmax()
        print("doc: {} topic: {}".format(n, topic_most_pr))
    
    print("\n************************************************************")
    import matplotlib.pyplot as plt
    
    f, ax = plt.subplots(5, 1, figsize=(8, 6), sharex=True)
    for i, k in enumerate([0, 5, 9, 14, 19]):
        print(i, k)
        ax[i].stem(topic_word[k, :], linefmt='b-',
                   markerfmt='bo', basefmt='w-')
        ax[i].set_xlim(-50, 4350)
        ax[i].set_ylim(0, 0.08)
        ax[i].set_ylabel("Prob")
        ax[i].set_title("topic {}".format(k))
    
    ax[4].set_xlabel("word")
    
    plt.tight_layout()
    plt.show()
    
    print("\n************************************************************")
    f, ax = plt.subplots(5, 1, figsize=(8, 6), sharex=True)
    for i, k in enumerate([1, 3, 4, 8, 9]):
        ax[i].stem(doc_topic[k, :], linefmt='r-',
                   markerfmt='ro', basefmt='w-')
        ax[i].set_xlim(-1, 21)
        ax[i].set_ylim(0, 1)
        ax[i].set_ylabel("Prob")
        ax[i].set_title("Document {}".format(k))
    
    ax[4].set_xlabel("Topic")
    
    plt.tight_layout()
    plt.show()
    

    六、参考资料

    (1)
    https://blog.csdn.net/eastmount/article/details/50824215

    (2)http://chrisstrelioff.ws/sandbox/2014/11/13/getting_started_with_latent_dirichlet_allocation_in_python.html

    七、推荐阅读

    《LDA漫游指南》

    TopCoder & Codeforces & AtCoder交流QQ群:648202993
    更多内容请关注微信公众号


    wechat_public_header.jpg

    相关文章

      网友评论

        本文标题:LDA文档主题生成模型入门

        本文链接:https://www.haomeiwen.com/subject/pxdhjftx.html