机器学习模型交叉验证脚本
本文以阿里云机器学习平台上的 ps_smart (GBDT)算法为例,提供一个搜索最佳超参数的交叉验证任务的bash脚本。
机器学习模型超参数网格搜索脚本 提供了超参数网格搜索的能力。然而,当验证集的数量较少时,网格搜索的最优超参数非常容易过拟合,在实际的生产环境中,往往效果不如预期。为了缓解数据量少的问题,我们把网格搜索的Top N最优超参数保存下来,对这组超参数继续使用交叉验证的方式评估每组超参数对应的模型的实现效果指标。
本文提供的示例是一个LTV预测的回归任务,计算MAE、RMSE、WAPE 三个评估指标。
#!/bin/bash
#set -x
odps='.odpscmd/bin/odpscmd --config=odps_config.ini'
hyper_params_file='hyper_params.txt'
function log_info()
{
if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
then
echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
fi
}
function prepare()
{
log_info "function [$FUNCNAME] begin"
if [ ! -d ".odpscmd" ]; then
wget https://odps-repo.oss-cn-hangzhou.aliyuncs.com/odpscmd/latest/odpscmd_public.zip
unzip -d .odpscmd odpscmd_public.zip
fi
log_info "function [$FUNCNAME] end"
}
function gen_partition() {
log_info "function [$FUNCNAME] begin"
local n=$1
local k=$2
local i
pt=""
for ((i=0;i<$n;i++))
do
if [ "$i" -eq "$k" ]; then
continue
fi
pt=${pt}",'"${i}"'"
done
exclude_pt=${pt#,}
log_info "function [$FUNCNAME] end"
}
function prepare_cv_data() {
log_info "function [$FUNCNAME] begin"
$odps -e "CREATE TABLE IF NOT EXISTS ps_smart_ltv
(
mae DOUBLE,
rmse DOUBLE,
wape DOUBLE
)
PARTITIONED BY (pt STRING COMMENT '实验参数', k STRING);"
$odps -e "CREATE TABLE IF NOT EXISTS userfeature_v2_googleplay_mergekv_freedom_day3_dataset
(
dt STRING,
uid STRING,
kv STRING,
targetprice DOUBLE,
ispay BIGINT
)
COMMENT '训练数据集'
PARTITIONED BY (pt STRING COMMENT '分区')
LIFECYCLE 7;"
local n=10
$odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt)
SELECT *
FROM (
SELECT dt,uid,kv,targetprice,ispay, FLOOR(rand() * ${n}) as pt
FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_train_20220905_jp_m1
UNION ALL
SELECT dt,uid,replace(kv,',',' ') kv,targetprice,ispay, FLOOR(rand(20220826) * ${n}) as pt
FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_test_20220905_jp_m1
) T;"
local k
for ((k=0;k<${n};k++))
do
{
gen_partition $n $k
$odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt='exclude_${k}')
SELECT \`(pt)?+.+\`
FROM userfeature_v2_googleplay_mergekv_freedom_day3_dataset
WHERE pt IN (${exclude_pt});"
} &
done
wait
log_info "function [$FUNCNAME] end"
}
function run_job() {
log_info "function [$FUNCNAME] begin"
local k_fold=$1
local tree_count=$2
local max_depth=$3
local l1=$4
local l2=$5
local lr=$6
local eps=$7
local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
log_info "run model: $model, k_fold: ${k_fold}"
$odps -e "PAI -name ps_smart
-project algo_public
-DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
-DinputTablePartitions='pt=exclude_${k_fold}'
-DmodelName='smart_${k_fold}_${model}'
-DoutputTableName='smart_table_${k_fold}_${model}'
-DoutputImportanceTableName='smart_imp_${k_fold}_${model}'
-DlabelColName='targetprice'
-DfeatureColNames='kv'
-DenableSparse='true'
-Dobjective='reg:tweedie'
-Dmetric='tweedie-nloglik'
-DfeatureImportanceType='gain'
-DtreeCount='${tree_count}'
-DmaxDepth='${max_depth}'
-Dshrinkage='${lr}'
-Dl2='${l2}'
-Dl1='${l1}'
-Dlifecycle='31'
-DsketchEps='${eps}'
-DsampleRatio='1.0'
-DfeatureRatio='1.0'
-DbaseScore='0.0'
-DminSplitLoss='0'
"
if [ $? -ne 0 ]; then
return $?
fi
$odps -e "drop table if exists smart_output_${k_fold}_${model};"
$odps -e "PAI -name prediction
-project algo_public
-DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
-DinputTablePartitions='pt=${k_fold}'
-DmodelName='smart_${k_fold}_${model}'
-DoutputTableName='smart_output_${k_fold}_${model}'
-DfeatureColNames='kv'
-DappendColNames='targetprice'
-DenableSparse='true'
-DitemDelimiter=' '
-Dlifecycle='128'
"
if [ $? -ne 0 ]; then
return $?
fi
$odps -e "INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='${k_fold}')
SELECT AVG(ABS(targetprice-prediction_result)) MAE,
SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
FROM smart_output_${k_fold}_${model};"
log_info "function [$FUNCNAME] end"
}
function run_cross_validation()
{
log_info "function [$FUNCNAME] begin"
local args=$@
local tree_count=$1
local max_depth=$2
local l1=$3
local l2=$4
local lr=$5
local eps=$6
local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
local n=10
local i
for ((i=0;i<$n;i++))
do
{
run_job ${i} $args
} &
done
wait
$odps -e "
INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='mean')
select avg(MAE), avg(RMSE), avg(WAPE)
from ps_smart_ltv
where pt='${model}' and k!='mean';
"
log_info "function [$FUNCNAME] end"
}
function run_from_file()
{
log_info "function [$FUNCNAME] begin"
threadTask=1 #并发数
fifoFile="test_fifo"
rm -f ${fifoFile}
mkfifo ${fifoFile} #创建fifo管道
exec 9<> ${fifoFile}
rm -f ${fifoFile}
# 预先向管道写入数据
for ((i=0;i<${threadTask};i++))
do
echo "" >&9
done
log_info "wait all task finish,then exit!!!"
while read line
do
read -u9
{
run_cross_validation $line
echo "" >&9
} &
done < $1
wait
exec 9<&- # 关闭文件描述符的读
exec 9>&- # 关闭文件描述符的写
log_info "function [$FUNCNAME] end"
}
prepare
prepare_cv_data
run_from_file ${hyper_params_file}
#run_from_file $1
备注:请结合机器学习模型超参数网格搜索脚本使用,网格搜索的Top N最优超参数需要预先保存到hyper_params.txt
文件中。
本文由mdnice多平台发布