在用go-xgboost做模型预测的时候,大家肯定非常熟悉下面的一条语句,但是这条语句非常的占用内存,因为模型在执行这样的加载时,顺带进行了以下循环:
predictor, err := xgboost.NewPredictor(modelDir+fileName, runtime.NumCPU(), 0, 5000, -1)
如果workCount=4,意味着任何一个模型均要加载四次
func NewPredictor(xboostSavedModelPath string, workerCount int, optionMask int, nTreeLimit uint, missingValue float32) (Predictor, error) {
if workerCount <= 0 {
return nil, errors.New("worker count needs to be larger than zero")
}
requestChan := make(chan multiBoosterRequest)
initErrors := make(chan error)
defer close(initErrors)
for i := 0; i < workerCount; i++ {
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
booster, err := core.XGBoosterCreate(nil)
if err != nil {
initErrors <- err
return
}
err = booster.LoadModel(xboostSavedModelPath)
if err != nil {
initErrors <- err
return
}
// No errors occured during init
initErrors <- nil
for req := range requestChan {
data, rowCount, columnCount := req.matrix.Data()
matrix, err := core.XGDMatrixCreateFromMat(data, rowCount, columnCount, missingValue)
if err != nil {
req.resultChan <- multiBoosterResponse{
err: err,
}
continue
}
res, err := booster.Predict(matrix, optionMask, nTreeLimit)
req.resultChan <- multiBoosterResponse{
err: err,
result: res,
}
}
}()
err := <-initErrors
if err != nil {
return nil, err
}
}
return &multiBooster{reqChan: requestChan}, nil
}
网友评论