美文网首页
2019-01-16 解析bert代码

2019-01-16 解析bert代码

作者: 猫和下午茶 | 来源:发表于2019-01-16 16:47 被阅读0次

代码文件为bert_lstm_ner.py,下面进行逐行解析:

tf.logging.set_verbosity(tf.logging.INFO)#运行代码时,将会看到info日志输出INFO:tensorflow:loss = 1.18812, step = 1INFO:tensorflow:loss = #0.210323, step = 101INFO:tensorflow:loss = 0.109025, step = 201

processors = {

        "ner": NerProcessor

    }#定义一个ner:NerProcessor的字典

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)#将bert参数传到bert_config中

if FLAGS.max_seq_length > bert_config.max_position_embeddings:#假如最大总输入序列长度大于bert最大的wordembedding长度,报错

    raise ValueError(

        "Cannot use sequence length %d because the BERT model "

        "was only trained up to sequence length %d" %

        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

# 在train 的时候,才删除上一轮产出的文件,在predicted 的时候不做clean

if FLAGS.clean and FLAGS.do_train:#默认是两个ture

    if os.path.exists(FLAGS.output_dir):#假如输出文件位置存在

        def del_file(path):#设置个删文件的函数

            ls = os.listdir(path)#listdir函数返回文件夹中的所有文件名字

            for i in ls:

                c_path = os.path.join(path, i)#os.path.join()函数用于路径拼接文件路径

                if os.path.isdir(c_path):#如果该文件存在

                    del_file(c_path)#删除文件

                else:

                    os.remove(c_path)#删除文件

        try:

            del_file(FLAGS.output_dir)#尝试删除文件,否则报错

        except Exception as e:

            print(e)

            print('pleace remove the files of output dir and data.conf')

            exit(-1)

    if os.path.exists(FLAGS.data_config_path):#如果保存数据的位置存在

        try:

            os.remove(FLAGS.data_config_path)#尝试删除

        except Exception as e:

            print(e)

            print('pleace remove the files of output dir and data.conf')

            exit(-1)

task_name = FLAGS.task_name.lower()#task_name是要训练的任务的名称,值为ner

if task_name not in processors:#如果processor里面没有ner,报错

    raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()#返回NerProcessor()函数

label_list = processor.get_labels()#label_list值为["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]

tokenizer = tokenization.FullTokenizer(

    vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)#输出函数(对bert的词汇文件在进行变小写后进行fulltokenizer)

tpu_cluster_resolver = None#不使用tpu集群

if FLAGS.use_tpu and FLAGS.tpu_name:#不考虑

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(

        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2#如果为PER_HOST_V1或PER_HOST_V2,则在每个主机上调用一次input_fn。 #使用每核心输入管道配置,每个核心调用一次。 具有全局批量大小

run_config = tf.contrib.tpu.RunConfig(#定义tpu函数

    cluster=tpu_cluster_resolver,#false

    master=FLAGS.master,#none‘TensorFlow master URL.’

    model_dir=FLAGS.output_dir,#输出位置

    save_checkpoints_steps=FLAGS.save_checkpoints_steps,#" 保存模型checkpoint的频率."为1000

    tpu_config=tf.contrib.tpu.TPUConfig(#定义tpu函数2

        iterations_per_loop=FLAGS.iterations_per_loop,#"在每个评估单元调用中要执行多少步骤."1000

        num_shards=FLAGS.num_tpu_cores,#tpu核数,8

        per_host_input_for_training=is_per_host))#PER_HOST_V2

train_examples = None#none

num_train_steps = None#none

num_warmup_steps = None#none

if os.path.exists(FLAGS.data_config_path):#如果data config 文件,保存训练和dev config存在

    with codecs.open(FLAGS.data_config_path) as fd:#打开文件路径

        data_config = json.load(fd)#加载数据到data_config中

else:

    data_config = {}#否则设为空

if FLAGS.do_train:

        # 加载训练数据

    if len(data_config) == 0:#如果为空

        train_examples = processor.get_train_examples(FLAGS.data_dir)#将训练样本输入到变量中

        num_train_steps = int(

            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)#训练执行总批次数为样本长度/训练总批次*训练总次数

        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)#上面数值*进行线性学习率热身训练的比例。

        data_config['num_train_steps'] = num_train_steps#数据参数设定1

        data_config['num_warmup_steps'] = num_warmup_steps#数据参数设定2

        data_config['num_train_size'] = len(train_examples)#数据参数设定3(数据长度)

    else:

        num_train_steps = int(data_config['num_train_steps'])#直接调用1

        num_warmup_steps = int(data_config['num_warmup_steps'])#直接调用2

    # 返回的model_dn 是一个函数,其定义了模型,训练,评测方法,并且使用钩子参数,加载了BERT模型的参数进行了自己模型的参数初始化过程

    # tf 新的架构方法,通过定义model_fn 函数,定义模型,然后通过EstimatorAPI进行模型的其他工作,Es就可以控制模型的训练,预测,评估工作等。

model_fn = model_fn_builder(

    bert_config=bert_config,#从bert文件中获得

    num_labels=len(label_list) + 1,#标签数量

    init_checkpoint=FLAGS.init_checkpoint,#r'D:\bert\chinese_L-12_H-768_A-12\bert_model.ckpt "初始检查点(通常来自预先训练的bert模型)."

    learning_rate=FLAGS.learning_rate,#学习率 5e-5,

    num_train_steps=num_train_steps,#总批次

    num_warmup_steps=num_warmup_steps,#warmup数

#warmup就是先采用小的学习率(0.01)进行训练,训练了400iterations之后将学习率调整至0.1开始正式训练

    use_tpu=FLAGS.use_tpu,#none

    use_one_hot_embeddings=FLAGS.use_tpu)#none

print(model_fn)

estimator = tf.contrib.tpu.TPUEstimator(#定义评估器

    use_tpu=FLAGS.use_tpu,#none

    model_fn=model_fn,#将上面定义的model加入

    config=run_config,#将上面定义的runconfig参数加入

    train_batch_size=FLAGS.train_batch_size,#训练批次 64

    eval_batch_size=FLAGS.eval_batch_size,#评估批次 8

    predict_batch_size=FLAGS.predict_batch_size)# 预测批次 8

train_file =r'C:\Users\dell\Desktop\Name-Entity-Recognition-master\BERT-BiLSTM-CRF-NER\train.tf_record'

filed_based_convert_examples_to_features(

    train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)#将数据转化为TF_Record 结构,作为模型数据输入:样本,标签,最#大长度,tokenizer,数据

num_train_size = num_train_size = int(data_config['num_train_size'])

tf.logging.info("***** Running training *****")

tf.logging.info("  Num examples = %d", num_train_size)#20864

tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)#64

tf.logging.info("  Num steps = %d", num_train_steps)#978

train_input_fn = file_based_input_fn_builder(

    input_file=train_file,#训练文件

    seq_length=FLAGS.max_seq_length,#最大序列长度 128

    is_training=True,#确定训练

    drop_remainder=True)#没查到。。。

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)#进行训练

if FLAGS.do_eval:#进行评估

    if data_config.get('eval.tf_record_path', '') == '':#如果字典中没有评估路径

        eval_examples = processor.get_dev_examples(FLAGS.data_dir)#读到data_dir的dev文件

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")#获得输出位置的eval.tf_record文件

        filed_based_convert_examples_to_features(

            eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)#将评估文件转换

        data_config['eval.tf_record_path'] = eval_file#将评估文件加入数据

        data_config['num_eval_size'] = len(eval_examples)#将评估文件长度加入数据

    else:

        eval_file = data_config['eval.tf_record_path']#将评估数据文件读出

        # 打印验证集数据信息

    num_eval_size = data_config.get('num_eval_size', 0)#将评估文件长度读出

    tf.logging.info("***** Running evaluation *****")

    tf.logging.info("  Num examples = %d", num_eval_size)#2318

    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)#8

    eval_steps = None

    if FLAGS.use_tpu:#none

        eval_steps = int(num_eval_size / FLAGS.eval_batch_size)#不管

    eval_drop_remainder = True if FLAGS.use_tpu else False#false

    eval_input_fn = file_based_input_fn_builder(

        input_file=eval_file,#评估文件

        seq_length=FLAGS.max_seq_length,#最大序列长度

        is_training=False,#不训练

        drop_remainder=eval_drop_remainder)#none

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)#step=none(这里报错)

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")#输出文件

    with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

        tf.logging.info("***** Eval results *****")

        for key in sorted(result.keys()):

            tf.logging.info("  %s = %s", key, str(result[key]))#报出文件

            writer.write("%s = %s\n" % (key, str(result[key])))#写入文件

# 保存数据的配置文件,避免在以后的训练过程中多次读取训练以及测试数据集,消耗时间

if not os.path.exists(FLAGS.data_config_path):

    with codecs.open(FLAGS.data_config_path, 'a', encoding='utf-8') as fd:

        json.dump(data_config, fd)#把a作为data_config_path存入data_config

if FLAGS.do_predict:#开始预测

    token_path = os.path.join(FLAGS.output_dir, "token_test.txt")#导入测试集输出位置

    if os.path.exists(token_path):#如果测试集存在

        os.remove(token_path)#删了

    with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:#打开label2id的文件

        label2id = pickle.load(rf)

        id2label = {value: key for key, value in label2id.items()}#转成字典

    predict_examples = processor.get_test_examples(FLAGS.data_dir)#得到test文件

    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")#得到预测的tf_record文件

    filed_based_convert_examples_to_features(predict_examples, label_list,

                                                FLAGS.max_seq_length, tokenizer,

                                                predict_file, mode="test")#建立测试的tf_record文件

    tf.logging.info("***** Running prediction*****")

    tf.logging.info("  Num examples = %d", len(predict_examples))#4636

    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)#8

    if FLAGS.use_tpu:

            # Warning: According to tpu_estimator.py Prediction on TPU is an

            # experimental feature and hence not supported here

            raise ValueError("Prediction in TPU not supported")

    predict_drop_remainder = True if FLAGS.use_tpu else False#false

    predict_input_fn = file_based_input_fn_builder(

        input_file=predict_file,#输入文件

        seq_length=FLAGS.max_seq_length,#最大序列

        is_training=False,#不训练

        drop_remainder=predict_drop_remainder)#none

    predicted_result = estimator.evaluate(input_fn=predict_input_fn)#报错。。。

    output_eval_file = os.path.join(FLAGS.output_dir, "predicted_results.txt")#输出预测结果

    with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

        tf.logging.info("***** Predict results *****")

        for key in sorted(predicted_result.keys()):

            tf.logging.info("  %s = %s", key, str(predicted_result[key]))

            writer.write("%s = %s\n" % (key, str(predicted_result[key])))#写入文件

    result = estimator.predict(input_fn=predict_input_fn)#预测

    output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")#输出文件

    def result_to_pair(writer):#这里是写入函数

        for predict_line, prediction in zip(predict_examples, result):

            idx = 0

            line = ''

            line_token = str(predict_line.text).split(' ')

            label_token = str(predict_line.label).split(' ')

            if len(line_token) != len(label_token):

                tf.logging.info(predict_line.text)

                tf.logging.info(predict_line.label)

            for id in prediction:

                if id == 0:

                    continue

                curr_labels = id2label[id]

                if curr_labels in ['[CLS]', '[SEP]']:

                    continue

                    # 不知道为什么,这里会出现idx out of range 的错误。。。do not know why here cache list out of range exception!

                try:

                    line += line_token[idx] + ' ' + label_token[idx] + ' ' + curr_labels + '\n'

                except Exception as e:

                    tf.logging.info(e)

                    tf.logging.info(predict_line.text)

                    tf.logging.info(predict_line.label)

                    line = ''

                    break

                idx += 1

            writer.write(line + '\n')

    with codecs.open(output_predict_file, 'w', encoding='utf-8') as writer:

        result_to_pair(writer)#写入文件

    from conlleval import return_report

    eval_result = return_report(output_predict_file)#百度找不到,猜测是得到评估结果的函数

    print(eval_result)

相关文章

网友评论

      本文标题:2019-01-16 解析bert代码

      本文链接:https://www.haomeiwen.com/subject/dirmdqtx.html