美文网首页
使用golang做最小二乘法的线性拟合

使用golang做最小二乘法的线性拟合

作者: FredricZhu | 来源:发表于2020-10-29 15:44 被阅读0次

const.go

package main

var (
    ColNames = []string{"feature", "document", "machine", "load_time",
        "search_time", "reduce_and_save"}

    ResColNames = []string{"feature", "document", "machine", "total"}
)

fit_classification.go

package main

import (
    "fmt"
    "log"
    "math"
    "os"

    "github.com/go-gota/gota/dataframe"
    "github.com/go-gota/gota/series"
    "gonum.org/v1/gonum/optimize"
    "gonum.org/v1/plot"
    "gonum.org/v1/plot/plotter"
    "gonum.org/v1/plot/plotutil"
    "gonum.org/v1/plot/vg"
)

// 根据条件修改原先值
func getTotal(s series.Series) series.Series {

    loadTime, _ := s.Val(3).(int)
    searchTime, _ := s.Val(4).(int)
    rAsTime, _ := s.Val(5).(int)

    res := loadTime + searchTime + rAsTime
    resF := float64(res) / float64(60)
    return series.Floats(resF)
}

func getDoc(s series.Series) series.Series {
    document, _ := s.Val(1).(float64)
    resF := float64(2*document) / float64(1000)
    return series.Floats(resF)
}

// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
    // 获取total列
    *clsDF = clsDF.Select(ColNames)
    totalSeries := clsDF.Rapply(getTotal)
    totalSeries.SetNames("total")
    *clsDF = clsDF.CBind(totalSeries)

    // document列 *2/1000
    *clsDF = clsDF.Select(ResColNames)
    newDocSeries := clsDF.Rapply(getDoc)
    newDocSeries.SetNames("new_doc")
    *clsDF = clsDF.CBind(newDocSeries)
    *clsDF = clsDF.Drop([]string{"document"})
    *clsDF = clsDF.Rename("document", "new_doc")
    *clsDF = clsDF.Select(ResColNames)
}

// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
    // 开始数据拟合

    // 实际观测点
    actPoints = plotter.XYs{}
    // N行数据产生N个点
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        val := clsDF.Elem(i, 3).Val().(float64)

        actPoints = append(actPoints, plotter.XY{
            X: float64(document) / float64(machine),
            Y: val,
        })
    }

    result, err := optimize.Minimize(optimize.Problem{
        Func: func(x []float64) float64 {
            if len(x) != 2 {
                panic("illegal x")
            }
            a := x[0]
            b := x[1]
            var sum float64
            for _, point := range actPoints {
                y := a*point.X + b
                sum += math.Abs(y - point.Y)
            }
            return sum
        },
    }, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
    if err != nil {
        panic(err)
    }

    // 最小二乘法拟合出来的k和b值
    fa, fb = result.X[0], result.X[1]
    expPoints = plotter.XYs{}
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        x := float64(document) / float64(machine)
        expPoints = append(expPoints, plotter.XY{
            X: x,
            Y: fa*float64(x) + fb,
        })
    }

    return
}

func dataPlot(actPoints, expPoints plotter.XYs) {
    plt, err := plot.New()
    if err != nil {
        panic(err)
    }
    plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10

    if err := plotutil.AddLinePoints(plt,
        "expPoints", expPoints,
        "actPoints", actPoints,
    ); err != nil {
        panic(err)
    }

    if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
        panic(err)
    }
}

// FitClassification 分类曲线拟合函数
func FitClassification() {
    clsData, err := os.Open("classification_data.csv")
    if err != nil {
        log.Fatal(err)
    }

    defer clsData.Close()
    clsDF := dataframe.ReadCSV(clsData)
    // 数据预处理
    dataPrepare(&clsDF)
    // 数据预处理完成
    fmt.Println("数据预处理完成...")
    fmt.Println(clsDF)

    // 数据拟合
    actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
    // 拟合完成,输出fa,fb
    fmt.Println("Fa", fa, "Fb", fb)

    // 数据绘图
    dataPlot(actPoints, expPoints)
    fmt.Println("绘制完成,图形地址: classification-fit.png")
}

main.go

package main

func main() {

}

main_test.go

package main

import "testing"

// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
    FitClassification()
}

运行数据

feature,document,machine,load_time,search_time,reduce_and_save

100,5000,4,19,130,67

100,5000,4,12,130,61

100,5000,4,13,127,61

100,5000,4,13,124,63

100,5000,4,13,129,59

100,5000,4,13,125,60

100,5000,4,13,123,63

100,5000,4,13,129,61

100,5000,4,12,127,61

100,5000,4,12,125,62

100,5000,4,13,128,59

100,5000,4,13,128,61

100,5000,4,12,130,60

100,5000,4,12,125,61

100,5000,4,13,127,60

100,5000,4,13,126,63

100,5000,4,13,127,64

100,5000,3,18,160,67

100,5000,3,13,166,59

100,5000,3,12,167,61

100,5000,3,12,168,60

100,5000,3,12,170,61

100,5000,3,12,154,63

100,5000,3,13,168,60

100,5000,3,12,167,60

100,5000,3,12,148,64

100,5000,3,12,167,65

100,5000,3,12,164,60

100,5000,3,12,150,59

100,5000,2,20,217,65

100,5000,2,13,205,63

100,5000,2,14,204,60

100,5000,2,13,205,55

100,5000,2,14,210,59

100,5000,2,13,201,59

100,5000,2,13,211,59

100,5000,2,13,217,59

100,5000,2,14,207,60

100,5000,2,14,209,59

100,5000,2,14,214,61

100,5000,2,13,210,61

100,5000,1,24,376,60

100,5000,1,21,393,58

100,5000,1,20,386,58

100,5000,1,22,384,59

100,5000,1,21,387,59

100,4000,4,18,112,70

100,4000,4,13,118,62

100,4000,4,12,114,63

100,4000,4,14,112,65

100,4000,4,12,113,62

100,4000,4,14,109,61

100,4000,4,13,118,63

100,4000,4,12,112,61

100,4000,4,12,110,61

100,4000,4,11,111,63

100,4000,4,13,112,67

100,4000,4,12,110,60

100,4000,4,12,113,60

100,3000,4,19,100,66

100,3000,4,13,99,64

100,3000,4,12,100,65

100,3000,4,13,103,61

100,3000,4,14,104,63

100,3000,4,14,99,63

100,2000,4,18,90,67

100,2000,4,13,86,65

100,2000,4,13,85,63

100,2000,4,14,87,62

100,2000,4,13,85,61

100,2000,4,13,86,64

100,2000,4,13,85,61

100,2000,4,13,89,58

100,2000,4,13,85,61

100,2000,4,12,85,60

100,2000,4,13,85,66

100,2000,4,12,86,59

100,2000,4,12,86,61

100,2000,4,12,82,60

100,2000,4,13,87,62

100,2000,4,12,83,65

100,2000,4,12,85,60

100,2000,4,13,87,60

100,2000,4,12,86,59

100,1000,4,19,75,63

100,1000,4,14,71,61

100,1000,4,13,75,59

100,1000,4,14,72,61

100,1000,4,13,72,59

100,1000,4,13,71,59

100,1000,4,13,72,59

100,1000,4,14,70,62

100,1000,4,12,72,58

100,1000,4,13,71,59

100,1000,4,13,70,62

100,1000,4,12,72,59

100,1000,4,20,71,58

100,1000,4,13,69,60

100,1000,4,12,73,60

100,1000,4,13,69,59

100,1000,4,13,71,60

100,1000,4,13,73,62

100,1000,4,12,71,59

100,1000,4,12,70,56

100,1000,4,13,70,58

100,1000,4,12,69,57

运行方法
go test -v -run=FitClass
最终输出的数据和scipy的结果差不多

程序输出


图片.png

输出的拟合图像如下


图片.png

相关文章

网友评论

      本文标题:使用golang做最小二乘法的线性拟合

      本文链接:https://www.haomeiwen.com/subject/lzokvktx.html