美文网首页
chatbot-retrieval源码分析(一)

chatbot-retrieval源码分析(一)

作者: 微澜55 | 来源:发表于2018-10-11 17:12 被阅读0次

 项目git路径

https://github.com/dennybritz/chatbot-retrieval

分析文件

TFIDF Baseline Evaluation.py

文件任务

基于检索的智能对话,baseline效果获取

训练及测试数据说明

train.cxv: Context,Utterance,Label(label表示utterance是否为context的正确回复,正确为1)

test.csv: Context,Ground Truth Utterance,Distractor_0,Distractor_1,Distractor_2,Distractor_3,Distractor_4,Distractor_5,Distractor_6,Distractor_7,Distractor_8

代码分析

# coding: utf-8

import pandasas pd #pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。

import numpyas np

from sklearn.feature_extraction.textimport TfidfVectorizer #将一组原始文本转换为由tf-idf组成的矩阵

from sklearn.feature_extraction.textimport TfidfTransformer #将一个数字矩阵转换为标准tf-idf表示

数据加载

train_df = pd.read_csv("../data/train.csv")

test_df = pd.read_csv("../data/test.csv")

validation_df = pd.read_csv("../data/valid.csv")

y_test = np.zeros(len(test_df)) # 编号为0的是正确答案

效果评估

def evaluate_recall(y, y_test, k=1):

        num_examples =float(len(y))

        num_correct =0

        for predictions, labelin zip(y, y_test):

            if labelin predictions[:k]:

            num_correct +=1

        return num_correct/num_examples

随机效果预测

def predict_random(context, utterances):

        return np.random.choice(len(utterances),10,replace=False)

随机预测正确回答

y_random = [predict_random(test_df.Context[x], test_df.iloc[x,1:].values)for xin range(len(test_df))]

for nin [1,2,5,10]:

        print("Recall @ ({}, 10): {:g}".format(n, evaluate_recall(y_random, y_test, n)))

TFIDF效果预测

class TFIDFPredictor:##并没有用到训练数据中的label

    def __init__(self):

        self.vectorizer = TfidfVectorizer()

    def train(self, data): ##从训练集中学习词表及idf

        self.vectorizer.fit(np.append(data.Context.values,data.Utterance.values))

        # Context和Utterance是train.csv文件第一行的内容,标识的是后续所有行中每一列的字段名

    def predict(self, context, utterances):

        # Convert context and utterances into tfidf vector

        vector_context =self.vectorizer.transform([context])

        vector_doc =self.vectorizer.transform(utterances)

        # The dot product measures the similarity of the resulting vectors

        # vector_doc:(10 395829)10个候选答案,每个答案标识为一个395829维的向量    vector_context.T:(395829 1)    np.dot:矩阵相乘

        result = np.dot(vector_doc, vector_context.T)  # result:(10 1)  (0, 0)        0.04412,  (1, 0)        0.03286......

        result = result.todense() # [[0.04412],[0.03286]......]

        result = np.asarray(result).flatten() # [0.04412 0.03286......]

        return np.argsort(result,axis=0)[::-1]# Sort by top results and return the indices in descending order

TFIDF预测正确答案

pred = TFIDFPredictor()

pred.train(train_df)

y = [pred.predict(test_df.Context[x], test_df.iloc[x,1:].values)for xin range(len(test_df))]

for nin [1,2,5,10]:

        print("Recall @ ({}, 10): {:g}".format(n, evaluate_recall(y, y_test, n)))

相关文章

网友评论

      本文标题:chatbot-retrieval源码分析(一)

      本文链接:https://www.haomeiwen.com/subject/euzpaftx.html