美文网首页
优化文件加载从20s到3s

优化文件加载从20s到3s

作者: 董泽润 | 来源:发表于2017-02-25 12:35 被阅读285次
    背景

    现在大多业务都使用机器学习,程序启动时加载训练好的模型文件,运行期也会触发模型的 reload。 在程序启动时如果加载耗时比较长,那么程序自然有段时间不可服务(模型没有准备好),但是运行期由于是双 buffer 切换,耗时长些也无所谓。

    优化前

    加载 14 个模型文件,并行加载,文件最小的几k,最大有 300m,加载时间取短板最长耗时 20s

    优化后

    对最大的四个文件,采用并行加载,耗时最大减少到 3s,优化完成

    代码串行处理逻辑

    原有单个文件也是串行处理逻辑

    • os.Open 打开文件
    • buf.ReadString 按行读取数据
    • 根据业务需求,解析各个字段
    • 追加到模型字典

    这个逻辑非常简单,符合人的直觉思维,但同时也非常低效。

    并行优化1

    思路很简单:模型文件没有顺序,可以一次性全读到内存中,然后按行去并行解析,最后合并到字典,非常类似 MapReduce

    • ioutil.ReadFile 全部读到内存中
    • bytes.Split 根据 '\n' 分隔符打散
    • 开启 n 个 goroutine 并行解析每行数据
    • 合并到模型字典

    第一次优化后耗时降为 10s,初步成效,但是仍然不理想。纺计每一步耗时后发现,对于最大 300m 的文件,bytes.Split 打散耗时 4s, 模型 Map 合并耗时 5s

    并行优化2

    和同事探讨下如何继续优化,对于 Map 无法并行。当前模型实现方式用单一 Map,如果加锁就和串行合并行为是一致的。当初始化 Map 指定大小时,合并时间从 5s 降到 2s,避免了 rehash copy 的开销,效果很明显。

    另外 bytes.Split 打散耗时超长是没有想到的,看了下源码,内部两次遍历,耗时自然和数量成正比。同事提义将打散移到并行阶段,由每个 goroutine 去完成,预估并行数量,然后按 batch 打散。有几点需要注意:

    1. 无所提前知道总数据量大小,模型 Map 初始化要预估大小,按 30 byte 一行猜测即可
    2. 每个 gorouinte 划分数据也是不均等的,但一定要以 '\n' 分隔符打散,不能打数据截断

    最后共耗时 3s,一次性加载内存维 150ms,并行解析 1s,合并 Map 2s

    代码示例:

    当前性价比最高的优化,如果大家有更好的方式可以共同交流一下,第一个是抽象的执行函数,第二个是示例使用方示

    // ParallelLoadModelFile 并行加载模型文件
    // @params  data         文件二进制数据
    // @params  sep          分隔符
    // @params  name         识别标记
    // @params  parallel     并发数目, 一般不超过20, 过大没用
    // @params  parse        用户处理函数
    // @params  merge        用户合并函数
    // 原类类似 MapReduce, 先将文件并行处理, 最后 reduce 合并。使用请参考 loadPassengerFeatures2
    // 原则:尽量将耗时操作并行化
    // 注意:
    // 1. map 初始化时一定要指定大小,否则 rehash copy 成本非常高 测试 800W 条记录合并消耗 2s
    // 2. 数据在 parse 和 merge 函数流动要用 channel, 具体类型及解析合并由调用方决定
    // 3. 需要特殊处理行不能使用这个函数, 要单独处理
    //
    // 流程优化:
    // 读文件 |  解析每行数据并写到map
    // +------------------------+
    //
    // load内存并打散  分片     聚合
    //             +-----+
    // +--------+  |-----| +------+
    //             +-----+
    // load   打散分片 聚合
    //
    //        +----+
    // +----+ |----| +---+
    //        +----+
    //
    // 1. ioutil.ReadFile一次性读入内存 2. bytes.Split 按\n打散 3. 分片计算  4. 合并merge
    // 在大文件时 bytes.Split 非常耗时, 将第2步移到并行阶段, 和3一起算。合并 map 非常耗时
    // Map 操作只能串行, 并发也需要加锁来互斥, 等同于串行, 暂时没想到好的合并方法
    func ParallelLoadModelFile(data []byte, sep []byte, name string, parallel int, parse func([]byte), merge func()) {
        if parallel <= 0 || parallel > 30 || parse == nil || merge == nil || len(sep) == 0 {
            panic("ParallelLoadModelFile params illegal")
        }
    
        var (
            wait  = sync.WaitGroup{} // sync
            size  = len(data)        // file size
            batch = size / parallel  // batch size
            num   = size/batch + 1   // parallel goroutine
            start = 0
            end   = batch
        )
    
        for i := 0; i < num; i++ {
            wait.Add(1)
            // 获取第一个 sep 所在的 index
            idx := bytes.Index(data[end:], sep)
            if idx == -1 {
                end = len(data) - 1
            } else {
                end += idx
            }
    
            go parse(data[start:end])
    
            start = end
            if (end + batch) < len(data) {
                end += batch
            } else {
                end = len(data) - 1
            }
        }
    
        go func() {
            for i := 0; i < num; i++ {
                merge()
                wait.Done()
            }
        }()
    
        // 同步阻塞,等待所有 MapReduce
        wait.Wait()
    }
    
    
    //加载小时特征 并行版本
    func LoadHourGEOInfo2(model_data_center *ModelDataCenter, file_name string) error {
        now := time.Now().UnixNano()
        defer func() {
            logger.Info("load[%s] time=%dms", file_name, (time.Now().UnixNano()-now)/1e6)
        }()
    
        content, err := ioutil.ReadFile(file_name)
        if err != nil {
            logger.Error("ioutil readfile error, file_name=%s", file_name)
            return err
        }
    
        // 预估map大小
        model_data_center.HourGEOInfoData = make(map[string]DynamicDiscountGEOInfo, len(content)/30)
    
        // model 消息
        modelChan := make(chan map[string]DynamicDiscountGEOInfo, 10)
    
        // map 并行处理函数
        mapParse := func(content []byte) {
            var (
                data = bytes.Split(content, SepLine)
                m    = make(map[string]DynamicDiscountGEOInfo, len(data))
            )
    
            defer func() {
                // 将数据扔到 chan 待合并
                // 用 defer 防止遗望
                modelChan <- m
            }()
    
            for _, l := range data {
    
                line := string(l)
                // 兼容\r\n换行的情况
                line = strings.Replace(line, "\r", "", -1)
                list := strings.Split(line, ",")
    
                var hour_geo_info DynamicDiscountGEOInfo
                if len(list) != 8 && len(list) != 10 {
                    logger.Warn("wrong fomat file=%s line=%s cols.Size=%d", file_name, line, len(list))
                    continue
                }
    
                lng_lat, err := strconv.Atoi(list[0])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 0)
                    continue
                }
                hour, err := strconv.Atoi(list[1])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 1)
                    continue
                }
                hour_geo_key := GetGEOKey(hour, lng_lat, "HOUR", 0)
                hour_geo_info.StartGEOInfo.CarpoolNum, err = strconv.Atoi(list[2])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 2)
                    continue
                }
                hour_geo_info.StartGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[3])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 3)
                    continue
                }
                hour_geo_info.StartGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[4])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 4)
                    continue
                }
                hour_geo_info.DestGEOInfo.CarpoolNum, err = strconv.Atoi(list[5])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 5)
                    continue
                }
                hour_geo_info.DestGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[6])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 6)
                    continue
                }
                hour_geo_info.DestGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[7])
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 7)
                    continue
                }
    
                if len(list) == 10 {
                    hour_geo_info.StartGEOInfo.InComeRate, err = strconv.ParseFloat(list[8], 64)
                    if err != nil {
                        logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 8)
                        continue
                    }
                    hour_geo_info.DestGEOInfo.InComeRate, err = strconv.ParseFloat(list[9], 64)
                    if err != nil {
                        logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 9)
                        continue
                    }
                } else {
                    hour_geo_info.StartGEOInfo.InComeRate = -1.0
                    hour_geo_info.DestGEOInfo.InComeRate = -1.0
                }
                // 更新 map
                m[hour_geo_key] = hour_geo_info
            }
        }
    
        // reduce 最终合并函数
        mergeReduce := func() {
            select {
            // merge model msg
            case m := <-modelChan:
                logger.Info("parallel load[%s]||line_num=%d", file_name, len(m))
                for k := range m {
                    model_data_center.HourGEOInfoData[k] = m[k]
                }
            }
        }
    
        ParallelLoadModelFile(content, SepLine, file_name, 3, mapParse, mergeReduce)
        return nil
    }
    

    相关文章

      网友评论

          本文标题:优化文件加载从20s到3s

          本文链接:https://www.haomeiwen.com/subject/qxnpwttx.html