机器学习模型交叉验证脚本

作者: ShallowLearner | 来源:发表于2022-09-19 17:59 被阅读0次

    机器学习模型交叉验证脚本

    本文以阿里云机器学习平台上的 ps_smart (GBDT)算法为例,提供一个搜索最佳超参数的交叉验证任务的bash脚本。

    机器学习模型超参数网格搜索脚本 提供了超参数网格搜索的能力。然而,当验证集的数量较少时,网格搜索的最优超参数非常容易过拟合,在实际的生产环境中,往往效果不如预期。为了缓解数据量少的问题,我们把网格搜索的Top N最优超参数保存下来,对这组超参数继续使用交叉验证的方式评估每组超参数对应的模型的实现效果指标。

    本文提供的示例是一个LTV预测的回归任务,计算MAE、RMSE、WAPE 三个评估指标。

    #!/bin/bash
    #set -x
    odps='.odpscmd/bin/odpscmd --config=odps_config.ini'
    hyper_params_file='hyper_params.txt'
    
    function log_info()
    {
        if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
        then
            echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
        fi
    }
    
    function prepare()
    {
        log_info "function [$FUNCNAME] begin"
        if [ ! -d ".odpscmd" ]; then
            wget https://odps-repo.oss-cn-hangzhou.aliyuncs.com/odpscmd/latest/odpscmd_public.zip
            unzip -d .odpscmd odpscmd_public.zip
        fi
        log_info "function [$FUNCNAME] end"
    }
    
    function gen_partition() {
        log_info "function [$FUNCNAME] begin"
        local n=$1
        local k=$2
        local i
        pt=""
        for ((i=0;i<$n;i++))
        do
            if [ "$i" -eq "$k" ]; then
                continue
            fi
            pt=${pt}",'"${i}"'"
        done
        exclude_pt=${pt#,}
        log_info "function [$FUNCNAME] end"
    }
    
    function prepare_cv_data() {
        log_info "function [$FUNCNAME] begin"
        $odps -e "CREATE TABLE IF NOT EXISTS ps_smart_ltv
        (
            mae DOUBLE,
            rmse DOUBLE,
            wape DOUBLE
        )
        PARTITIONED BY (pt STRING COMMENT '实验参数', k STRING);"
    
        $odps -e "CREATE TABLE IF NOT EXISTS userfeature_v2_googleplay_mergekv_freedom_day3_dataset
        (
            dt  STRING,
            uid STRING,
            kv  STRING,
            targetprice DOUBLE,
            ispay BIGINT
        )
        COMMENT '训练数据集'
        PARTITIONED BY (pt STRING COMMENT '分区')
        LIFECYCLE 7;"
    
        local n=10
        $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt)
        SELECT *
        FROM (
            SELECT dt,uid,kv,targetprice,ispay, FLOOR(rand() * ${n}) as pt
            FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_train_20220905_jp_m1
            UNION ALL
            SELECT dt,uid,replace(kv,',',' ') kv,targetprice,ispay, FLOOR(rand(20220826) * ${n}) as pt
            FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_test_20220905_jp_m1
        ) T;"
    
        local k
        for ((k=0;k<${n};k++))
        do
        {
            gen_partition $n $k
            $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt='exclude_${k}')
            SELECT \`(pt)?+.+\`
            FROM userfeature_v2_googleplay_mergekv_freedom_day3_dataset
            WHERE pt IN (${exclude_pt});"
        } &
        done
        wait
        log_info "function [$FUNCNAME] end"
    }
    
    function run_job() {
        log_info "function [$FUNCNAME] begin"
        local k_fold=$1
        local tree_count=$2
        local max_depth=$3
        local l1=$4
        local l2=$5
        local lr=$6
        local eps=$7
        local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
        log_info "run model: $model, k_fold: ${k_fold}"
    
        $odps -e "PAI -name ps_smart
        -project algo_public
        -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
        -DinputTablePartitions='pt=exclude_${k_fold}'
        -DmodelName='smart_${k_fold}_${model}'
        -DoutputTableName='smart_table_${k_fold}_${model}'
        -DoutputImportanceTableName='smart_imp_${k_fold}_${model}'
        -DlabelColName='targetprice'
        -DfeatureColNames='kv'
        -DenableSparse='true'
        -Dobjective='reg:tweedie'
        -Dmetric='tweedie-nloglik'
        -DfeatureImportanceType='gain'
        -DtreeCount='${tree_count}'
        -DmaxDepth='${max_depth}'
        -Dshrinkage='${lr}'
        -Dl2='${l2}'
        -Dl1='${l1}'
        -Dlifecycle='31'
        -DsketchEps='${eps}'
        -DsampleRatio='1.0'
        -DfeatureRatio='1.0'
        -DbaseScore='0.0'
        -DminSplitLoss='0'
        "
        if [ $? -ne 0 ]; then
            return $?
        fi
    
        $odps -e "drop table if exists smart_output_${k_fold}_${model};"
        $odps -e "PAI -name prediction
        -project algo_public
        -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
        -DinputTablePartitions='pt=${k_fold}'
        -DmodelName='smart_${k_fold}_${model}'
        -DoutputTableName='smart_output_${k_fold}_${model}'
        -DfeatureColNames='kv'
        -DappendColNames='targetprice'
        -DenableSparse='true'
        -DitemDelimiter=' '
        -Dlifecycle='128'
        "
        if [ $? -ne 0 ]; then
            return $?
        fi
        
        $odps -e "INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='${k_fold}')
        SELECT AVG(ABS(targetprice-prediction_result)) MAE,
            SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
            SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
        FROM smart_output_${k_fold}_${model};"
        log_info "function [$FUNCNAME] end"
    }
    
    
    function run_cross_validation()
    {
        log_info "function [$FUNCNAME] begin"
        local args=$@
        local tree_count=$1
        local max_depth=$2
        local l1=$3
        local l2=$4
        local lr=$5
        local eps=$6
        local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
     
        local n=10
        local i 
        for ((i=0;i<$n;i++))
        do
        {
            run_job ${i} $args  
        } &
        done
        wait
    
    
        $odps -e "
        INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='mean')
        select avg(MAE), avg(RMSE), avg(WAPE)
        from ps_smart_ltv
        where pt='${model}' and k!='mean';
        "
        log_info "function [$FUNCNAME] end"
    }
    
    function run_from_file()
    {
        log_info "function [$FUNCNAME] begin"
        threadTask=1 #并发数
        fifoFile="test_fifo"
        rm -f ${fifoFile}
        mkfifo ${fifoFile}  #创建fifo管道
        exec 9<> ${fifoFile}
        rm -f ${fifoFile}
        # 预先向管道写入数据
        for ((i=0;i<${threadTask};i++))
        do
            echo "" >&9
        done
        
        log_info "wait all task finish,then exit!!!"
        while read line
        do
            read -u9
            {
                run_cross_validation $line
                echo "" >&9
            } &
        done < $1
        wait
    
        exec 9<&-  # 关闭文件描述符的读
        exec 9>&-  # 关闭文件描述符的写
        log_info "function [$FUNCNAME] end"
    }
    
    prepare
    prepare_cv_data
    run_from_file ${hyper_params_file}
    #run_from_file $1
    

    备注:请结合机器学习模型超参数网格搜索脚本使用,网格搜索的Top N最优超参数需要预先保存到hyper_params.txt文件中。

    本文由mdnice多平台发布

    相关文章

      网友评论

        本文标题:机器学习模型交叉验证脚本

        本文链接:https://www.haomeiwen.com/subject/zpbportx.html