美文网首页
[fairseq] generate.py

[fairseq] generate.py

作者: VanJordan | 来源:发表于2019-04-28 21:31 被阅读0次

    [TOC]

    generate.py

    • args.dataargs.gen_subset的区别?
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
    
    • 原来输入的path可以用:隔开可以完成ensemble,然后后面就是各自的model进行各种操作
    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
    
    • 需要ensemble的模型都有各自的操作。
    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
    

    replace_unk的原理

    • 可以指定一个字典加载
    align_dict = utils.load_align_dict(args.replace_unk)
    
    • 函数里面,replace_unk在命令行可以传入一个参数,是一个字典,每行使用空格隔开表示源单词和要替换的目标单词。
    • 如果没有传入参数的话那么字典是空的,但是是使用源语言中的word来替代unk
    def load_align_dict(replace_unk):
        if replace_unk is None:
            align_dict = None
        elif isinstance(replace_unk, str):
            # Load alignment dictionary for unknown word replacement if it was passed as an argument.
            align_dict = {}
            with open(replace_unk, 'r') as f:
                for line in f:
                    cols = line.split()
                    align_dict[cols[0]] = cols[1]
        else:
            # No alignment dictionary provided but we still want to perform unknown word replacement by copying the
            # original source word.
            align_dict = {}
        return align_dict
    
    • itrsample的关系
    • 可以通过.get_original_text将id转换成raw word
    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
    
    • 现在的疑问src_dict.string在哪里定义的。
    src_str = src_dict.string(src_tokens, args.remove_bpe)
    
    • 控制是否输出 sample_idsrc_str是通过args.quiet来控制的
    if not args.quiet:
        if src_dict is not None:
            print('S-{}\t{}'.format(sample_id, src_str))
        if has_target:
            print('T-{}\t{}'.format(sample_id, target_str))
    
    • hypos = task.inference_step(generator, models, sample, prefix_tokens)这一句是干啥的
    • args.prefix_size,是将target的这么长的部分直接给generator看到,默认是0。
    prefix_tokens = None
    if args.prefix_size > 0:
        prefix_tokens = sample['target'][:, :args.prefix_size]
    
    gen_timer.start()
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
    
    • 看到原来hypo是推断的结果,里面还存放者每个位置的的得分。
    if not args.quiet:
        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
        print('P-{}\t{}'.format(
            sample_id,
            ' '.join(map(
                lambda x: '{:.4f}'.format(x),
                hypo['positional_scores'].tolist(),
            ))
        ))
    
    • hyposdecoder得到的数据,args.nbest是推断的个数,即decoder句子的个数。
    for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
    
    • utils.post_process_prediction对已经decode出来的数据进行后处理。
    # Process top predictions
    for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
            hypo_tokens=hypo['tokens'].int().cpu(),
            src_str=src_str,
            alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
            align_dict=align_dict,
            tgt_dict=tgt_dict,
            remove_bpe=args.remove_bpe,
        )
    
    • 可以看到generator里面就直接将score得到了因此不用再设置score
    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
    
    • args.print_alignment 可以看到每一个target string中的每一个单词对应source string中概率最大的那个单词。
    if args.print_alignment:
        print('A-{}\t{}'.format(
            sample_id,
            ' '.join(map(lambda x: str(utils.item(x)), alignment))
        ))
    
    • 现在好奇是怎么算分数的即score的计算方法是什么

    其他部分

    • build_progress_bar的时候可以选择args.log_format==tqdm

    相关文章

      网友评论

          本文标题:[fairseq] generate.py

          本文链接:https://www.haomeiwen.com/subject/ulzdnqtx.html