datasets
是huggingface维护的一个轻量级可扩展的数据加载库,其兼容pandas、numpy、pytorch和tensorflow,使用简便。根据其官方简介:Datasets
originated from a fork of the awesome TensorFlow Datasets
,datasets是源自于tf.data的,两者之间的主要区别可参考这里。
tf.data相较于pytorch的dataset/dataloader来说,(个人认为)其最强大的一点是可以处理大数据集,而不用将所有数据加载到内存中。datasets的序列化基于Apache Arrow(tf.data基于tfrecord),熟悉spark的应该对apache arrow略有了解。datasets使用的是内存地址映射的方式来直接从磁盘上来读取数据,这使得他能够处理大量的数据。用法简介可参考Quick tour。下面对datasets用法做一些简单的记录。
1、载入数据
datasets
提供了许多NLP相关的数据集,使用list_datasets()
可查看提供的相关数据集。这里使用自己的是自己本地的数据集:
# jupyter notebook中设置交互式输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import pandas as pd
import numpy as np
from datasets import load_dataset
import torch
import torch.nn as nn
import transformers
import os
from pprint import pprint
dpath = '/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/'
files = [os.path.join(dpath, f) for f in os.listdir(dpath) if f.endswith('.csv')]
files
# 输出
['/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00001-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00008-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00017-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00003-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00010-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00009-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00015-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00007-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00012-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00019-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00004-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00021-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00014-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00016-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00011-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00022-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00023-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00020-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00006-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00013-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00005-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00000-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00018-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv',
'/home/jovyan/media/huiting_1102/sequence_data_month45_with_dense/part-00002-07a7f935-a21c-4c8c-9e3d-ad2d7fc2aedf-c000.csv']
载入本地csv数据集:
dataset = load_dataset('/opt/miniconda3/lib/python3.7/site-packages/datasets/csv.py',
data_files=files, delimiter='\t')
这里有一点需要注意,原始用法是load_dataset('csv', files)
,然后再load数据集的时候会从datasets github库中拉取读取csv数据的脚本,用此脚本来读取本地数据。但是在读取的过程中非常容易出现网络错误,这里的做法是直接将github 库中的csv读取脚本直接下载到本地datasets安装库中,如上所示,将csv.py放入datasets pip安装的位置即可解决此问题。
简单查看一下datasets中的信息:
pprint(dataset)
# 输出
{'train': Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 24423486)}
dataset.shape
# {'train': (24423486, 11)}
dataset.num_rows
# {'train': 24423486}
dataset.column_names
{'train': ['id',
'label',
'source',
'ts4_cnt',
'ts5_cnt',
'camp_id',
'spot_id',
'gender_0_spot_id_mean',
'gender_1_spot_id_mean',
'gender_0_camp_id_mean',
'gender_1_camp_id_mean']}
dataset['train'].features
{'id': Value(dtype='string', id=None),
'label': Value(dtype='float64', id=None),
'source': Value(dtype='string', id=None),
'ts4_cnt': Value(dtype='int64', id=None),
'ts5_cnt': Value(dtype='int64', id=None),
'camp_id': Value(dtype='string', id=None),
'spot_id': Value(dtype='string', id=None),
'gender_0_spot_id_mean': Value(dtype='float64', id=None),
'gender_1_spot_id_mean': Value(dtype='float64', id=None),
'gender_0_camp_id_mean': Value(dtype='float64', id=None),
'gender_1_camp_id_mean': Value(dtype='float64', id=None)}
dataset['train'][:3]
{'id': ['c94069f423b47e163502a7a418d076ae',
'b2c72851e56c03f05b8d75174a32149d',
'74ba33298cbced79e6029676c4016dca'],
'label': [999.0, 1.0, 999.0],
'source': ['test', 'old', 'test'],
'ts4_cnt': [1, 22, 105],
'ts5_cnt': [1, 30, 108],
'camp_id': ['75 113',
'398 517 473 473 473 264 264 473 844 844 473 133 133 436 133 133 5574 36 133 436 607 133 133 607 1017 133 161 8392 1624 1624 1461 2765 4586 718 718 1298 355 1304 607 221 4466 795 12593 536 536 1228 1284 536 536 536 608 795',
'36 36 36 36 36 36 36 36 2618 36 2618 36 36 36 36 1454 1454 36 36 36 36 36 36 36 1454 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 2618 36 36 36 36 2618 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 2618 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 5859 36 36 5859 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 129 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 5037 36 36 36 36 36 36 36 36 36 36 319 36 36 36 36'],
'spot_id': ['8662 22475',
'5576 26620 2205 2205 1405 35889 35889 1405 11915 12487 1405 230 275 836 275 275 14776 199 275 836 1225 230 275 1225 5772 275 1656 163445 3057 3057 3171 4354 134965 1276 1276 93168 138 1664 1225 141196 41233 1956 78753 15869 13467 797 7159 9269 13467 13519 1647 1956',
'633 132 245 245 245 245 245 245 2159 132 2159 633 633 633 132 23252 23252 245 245 245 245 245 245 245 29987 633 245 633 199 245 245 245 199 245 199 245 132 633 245 245 245 132 132 132 245 245 132 199 245 633 633 245 245 245 245 245 245 245 132 199 245 2159 199 132 633 199 2159 132 132 245 132 132 633 199 633 633 245 5346 245 132 633 633 2159 245 132 633 132 199 132 199 199 633 199 633 633 132 132 132 199 5346 132 5346 5346 5346 633 5346 5346 245 245 199 199 199 199 132 199 245 245 245 199 199 199 132 132 199 122 132 132 5346 633 633 633 633 633 633 5346 199 132 245 245 199 199 132 11004 199 5346 11004 132 245 5346 199 245 199 132 199 132 633 132 5346 5346 199 5346 199 132 199 132 245 245 199 199 199 132 245 199 29165 633 633 5346 5346 199 132 132 245 245 633 245 633 132 132 633 132 132 132 132 199 199 199 633 5695 245 5346 5346 199 245 132 245 132 199 199 13996 132 132 132 199'],
'gender_0_spot_id_mean': [0.6473253581552576,
0.5286873564118875,
0.3833915059748672],
'gender_1_spot_id_mean': [0.35267464184474245,
0.4713126435881125,
0.6166084940251326],
'gender_0_camp_id_mean': [0.7617982406658169,
0.5803965671172192,
0.38961354242060287],
'gender_1_camp_id_mean': [0.2382017593341831,
0.41960343288278074,
0.610386457579397]}
在上面简单查看了dataset的内部数据结构类型和一些元信息。
虽然上面我们读取的是csv格式文件,但datasets支持多种数据格式:csv、json、text文件,以及in-memory中的pandas dataframe、python dictionary均可以。
详细用法可参考:https://huggingface.co/docs/datasets/loading_datasets.html。
2、处理数据集
2.1 数据选取
为方便这里仅使用10000个数据来进行演示。
# 从train dataset中选取10000行来作为demo数据
dataset = dataset['train'].select(np.arange(10000))
len(dataset)
# 10000
由于上面读取数据集时没有使用split参数,也没有分开指定各个文件所属数据集,因此所有文件均读入一个dataset中。各个dataset组成一个dict,如上所示使用字典取值方式可以获取想要的dataset。
pprint(dataset)
# 输出
Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 10000)
dataset.features
# 输出
{'id': Value(dtype='string', id=None),
'label': Value(dtype='float64', id=None),
'source': Value(dtype='string', id=None),
'ts4_cnt': Value(dtype='int64', id=None),
'ts5_cnt': Value(dtype='int64', id=None),
'camp_id': Value(dtype='string', id=None),
'spot_id': Value(dtype='string', id=None),
'gender_0_spot_id_mean': Value(dtype='float64', id=None),
'gender_1_spot_id_mean': Value(dtype='float64', id=None),
'gender_0_camp_id_mean': Value(dtype='float64', id=None),
'gender_1_camp_id_mean': Value(dtype='float64', id=None)}
使用dataset.features
可以获取dataset中各个字段的数据类型。
我们也可以通过指定dataset中的set_format()
方法来指定各个字段的输出类型并限制输出的字段,这一般在load数据时用到。通过reset_format()
可以重置数据类型:
dataset.set_format(type='torch', columns=['label', 'gender_0_spot_id_mean', 'gender_1_spot_id_mean',
'gender_0_camp_id_mean', 'gender_1_camp_id_mean'])
dataset[:3]
# 输出
{'label': tensor([999., 1., 999.], dtype=torch.float64),
'gender_0_spot_id_mean': tensor([0.6473, 0.5287, 0.3834], dtype=torch.float64),
'gender_1_spot_id_mean': tensor([0.3527, 0.4713, 0.6166], dtype=torch.float64),
'gender_0_camp_id_mean': tensor([0.7618, 0.5804, 0.3896], dtype=torch.float64),
'gender_1_camp_id_mean': tensor([0.2382, 0.4196, 0.6104], dtype=torch.float64)}
dataset.format
# 输出
{'type': 'torch',
'format_kwargs': {},
'columns': ['label',
'gender_0_spot_id_mean',
'gender_1_spot_id_mean',
'gender_0_camp_id_mean',
'gender_1_camp_id_mean'],
'output_all_columns': False}
dataset.reset_format()
dataset.format
# 输出
{'type': None,
'format_kwargs': {},
'columns': ['id',
'label',
'source',
'ts4_cnt',
'ts5_cnt',
'camp_id',
'spot_id',
'gender_0_spot_id_mean',
'gender_1_spot_id_mean',
'gender_0_camp_id_mean',
'gender_1_camp_id_mean'],
'output_all_columns': False}
我们也可以使用类似pandas的方法来对dataset进行过滤和查看:
# 过滤得到label为0的样本
label_mask = np.array(dataset['label']) == 0
# 挑选label列中的前10个进行查看
dataset[label_mask]['label'][:10]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
2.2 数据处理
在这一部分我们将对数据进行简单的处理:划分数据集,数据变换,数据保存以及重载。
数据集划分
# Be sure to shard before using any randomizing operator (such as shuffle).
# It is best if the shard operator is used early in the dataset pipeline.
dataset_train = dataset.filter(lambda example: example['label']!=999)
dataset_test = dataset.filter(lambda example: example['label']==999)
dataset_train
# 输出
Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 6127)
使用train_test_split
将训练集划分为训练集和验证集, 这里与sklearn中api的唯一不同是没有stratify参数,无法按label比例来进行层次采样:
train = dataset_train.train_test_split(test_size=0.1, seed=123)
dataset_trn = train['train']
dataset_val = train['test']
print('dataset train:\n', dataset_trn)
print('\n')
print('dataset validation:\n', dataset_val)
# 输出
dataset train:
Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 5514)
dataset validation:
Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 613)
使用shard
可以将数据集划分为更小的数据集,加载时可分shard加载:
dataset_trn.shard(num_shards=50, index=1)
# 输出
Dataset(features: {'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None), 'camp_id': Value(dtype='string', id=None), 'spot_id': Value(dtype='string', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None)}, num_rows: 111)
类似的,使用shuffle
来打乱数据集:
shuffle_trn = dataset_trn.shuffle(seed=123)
数据变换
在上面可以看到,数据集中camp_id和spot_id列是字符串类型的,在NLP任务中,输入一般为indices,因此这里做一些简单的转换示例:
def word2idx_camp(example, col='camp_id'):
lst = []
for i in example[col]:
lst.append([int(j)+1 for j in i.split()])
example[col] = lst
return example
# 编码字符
encoded_trn = shuffle_trn.map(lambda x:word2idx_camp(x, 'camp_id'), num_proc=8, batched=True)
encoded_trn = encoded_trn.map(lambda x:word2idx_camp(x, 'spot_id'), num_proc=8, batched=True)
使用batched表示一次处理多行,也可将其设置为False每次仅处理一行,这在数据增强中可能更常用。因为dataset允许我们每次处理的输入行数不等于输出行数,因此数据增强如单词替换一个变多个数据时,进行将label 字段进行copy多分即可,也即如下:
>>> def chunk_examples(example):
# 长句边多个短句
... chunks = []
... chunks += [example[sentence][i:i + 50] for i in range(0, len(sentence), 50)]
labels = [example['label']]*len(chunks)
... return {'chunks': chunks, 'label':labels}
查看转换后的数据:
encoded_trn
Dataset(features: {'camp_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'spot_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None)}, num_rows: 5514)
pprint(encoded_trn[0])
# 输出
{'camp_id': [1592,
1592,
16,
4280,
16,
4916,
356,
356,
48,
2766,
3729,
16,
16,
1635,
300,
16,
6188,
3540,
5273,
1274,
2713,
1462,
54,
463,
6214,
16,
9491,
1299,
7,
3321,
2713],
'gender_0_camp_id_mean': 0.7141306714111253,
'gender_0_spot_id_mean': 0.7170624102313645,
'gender_1_camp_id_mean': 0.2858693285888747,
'gender_1_spot_id_mean': 0.2829375897686355,
'id': '38a12968339ccd5acebb3bdeed2fdb0f',
'label': 0.0,
'source': 'old',
'spot_id': [9988,
9988,
17534,
17551,
8197,
8208,
139,
139,
3433,
5223,
5435,
2101,
6165,
14157,
1848,
10467,
24136,
22641,
6228,
6327,
9933,
1849,
332,
2291,
24214,
10467,
55522,
317148,
30,
42465,
9933],
'ts4_cnt': 2,
'ts5_cnt': 29}
处理测试集:
encoded_val = dataset_val.map(lambda x:word2idx_camp(x, 'camp_id'), num_proc=8, batched=True)
encoded_val = encoded_val.map(lambda x:word2idx_camp(x, 'spot_id'), num_proc=8, batched=True)
datasets也提供了合并数据集的方法concatenate_datasets
,当然合并的前提是各个数据集数据类型相同。如两个数据集来自于同一份数据抽取得到的(或者是同时读取的,如在data_files
中指定不同dataset名字,或通过split指定名字),则最好先通过flatten_indices()
来刷新indices,否则合并时容易报错:
from datasets import concatenate_datasets
encoded_train = concatenate_datasets([encoded_trn, encoded_val])
encoded_train
# 输出
Dataset(features: {'camp_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'spot_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None)}, num_rows: 6127)
保存、重新加载数据集
datasets允许我们将处理好的数据集直接进行保存以方便复用。
from datasets import load_from_disk
encoded_train.save_to_disk('demo_data/')
reloaded_encoded_dataset = load_from_disk('demo_data')
reloaded_encoded_dataset
# 输出
Dataset(features: {'camp_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'gender_0_camp_id_mean': Value(dtype='float64', id=None), 'gender_0_spot_id_mean': Value(dtype='float64', id=None), 'gender_1_camp_id_mean': Value(dtype='float64', id=None), 'gender_1_spot_id_mean': Value(dtype='float64', id=None), 'id': Value(dtype='string', id=None), 'label': Value(dtype='float64', id=None), 'source': Value(dtype='string', id=None), 'spot_id': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'ts4_cnt': Value(dtype='int64', id=None), 'ts5_cnt': Value(dtype='int64', id=None)}, num_rows: 6127)
数据加载+pytorch模型训练
encoded_trn.set_format(type='torch', columns=['label',
'gender_0_camp_id_mean', 'gender_0_spot_id_mean',
'gender_1_camp_id_mean', 'gender_1_spot_id_mean'],
device='cuda')
dataloader = torch.utils.data.DataLoader(encoded_trn, batch_size=4)
batch = next(iter(dataloader))
batch
# 输出
{'gender_0_camp_id_mean': tensor([0.7141, 0.6522, 0.6465, 0.8104], device='cuda:0'),
'gender_0_spot_id_mean': tensor([0.7171, 0.6234, 0.6305, 0.8245], device='cuda:0'),
'gender_1_camp_id_mean': tensor([0.2859, 0.3478, 0.3535, 0.1896], device='cuda:0'),
'gender_1_spot_id_mean': tensor([0.2829, 0.3766, 0.3695, 0.1755], device='cuda:0'),
'label': tensor([0., 1., 1., 0.], device='cuda:0')}
# 合并多个字段,这一步也可以放到map中
batch_features = torch.cat([batch['gender_0_camp_id_mean'].reshape(-1,1),
batch['gender_0_spot_id_mean'].reshape(-1,1),
batch['gender_1_camp_id_mean'].reshape(-1,1),
batch['gender_1_spot_id_mean'].reshape(-1,1)],
dim=-1)
batch_features
# 输出
tensor([[0.7141, 0.7171, 0.2859, 0.2829],
[0.6522, 0.6234, 0.3478, 0.3766],
[0.6465, 0.6305, 0.3535, 0.3695],
[0.8104, 0.8245, 0.1896, 0.1755]], device='cuda:0')
在上面中我们将dataset的输出设置为pytorch tensor,并设置了device为gpu。
下面我们来训练个简单的模型:
from torch import optim
import torch.functional as F
model = nn.Sequential(nn.Linear(4, 20),
nn.BatchNorm1d(20),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(20, 2))
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
model.to('cuda')
model.zero_grad()
model.train()
losses = []
dataloader = torch.utils.data.DataLoader(encoded_trn, batch_size=128)
for _ in range(5):
for batch in dataloader:
batch_features = torch.cat([batch['gender_0_camp_id_mean'].reshape(-1,1),
batch['gender_0_spot_id_mean'].reshape(-1,1),
batch['gender_1_camp_id_mean'].reshape(-1,1),
batch['gender_1_spot_id_mean'].reshape(-1,1)],
dim=-1)
labels = batch['label'].long()
out = model(batch_features)
loss = loss_fn(out, labels)
loss.backward()
optimizer.zero_grad()
optimizer.step()
# print(loss.item())
losses.append(loss.item())
可视化一下loss:
import matplotlib.pyplot as plt
plt.plot(losses)
太垃圾了!!!
上面简单介绍了使用datasets作为pytorch数据加载器的简单用法,除此之外,datasets还提供了丰富的NLP相关评估metric,详细可参见:loading_metrics和using_metric。
更多更详细的使用可以参见官网https://huggingface.co。
参考:
https://huggingface.co
https://github.com/huggingface/datasets
网友评论