美文网首页
数据集分割方法

数据集分割方法

作者: Byte猫 | 来源:发表于2019-03-18 12:11 被阅读0次

在机器学习建模过程中,通行的做法通常是将数据集分为训练集和测试集。测试集是与训练独立的数据,完全不参与训练,用于最终模型的评估。
在训练过程中,经常会出现过拟合的问题,就是模型可以很好的匹配训练数据,却不能很好在预测训练集外的数据。如果此时就使用测试数据来调整模型参数,就相当于在训练时已知部分测试数据的信息,会影响最终评估结果的准确性。通常的做法是在训练数据再中分出一部分做为验证(Validation)数据,用来评估模型的训练效果。
验证数据取自训练数据,但不参与训练,这样可以相对客观的评估模型对于训练集之外数据的匹配程度。模型在验证数据中的评估常用的是交叉验证,又称循环验证。它将原始数据分成K组(K-Fold),将每个子集数据分别做一次验证集,其余的K-1组子集数据作为训练集,这样会得到K个模型。这K个模型分别在验证集中评估结果,最后的误差MSE(Mean Squared Error)加和平均就得到交叉验证误差。交叉验证有效利用了有限的数据,并且评估结果能够尽可能接近模型在测试集上的表现,可以做为模型优化的指标使用。

一、随机划分

train_test_split是最简单的数据集分割函数,功能是将样本按比例分割。

# coding = utf-8
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data
y = iris.target

x_train,x_test,y_train,y_test = train_test_split(X, y, test_size=0.2,
                                                 stratify=y,  # 按照标签来分层采样
                                                 shuffle=True, # 是否先打乱数据的顺序再划分
                                                 random_state=1)   # 控制将样本随机打乱
x_train,x_valid,y_train,y_valid = train_test_split(x_train, y_train, test_size=0.4,
                                                 stratify=y_train,
                                                 shuffle=True,
                                                 random_state=1) 

然而这种方式并不是很好,有两大缺点:一是浪费数据,二是容易过拟合且矫正方式不方便

二、K折划分

1、KFold

K-Fold是最简单的K折交叉,n-split就是K值,shuffle指是否对数据洗牌,random_state为随机种子
K值的选取会影响bias和viriance。K越大,每次投入的训练集的数据越多,模型的Bias越小。但是K越大,又意味着每一次选取的训练集之前的相关性越大,而这种大相关性会导致最终的test error具有更大的Variance。一般来说,根据经验我们一般选择k=5或10。

# coding = utf-8
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier as GBDT
from sklearn.metrics import precision_score

iris = datasets.load_iris()
X = iris.data
y = iris.target

x_train,x_test,y_train,y_test = train_test_split(X, y, test_size=0.2,
                                                 stratify=y,  # 按照标签来分层采样
                                                 shuffle=True, # 是否先打乱数据的顺序再划分
                                                 random_state=1)   # 控制将样本随机打乱

clf = GBDT(n_estimators=100)
precision_scores = []

kf = KFold(n_splits=5, random_state=0, shuffle=False)
for train_index, valid_index in kf.split(x_train, y_train):
    x_train, x_valid = X[train_index], X[valid_index]
    y_train, y_valid = y[train_index], y[valid_index]
    clf.fit(x_train, y_train)
    y_pred = clf.predict(x_valid)
    precision_scores.append(precision_score(y_valid, y_pred, average='micro'))

print('Precision', np.mean(precision_scores))

2、StratifiedKFold

StratifiedKFold用法类似Kfold,但是他是分层采样,确保训练集、验证集中各类别样本的比例与原始数据集中相同。

# coding = utf-8
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier as GBDT
from sklearn.metrics import precision_score

iris = datasets.load_iris()
X = iris.data
y = iris.target

x_train,x_test,y_train,y_test = train_test_split(X, y, test_size=0.2,
                                                 stratify=y,  # 按照标签来分层采样
                                                 shuffle=True, # 是否先打乱数据的顺序再划分
                                                 random_state=1)   # 控制将样本随机打乱

clf = GBDT(n_estimators=100)
precision_scores = []

kf = StratifiedKFold(n_splits=5, random_state=0, shuffle=False)
for train_index, valid_index in kf.split(x_train, y_train):
    x_train, x_valid = X[train_index], X[valid_index]
    y_train, y_valid = y[train_index], y[valid_index]
    clf.fit(x_train, y_train)
    y_pred = clf.predict(x_valid)
    precision_scores.append(precision_score(y_valid, y_pred, average='micro'))

print('Precision', np.mean(precision_scores))

3、StratifiedShuffleSplit

StratifiedShuffleSplit允许设置设置train/valid对中train和valid所占的比例

# coding = utf-8
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier as GBDT
from sklearn.metrics import precision_score

iris = datasets.load_iris()
X = iris.data
y = iris.target

x_train,x_test,y_train,y_test = train_test_split(X, y, test_size=0.2,
                                                 stratify=y,  # 按照标签来分层采样
                                                 shuffle=True, # 是否先打乱数据的顺序再划分
                                                 random_state=1)   # 控制将样本随机打乱

clf = GBDT(n_estimators=100)
precision_scores = []

kf = StratifiedShuffleSplit(n_splits=10, train_size=0.6, test_size=0.4, random_state=0)
for train_index, valid_index in kf.split(x_train, y_train):
    x_train, x_valid = X[train_index], X[valid_index]
    y_train, y_valid = y[train_index], y[valid_index]
    clf.fit(x_train, y_train)
    y_pred = clf.predict(x_valid)
    precision_scores.append(precision_score(y_valid, y_pred, average='micro'))

print('Precision', np.mean(precision_scores))

其他的方法如RepeatedStratifiedKFold、GroupKFold等详见sklearn官方文档。
拓展阅读

相关文章

  • 数据集分割方法

    在机器学习建模过程中,通行的做法通常是将数据集分为训练集和测试集。测试集是与训练独立的数据,完全不参与训练,用于最...

  • 4种语义分割数据集Cityscapes上SOTA方法总结

    摘要:当前语义分割方法面临3个挑战。 本文分享自华为云社区《语义分割数据集Cityscapes上SOTA方法总结[...

  • 分割数据集的方法一

    手撕数据集 1.随机数 2.哈希表 使用工具 1.sklearn.model_selection Signatur...

  • CVPR2019|In Defense of Pre-train

    用于道路驾驶的实时语义分割 Abstract 在要求苛刻的道路驱动数据集上, 语义分割方法最近取得了成功, 激发了...

  • 数据集的分割与sklearn实现

    今天聊一下数据集分割的问题,在使用机器学习算法的时候,我们需要对原始数据集进行分割,分为训练集、验证集、测试集。训...

  • 数据集分割

    一、单个文件分割训练集、测试集和验证集 一、单个文件分割多个训练集、测试集和验证集(5折) 有用的话,点个小红心哦!

  • 基于Keras实现Kaggle2013--Dogs vs. Ca

    【下载数据集】 下载链接--百度网盘关于猫的部分数据集示例 【整理数据集】 将训练数据集分割成训练集、验证集、测试...

  • 常用数据集介绍及转换

    研究背景 在深度学习中常用的数据集进行归纳和总结 语义分割的数据集 1、COCO 数据集 COCO(Common ...

  • 2.封装kNN算法之数据分割

    训练数据集与测试数据集 当我们拿到一组数据之后,通常我们需要把数据分割成两部分,即训练数据集和测试数据集。训练数据...

  • scikit-learn 中的交叉验证方法

    scikit-learn中提供了多种用于交叉验证的数据集分割方法。这里对这些方法的区别和应用场景做一个梳理。 基本...

网友评论

      本文标题:数据集分割方法

      本文链接:https://www.haomeiwen.com/subject/mkqvmqtx.html