以Resnet为例,找到了一个微软的resnet-50的预训练模型
from transformers import AutoImageProcessor, ResNetForImageClassification
# 加载前处理处理器,自动构建;功能:输入图片,输出Tenser(1, 3, 224, 224)
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
# 加载与训练模型
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
from datasets import load_dataset
import numpy as np
# 目录结构加载数据集
"""
目录结构如下:
flowers102_dir
├── train
│ ├── 0
│ │ ├── image_06736.jpg
│ │ └── image_06771.jpg
│ ├── 1
│ │ ├── image_05091.jpg
│ │ └── image_05146.jpg
│ └── 99
│ ├── image_07900.jpg
│ └── image_07941.jpg
└── validation
├── 0
│ ├── image_06738.jpg
│ └── image_06773.jpg
├── 1
│ ├── image_05100.jpg
│ └── image_05138.jpg
└── 99
├── image_07902.jpg
└── image_07935.jpg
"""
dataset = load_dataset("imagefolder", data_dir="/test/flowers102_dir")
# 定义转换函数,这个地方遇到了坑。
# 重要:前处理的输出为(1, 3, 224, 224),但是此处的pixel_values需要的是(3, 224, 224),所以需要reshape或者降维
def trans(row):
row["pixel_values"] = np.asarray(processor(row["image"])["pixel_values"][0])
return row
train_dataset = dataset['train'].map(trans)
val_dataset = dataset['validation'].map(trans)
from transformers import TrainingArguments, Trainer
# 可以自行根据需要设置训练参数
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
# 构建Trainer,参数还有很多,此处为最基本的。
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
# 开始训练
trainer.train()
网友评论