基本的使用情况差不多介绍的差不多了,我也是边学习边写博客,其中难免有很多理解错误的地方或者理解不到位的地方,还请各位博友多多指点。
介绍完了使用,就应该自己动手去实践了,因此,这里再介绍一下实验数据的问题。Keras提供了常用的几种数据集的下载,可以直接拿来用,非常方便。下面我们来看一下。
一、CIFAR10小图分类
keras.datasets.cifar10
CIFAR10数据集包含有5万张32*32的训练彩色图,共标记了超过10个分类;还有1万张测试图片。
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
返回: 2个元组
X_train, X_test : 每个元素是一个uint8类型(0~255),代表RGB图像数据的每个像素值。其shape是(nb_samples, 3, 32, 32),nb_samples是样本数量。整体意思就是nb_samples张图片,每张图片有3个通道(代表RGB),每个通道数据包含一个32*32的像素值矩阵。
y_train, y_test: 每个元素是一个uint8类型,代表分类编号(0-9),其shape是(nb_samples, ),就是只有一个列向量,每个值表示对应样本的分类标号。
二、CIFAR100小图分类
keras.datasets.cifar100
CIFAR100数据集包含有5万张32*32的训练彩色图,共标记了超过100个分类;还有1万张测试图片。
(X_train, y_train), (X_test, y_test) =cifar100.load_data(label_mode='fine')
参数:
label_mode: “fine” 或 “coarse”,分别表示分类标准比较严格或者分类标准比较宽泛。
返回 : 2个元组
X_train, X_test: 同CIFAR10
y_train, y_test: 同CIFAR10
三、IMDB数据集,影评情感分析
keras.datasets. imdb
IMDB数据集包含有2.5万条电影评论,被标记为积极和消极。影评会经过预处理,把每一条影评编码为一个词索引(数字)sequence。为了方便起见,单词根据在整个数据集中的出现频率大小建立索引,所以”3”就代表在数据中出现频率第三的单词。这样可以快速筛选出想要的结果,比如想要top10000,但是排除top20的单词。
同时约定,”0”不代表特定的单词,而是代表一些未知词。
(X_train, y_train),(X_test, y_test) = imdb.load_data(path="imdb.pkl",nb_words=None, skip_top=0, maxlen=None, test_split=0.1, seed=113)
参数:
path: 如果本地(‘~/.keras/datasets/’ + path)已经有该数据集,则使用本地的;则否会从联网下载该数据集(cPickle格式)到本地。
nb_words : 表示频率最高的前nb_words个单词,其他频率的词用”0”表示。如果为None,则为每个单词都建立索引号。
skip_top : 表示忽略频率最高的前skip_top个单词,用”0”表示它们。
maxlen : sequence最大长度,过长的会被截断。如果为None,则表示不限制最大长度。
test_split: 测试数据占总数据的比例。
seed : Seed for reproducible datashuffling
返回 : 2个元组
X_train, X_test : sequence列表,就是一列索引号。如果nb_words参数明确定义,那么sequence最大索引是nb_words-1。如果maxlen参数明确定义,那么最大的sequence长度就是maxlen。
y_train, y_test: 0或1序列。
四、新闻主题分类(数据集来源于路透社新闻专线)
keras.datasets. reuters
数据集包含有来自于路透社的11228条新闻数据,被标记了超过46个分类。和IMDB数据集一样,每一条数据被编码为一条索引序列。
(X_train, y_train),(X_test, y_test) = reuters.load_data(path="reuters.pkl",nb_words=None, skip_top=0, maxlen=None, test_split=0.1, seed=113)
使用说明同IMDB。该数据集可以通过以下代码获取单词的索引。
word_index = reuters.get_word_index(path="reuters_word_index.pkl")
返回字典实例,键为单词,值为索引。比如,word_index[“giraffe”]会返回1234.
五、MNIST分类(手写数字识别)
keras.datasets. mnist
数据集有6万张2828的灰度图,共分为10类,含1万张测试图。
(X_train, y_train), (X_test, y_test) = mnist.load_data()
返回: 2个元组
X_train, X_test : 每个元素是一个uint8类型(0~255),代表灰度图像的每个像素值。其shape是(nb_samples, 28, 28),nb_samples是样本数量。整体意思就是nb_samples张图片,每张图片包含一个2828的像素值矩阵。(因为是灰度图,所以只有一个通道,可以理解shape为[nb_samples, 1 , 28 , 28])
y_train, y_test: 每个元素是一个uint8类型,代表数字0-9,其shape是(nb_samples, ),就是只有一个列向量,每个值表示对应图片中的数字是多少。
网友评论