美文网首页
使用golang做最小二乘法的线性拟合

使用golang做最小二乘法的线性拟合

作者: FredricZhu | 来源:发表于2020-10-29 15:44 被阅读0次

    const.go

    package main
    
    var (
        ColNames = []string{"feature", "document", "machine", "load_time",
            "search_time", "reduce_and_save"}
    
        ResColNames = []string{"feature", "document", "machine", "total"}
    )
    

    fit_classification.go

    package main
    
    import (
        "fmt"
        "log"
        "math"
        "os"
    
        "github.com/go-gota/gota/dataframe"
        "github.com/go-gota/gota/series"
        "gonum.org/v1/gonum/optimize"
        "gonum.org/v1/plot"
        "gonum.org/v1/plot/plotter"
        "gonum.org/v1/plot/plotutil"
        "gonum.org/v1/plot/vg"
    )
    
    // 根据条件修改原先值
    func getTotal(s series.Series) series.Series {
    
        loadTime, _ := s.Val(3).(int)
        searchTime, _ := s.Val(4).(int)
        rAsTime, _ := s.Val(5).(int)
    
        res := loadTime + searchTime + rAsTime
        resF := float64(res) / float64(60)
        return series.Floats(resF)
    }
    
    func getDoc(s series.Series) series.Series {
        document, _ := s.Val(1).(float64)
        resF := float64(2*document) / float64(1000)
        return series.Floats(resF)
    }
    
    // dataPrepare 数据预处理函数
    func dataPrepare(clsDF *dataframe.DataFrame) {
        // 获取total列
        *clsDF = clsDF.Select(ColNames)
        totalSeries := clsDF.Rapply(getTotal)
        totalSeries.SetNames("total")
        *clsDF = clsDF.CBind(totalSeries)
    
        // document列 *2/1000
        *clsDF = clsDF.Select(ResColNames)
        newDocSeries := clsDF.Rapply(getDoc)
        newDocSeries.SetNames("new_doc")
        *clsDF = clsDF.CBind(newDocSeries)
        *clsDF = clsDF.Drop([]string{"document"})
        *clsDF = clsDF.Rename("document", "new_doc")
        *clsDF = clsDF.Select(ResColNames)
    }
    
    // dataOptimize 数据优化和拟合函数
    func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
        // 开始数据拟合
    
        // 实际观测点
        actPoints = plotter.XYs{}
        // N行数据产生N个点
        for i := 0; i < clsDF.Nrow(); i++ {
            document := clsDF.Elem(i, 1).Val().(float64)
            machine := clsDF.Elem(i, 2).Val().(int)
            val := clsDF.Elem(i, 3).Val().(float64)
    
            actPoints = append(actPoints, plotter.XY{
                X: float64(document) / float64(machine),
                Y: val,
            })
        }
    
        result, err := optimize.Minimize(optimize.Problem{
            Func: func(x []float64) float64 {
                if len(x) != 2 {
                    panic("illegal x")
                }
                a := x[0]
                b := x[1]
                var sum float64
                for _, point := range actPoints {
                    y := a*point.X + b
                    sum += math.Abs(y - point.Y)
                }
                return sum
            },
        }, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
        if err != nil {
            panic(err)
        }
    
        // 最小二乘法拟合出来的k和b值
        fa, fb = result.X[0], result.X[1]
        expPoints = plotter.XYs{}
        for i := 0; i < clsDF.Nrow(); i++ {
            document := clsDF.Elem(i, 1).Val().(float64)
            machine := clsDF.Elem(i, 2).Val().(int)
            x := float64(document) / float64(machine)
            expPoints = append(expPoints, plotter.XY{
                X: x,
                Y: fa*float64(x) + fb,
            })
        }
    
        return
    }
    
    func dataPlot(actPoints, expPoints plotter.XYs) {
        plt, err := plot.New()
        if err != nil {
            panic(err)
        }
        plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10
    
        if err := plotutil.AddLinePoints(plt,
            "expPoints", expPoints,
            "actPoints", actPoints,
        ); err != nil {
            panic(err)
        }
    
        if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
            panic(err)
        }
    }
    
    // FitClassification 分类曲线拟合函数
    func FitClassification() {
        clsData, err := os.Open("classification_data.csv")
        if err != nil {
            log.Fatal(err)
        }
    
        defer clsData.Close()
        clsDF := dataframe.ReadCSV(clsData)
        // 数据预处理
        dataPrepare(&clsDF)
        // 数据预处理完成
        fmt.Println("数据预处理完成...")
        fmt.Println(clsDF)
    
        // 数据拟合
        actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
        // 拟合完成,输出fa,fb
        fmt.Println("Fa", fa, "Fb", fb)
    
        // 数据绘图
        dataPlot(actPoints, expPoints)
        fmt.Println("绘制完成,图形地址: classification-fit.png")
    }
    

    main.go

    package main
    
    func main() {
    
    }
    

    main_test.go

    package main
    
    import "testing"
    
    // TestFitClassification 测试分类曲线拟合
    func TestFitClassification(t *testing.T) {
        FitClassification()
    }
    

    运行数据

    feature,document,machine,load_time,search_time,reduce_and_save
    
    100,5000,4,19,130,67
    
    100,5000,4,12,130,61
    
    100,5000,4,13,127,61
    
    100,5000,4,13,124,63
    
    100,5000,4,13,129,59
    
    100,5000,4,13,125,60
    
    100,5000,4,13,123,63
    
    100,5000,4,13,129,61
    
    100,5000,4,12,127,61
    
    100,5000,4,12,125,62
    
    100,5000,4,13,128,59
    
    100,5000,4,13,128,61
    
    100,5000,4,12,130,60
    
    100,5000,4,12,125,61
    
    100,5000,4,13,127,60
    
    100,5000,4,13,126,63
    
    100,5000,4,13,127,64
    
    100,5000,3,18,160,67
    
    100,5000,3,13,166,59
    
    100,5000,3,12,167,61
    
    100,5000,3,12,168,60
    
    100,5000,3,12,170,61
    
    100,5000,3,12,154,63
    
    100,5000,3,13,168,60
    
    100,5000,3,12,167,60
    
    100,5000,3,12,148,64
    
    100,5000,3,12,167,65
    
    100,5000,3,12,164,60
    
    100,5000,3,12,150,59
    
    100,5000,2,20,217,65
    
    100,5000,2,13,205,63
    
    100,5000,2,14,204,60
    
    100,5000,2,13,205,55
    
    100,5000,2,14,210,59
    
    100,5000,2,13,201,59
    
    100,5000,2,13,211,59
    
    100,5000,2,13,217,59
    
    100,5000,2,14,207,60
    
    100,5000,2,14,209,59
    
    100,5000,2,14,214,61
    
    100,5000,2,13,210,61
    
    100,5000,1,24,376,60
    
    100,5000,1,21,393,58
    
    100,5000,1,20,386,58
    
    100,5000,1,22,384,59
    
    100,5000,1,21,387,59
    
    100,4000,4,18,112,70
    
    100,4000,4,13,118,62
    
    100,4000,4,12,114,63
    
    100,4000,4,14,112,65
    
    100,4000,4,12,113,62
    
    100,4000,4,14,109,61
    
    100,4000,4,13,118,63
    
    100,4000,4,12,112,61
    
    100,4000,4,12,110,61
    
    100,4000,4,11,111,63
    
    100,4000,4,13,112,67
    
    100,4000,4,12,110,60
    
    100,4000,4,12,113,60
    
    100,3000,4,19,100,66
    
    100,3000,4,13,99,64
    
    100,3000,4,12,100,65
    
    100,3000,4,13,103,61
    
    100,3000,4,14,104,63
    
    100,3000,4,14,99,63
    
    100,2000,4,18,90,67
    
    100,2000,4,13,86,65
    
    100,2000,4,13,85,63
    
    100,2000,4,14,87,62
    
    100,2000,4,13,85,61
    
    100,2000,4,13,86,64
    
    100,2000,4,13,85,61
    
    100,2000,4,13,89,58
    
    100,2000,4,13,85,61
    
    100,2000,4,12,85,60
    
    100,2000,4,13,85,66
    
    100,2000,4,12,86,59
    
    100,2000,4,12,86,61
    
    100,2000,4,12,82,60
    
    100,2000,4,13,87,62
    
    100,2000,4,12,83,65
    
    100,2000,4,12,85,60
    
    100,2000,4,13,87,60
    
    100,2000,4,12,86,59
    
    100,1000,4,19,75,63
    
    100,1000,4,14,71,61
    
    100,1000,4,13,75,59
    
    100,1000,4,14,72,61
    
    100,1000,4,13,72,59
    
    100,1000,4,13,71,59
    
    100,1000,4,13,72,59
    
    100,1000,4,14,70,62
    
    100,1000,4,12,72,58
    
    100,1000,4,13,71,59
    
    100,1000,4,13,70,62
    
    100,1000,4,12,72,59
    
    100,1000,4,20,71,58
    
    100,1000,4,13,69,60
    
    100,1000,4,12,73,60
    
    100,1000,4,13,69,59
    
    100,1000,4,13,71,60
    
    100,1000,4,13,73,62
    
    100,1000,4,12,71,59
    
    100,1000,4,12,70,56
    
    100,1000,4,13,70,58
    
    100,1000,4,12,69,57
    
    

    运行方法
    go test -v -run=FitClass
    最终输出的数据和scipy的结果差不多

    程序输出


    图片.png

    输出的拟合图像如下


    图片.png

    相关文章

      网友评论

          本文标题:使用golang做最小二乘法的线性拟合

          本文链接:https://www.haomeiwen.com/subject/lzokvktx.html