美文网首页
gcforest的官方代码详解

gcforest的官方代码详解

作者: wendy_要努力努力再努力 | 来源:发表于2018-05-14 18:37 被阅读0次

    本文采用的是v1.1版本,github地址https://github.com/kingfengji/gcForest
    代码主要分为两部分:examples文件夹下是主代码.py和配置文件.json;libs文件夹下是代码中用到的库

    主代码的实现

    from gcforest.gcforest import GCForest
    gc = GCForest(config) # should be a dict
    X_train_enc = gc.fit_transform(X_train, y_train)
    y_pred = gc.predict(X_test)
    

    lib库的详解

    gcforest.py 整个框架的实现
    fgnet.py 多粒度部分,FineGrained的实现
    cascade/cascade_classifier 级联分类器的实现
    datasets/.... 包含一系列数据集的定义
    estimator/... 包含决策树在进行评估用到的函数(多种分类器的预估)
    layer/... 包含不同的层操作,如连接、池化、滑窗等
    utils/.. 包含各种功能函数,譬如计算准确率、win_vote、win_avg、get_windows等

    json配置文件的详解

    参数介绍

    • max_depth: 决策树最大深度。默认为"None",决策树在建立子树的时候不会限制子树的深度这样建树时,会使每一个叶节点只有一个类别,或是达到min_samples_split。一般来说,数据少或者特征少的时候可以不管这个值。如果模型样本量多,特征也多的情况下,推荐限制这个最大深度,具体的取值取决于数据的分布。常用的可以取值10-100之间。
    • estimators表示选择的分类器
    • n_estimators 为森林里的树的数量
    • n_jobs: int (default=1)
      The number of jobs to run in parallel for any Random Forest fit and predict.
      If -1, then the number of jobs is set to the number of cores.

    训练的配置,分三类情况:

    1. 采用默认的模型
    def get_toy_config():
        config = {}
        ca_config = {}
        ca_config["random_state"] = 0  # 0 or 1
        ca_config["max_layers"] = 100  #最大的层数,layer对应论文中的level
        ca_config["early_stopping_rounds"] = 3  #如果出现某层的三层以内的准确率都没有提升,层中止
        ca_config["n_classes"] = 3      #判别的类别数量
        ca_config["estimators"] = []  
        ca_config["estimators"].append(
                {"n_folds": 5, "type": "XGBClassifier", "n_estimators": 10, "max_depth": 5,
                 "objective": "multi:softprob", "silent": True, "nthread": -1, "learning_rate": 0.1} )
        ca_config["estimators"].append({"n_folds": 5, "type": "RandomForestClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})
        ca_config["estimators"].append({"n_folds": 5, "type": "ExtraTreesClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})
        ca_config["estimators"].append({"n_folds": 5, "type": "LogisticRegression"})
        config["cascade"] = ca_config    #共使用了四个基学习器
        return config
    

    支持的基本分类器:
    RandomForestClassifier
    XGBClassifier
    ExtraTreesClassifier
    LogisticRegression
    SGDClassifier

    你可以通过下述方式手动添加任何分类器:

    lib/gcforest/estimators/__init__.py
    
    1. 只有级联(cascade)部分
    {
    "cascade": {
        "random_state": 0,
        "max_layers": 100,
        "early_stopping_rounds": 3,
        "n_classes": 10,
        "estimators": [
            {"n_folds":5,"type":"XGBClassifier","n_estimators":10,"max_depth":5,"objective":"multi:softprob", "silent":true, "nthread":-1, "learning_rate":0.1},
            {"n_folds":5,"type":"RandomForestClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},
            {"n_folds":5,"type":"ExtraTreesClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},
            {"n_folds":5,"type":"LogisticRegression"}
        ]
    }
    }
    
    1. “multi fine-grained + cascade” 两部分
      滑动窗口的大小: {[d/16], [d/8], [d/4]},d代表输入特征的数量;
      "look_indexs_cycle": [
      [0, 1],
      [2, 3],
      [4, 5]]
      代表级联多粒度的方式,第一层级联0、1森林的输出,第二层级联2、3森林的输出,第三层级联4、5森林的输出
    {
    "net":{
    "outputs": ["pool1/7x7/ets", "pool1/7x7/rf", "pool1/10x10/ets", "pool1/10x10/rf", "pool1/13x13/ets", "pool1/13x13/rf"],
    "layers":[
    // win1/7x7
        {
            "type":"FGWinLayer",
            "name":"win1/7x7",
            "bottoms": ["X","y"],
            "tops":["win1/7x7/ets", "win1/7x7/rf"],
            "n_classes": 10,
            "estimators": [
                {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
                {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
            ],
            "stride_x": 2,
            "stride_y": 2,
            "win_x":7,
            "win_y":7
        },
    // win1/10x10
        {
            "type":"FGWinLayer",
            "name":"win1/10x10",
            "bottoms": ["X","y"],
            "tops":["win1/10x10/ets", "win1/10x10/rf"],
            "n_classes": 10,
            "estimators": [
                {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
                {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
            ],
            "stride_x": 2,
            "stride_y": 2,
            "win_x":10,
            "win_y":10
        },
    // win1/13x13
        {
            "type":"FGWinLayer",
            "name":"win1/13x13",
            "bottoms": ["X","y"],
            "tops":["win1/13x13/ets", "win1/13x13/rf"],
            "n_classes": 10,
            "estimators": [
                {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
                {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
            ],
            "stride_x": 2,
            "stride_y": 2,
            "win_x":13,
            "win_y":13
        },
    // pool1
        {
            "type":"FGPoolLayer",
            "name":"pool1",
            "bottoms": ["win1/7x7/ets", "win1/7x7/rf", "win1/10x10/ets", "win1/10x10/rf", "win1/13x13/ets", "win1/13x13/rf"],
            "tops": ["pool1/7x7/ets", "pool1/7x7/rf", "pool1/10x10/ets", "pool1/10x10/rf", "pool1/13x13/ets", "pool1/13x13/rf"],
            "pool_method": "avg",
            "win_x":2,
            "win_y":2
        }
    ]
    
    },
    
    "cascade": {
        "random_state": 0,
        "max_layers": 100,
        "early_stopping_rounds": 3,
        "look_indexs_cycle": [
            [0, 1],
            [2, 3],
            [4, 5]
        ],
        "n_classes": 10,
        "estimators": [
            {"n_folds":5,"type":"ExtraTreesClassifier","n_estimators":1000,"max_depth":null,"n_jobs":-1},
            {"n_folds":5,"type":"RandomForestClassifier","n_estimators":1000,"max_depth":null,"n_jobs":-1}
        ]
    }
    }
    

    相关文章

      网友评论

          本文标题:gcforest的官方代码详解

          本文链接:https://www.haomeiwen.com/subject/mijirftx.html