美文网首页
fasttext的源码阅读

fasttext的源码阅读

作者: 小小兰哈哈 | 来源:发表于2019-08-11 22:57 被阅读0次

    最近做了fasttext的  源码阅读,分享一下心得。

    1.所用数据结构:

      1)Matrix(父类)->DenseMatrix(子类)

        DenseMatrix类里面有vector<real> data_的变量,用一个vector保存二维矩阵的信息

        初始化变量(m:int64_t, n:int64_t) 表明矩阵的维数。

        使用DenseMatrix的模型参数:wi:隐藏层 wo:输出层,fasttext就这两层参数,全部都是用DenseMatrix表示。

        模型的输出output变量也是用DenseMatrix定义。

      2)Vector

          Vector类里面也是vector<real> data_的变量,与DenseMatrix不同的是,它是用vector保存一维矩阵的信息

          初始化变量(m:int64_t) 表明矩阵的维数

          使用Vector的模型参数: 模型的状态类State中的变量:包括:hidden,output,grad

      3)另外的变量:

          除了上面提到的wi,wo,hidden,output, graed; input是由一个vector<int64_t>构造的

    2. 封装的类:

      1)fasttext

      2)Model

      3)Loss

      分别描述:

      1)fasttext: fasttext类提供整个模型训练、预测的入口。其内部变量是模型训练过程中所有参数。

        1.模型参数model_ 2. 训练参数 args_ 3. 词典 dict_, 4 模型输入 input_ 5. 模型输出 output. 6. loss_

       源码如下:

        class FastText {

    protected:

      std::shared_ptr<Args> args_;

      std::shared_ptr<Dictionary> dict_;

      std::shared_ptr<Matrix> input_;

      std::shared_ptr<Matrix> output_;

      std::shared_ptr<Model> model_;

      std::atomic<int64_t> tokenCount_{};

      std::atomic<real> loss_{};


        fasttext中共有四种类型的内部函数:

        1. 词典生成及序列转换函数:getInputMatrix, getOutputMatrix, getDictionary等

        2. 训练函数:cbow,skip,supervise,其中,cbow,skip是训练词向量的, surpervise是训练分类的

        3. 预测和实验: test, predict

        4. 保存和加载:保存模型,加载模型 saveModel loadModel

      2)Model:Model类提供Model训练、预测的方法,隐藏层的计算ComputeHidden, predict,update。

          其中,内部变量包括:模型状态变量hidden,output,grad,lossvalue等, wi_(第一层模型参数),wo_(第二层模型参数)

          并且内部定义了一个loss对象。

         源码举例如下:

    void predict(

          const std::vector<int32_t>& input,

          int32_t k,

          real threshold,

          Predictions& heap,

          State& state) const;

      void update(

          const std::vector<int32_t>& input,

          const std::vector<int32_t>& targets,

          int32_t targetIndex,

          real lr,

          State& state);

      void computeHidden(const std::vector<int32_t>& input, State& state) const;

     3)Loss:loss类

          Loss类是由model类引用并使用的。其中封装了四种loss的计算方法:

          1.OneVsAllLoss

          2.NegativeSamplingLoss (默认的loss求法)

          3.HierarchicalSoftmaxLoss (hs的求法,fasttext的创新求法:使用了哈夫曼树)

          4.SoftmaxLoss (softmax)

          另外,还有计算output的ComputOutput函数。

    3. 三个模块的关系图可绘如下:

    相关文章

      网友评论

          本文标题:fasttext的源码阅读

          本文链接:https://www.haomeiwen.com/subject/xspsjctx.html