美文网首页
公众号-科研私家菜学习记录(5)

公众号-科研私家菜学习记录(5)

作者: 明眸意海 | 来源:发表于2021-08-06 19:26 被阅读0次

    交叉验证与模型选择

    • 定义:亦称循环估计,是一种统计学上将数据样本切割成较小子集的实用方法。可以先在一个子集上做建模分析,而其它子集则用来做后续对此分析的效果评价及验证。一开始的子集被称为训练集(Train set)。而其它的子集则被称为验证集(Validation set)或测试集(Test set)。
    • 交叉验证是一种评估统计分析、机器学习算法对独立于训练数据的数据集的泛化能力(Generalize)。交叉验证大致分为三种:简单交叉验证(hold-outcross validation)、k-折交叉验证(k-fold cross validation)和留一交叉验证(leave one out cross validation)。
    • R包实现:ISLR
    1. 示例
    library(ISLR)
    data('Auto') ##载入数据集
    head(Auto)
    set.seed(111)
    
    n=nrow(Auto) ## 共392个样本
    train=sample(n,n/2) ## 选择50%的样本作为训练集
    test=(-train) ## 50%的样本作为测试集
    
    
    lm.fit=lm(mpg~horsepower, data=Auto, subset=train) 
    ## 使用训练集拟合线性回归模型
    ## 以horsepower来拟合mpg的值
    mean((Auto[test,'mpg']-predict(lm.fit, newdata=Auto[test,]))^2) ## 计算均方误差
    
    lm.fit2=lm(mpg~poly(horsepower,2), data=Auto, subset=train) 
    ## 使用训练集拟合多项式回归
    mean((Auto[test,'mpg']-predict(lm.fit2, newdata=Auto[test,]))^2) ## 均方误差
    

    简单交叉验证

    MSE=matrix(NA,10,10) ## 建10行10列的Matrix
    
    for(seed in 1:10){
      set.seed(seed)
      train=sample(n,n/2)
      test=(-train)
      for(degree in 1:10){
        lm.fit=lm(mpg~poly(horsepower,degree), data=Auto,subset=train)
        MSE[seed,degree]=mean((Auto[test,'mpg']-predict(lm.fit,newdata=Auto[test,]))^2)  
      }
    }
    ## 结果可视化
    plot(MSE[1,],
         ylim=range(MSE),
         type='l',
         lwd=2,
         col=rainbow(10)[1],
         xlab='degree',
         ylab='the estimated test MSE')
    
    for(seed in 2:10){
      points(MSE[seed,],
             type='l',
             lwd=2,
             col=ggsci::pal_npg()(10)[seed])
    } 
    

    K-折交叉验证

    library(boot)
    cv.error.10=rep(NA,10)
    for(degree in 1:10){
      glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)
      set.seed(1234)
      cv.error.10[degree] <- cv.glm(Auto,glm.fit,K=10)$delta[1]
    }
    cv.error.10
    
    plot(cv.error.10,type='b',xlab='degree',col=ggsci::pal_npg()(10))
    

    留一交叉验证

    loocv.error=rep(NA,10)
    for(degree in 1:10){
      glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)
      loocv.error[degree]=cv.glm(Auto,glm.fit,K = nrow(Auto))$delta[1]
    }
    loocv.error 
    
    plot(loocv.error,type='b',xlab='degree',col=ggsci::pal_npg()(10))
    

    相关文章

      网友评论

          本文标题:公众号-科研私家菜学习记录(5)

          本文链接:https://www.haomeiwen.com/subject/mhrmvltx.html