美文网首页
踩坑笔记(pytorch-bert,dataframe,交叉熵)

踩坑笔记(pytorch-bert,dataframe,交叉熵)

作者: 锦绣拾年 | 来源:发表于2021-01-21 15:56 被阅读0次

    20210121 note

    1、pytorch bert输出的问题。

    #model是加载的pytorch transformer里的bert模型
    loss,logits = model(input_ids, attention_mask=masks)#会提示是str类型
    
    tmpx = model(input_ids, attention_mask=masks)
    loss=tmpx[0]
    logits=tmpx[1]#有时候loss=None会出问题。
    #比较合适的是用tmpx.loss,tmpx.logits来得到以下输出值。
    
    SequenceClassifierOutput(loss=tensor(1.2209, device='cuda:0', grad_fn=<NllLossBackward>), logits=tensor([[-0.0275, -0.2090, -0.1251, -0.2942],
            [ 0.0310, -0.2028, -0.1399, -0.3605],
            [-0.0671, -0.3543, -0.1225, -0.4625],
            [ 0.1389, -0.1244, -0.2310, -0.3664]], device='cuda:0',
           grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
    

    2、dataframe的筛选问题。

    import numpy as np
    s = np.asarray([True] * 12)
    s[3]=False
    s[6]=False
    import pandas as pd
    x=pd.read_csv("boold.csv")
    print(x[:12][s])
    print(len(x[:12][s]))
    

    输出:值为False的列没有被选中。

     Unnamed: 0                                              train  answer
    0            0  <start> does ethanol take more energy make tha...       0
    1            1  <start> is house tax and property tax are same...       1
    2            2  <start> is pain experienced in a missing body ...       1
    4            4  <start> is there a difference between hydroxyz...       1
    5            5  <start> is barq s root beer a pepsi product <e...       0
    7            7  <start> is there a word with q without u <end>...       1
    8            8  <start> can u drive in canada with us license ...       1
    9            9  <start> is there a play off for third place in...       1
    10          10  <start> can minors drink with parents in new y...       1
    11          11  <start> is the show bloodline based on a 1 sto...       0
    10
    

    3、交叉熵损失函数

    \mathcal{L} = - \frac{1}{N}\sum_{n=1}^Nlogp(y^{(n)}|s^{(n)})

    可以直接 认为是这样的一个概率和,

    如果是多分类,假设真实标签[0,1,2,1,1,2,1]

    那就是-(logp(0|x)+logp(1|x)+log p(2|x)+log p(1|x)+log p(1|x)+log p(2|x)+log p(1|x))

    相关文章

      网友评论

          本文标题:踩坑笔记(pytorch-bert,dataframe,交叉熵)

          本文链接:https://www.haomeiwen.com/subject/orhdzktx.html