下载数据集
解压数据集
unzip OCR_Dataset.zip -d data/
启动jupterhub
mkdir jupterhub
cd jupterhub
docker run -p 80:80 --rm --env USER_PASSWD="password you set" -v $PWD:/home/paddle -v /data:/data -v /data:/data registry.baidubce.com/paddlepaddle/paddle:2.4.2-jupyter
拷贝需要预测的图片
cd /data
mkdir sample-img
复制需要预测的图片到文件夹sample-img
新建notebook输入如下代码
import paddle
import os
import PIL.Image as Image
import numpy as np
from paddle.io import Dataset
图片信息配置 - 通道数、高度、宽度
IMAGE_SHAPE_C = 3
IMAGE_SHAPE_H = 30
IMAGE_SHAPE_W = 70
数据集图片中标签长度最大值设置 - 因图片中均为4个字符,故该处填写为4即可
LABEL_MAX_LEN = 4
class Reader(Dataset):
def init(self, data_path: str, is_val: bool = False):
"""
数据读取Reader
:param data_path: Dataset路径
:param is_val: 是否为验证集
"""
super().init()
self.data_path = data_path
# 读取Label字典
with open(os.path.join(self.data_path, "label_dict.txt"), "r", encoding="utf-8") as f:
self.info = eval(f.read())
# 获取文件名列表
self.img_paths = [img_name for img_name in self.info]
# 将数据集后1024张图片设置为验证集,当is_val为真时img_path切换为后1024张
self.img_paths = self.img_paths[-1024:] if is_val else self.img_paths[:-1024]
def __getitem__(self, index):
# 获取第index个文件的文件名以及其所在路径
file_name = self.img_paths[index]
file_path = os.path.join(self.data_path, file_name)
# 捕获异常 - 在发生异常时终止训练
try:
# 使用Pillow来读取图像数据
img = Image.open(file_path)
# 转为Numpy的array格式并整体除以255进行归一化
img = np.array(img, dtype="float32").reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)) / 255
except Exception as e:
raise Exception(file_name + "\t文件打开失败,请检查路径是否准确以及图像文件完整性,报错信息如下:\n" + str(e))
# 读取该图像文件对应的Label字符串,并进行处理
label = self.info[file_name]
label = list(label)
# 将label转化为Numpy的array格式
label = np.array(label, dtype="int32")
return img, label
def __len__(self):
# 返回每个Epoch中图片数量
return len(self.img_paths)
分类数量设置 - 因数据集中共包含0~9共10种数字+分隔符,所以是11分类任务
CLASSIFY_NUM = 11
定义输入层,shape中第0维使用-1则可以在预测时自由调节batch size
input_define = paddle.static.InputSpec(shape=[-1, IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W],
dtype="float32",
name="img")
定义网络结构
class Net(paddle.nn.Layer):
def init(self, is_infer: bool = False):
super().init()
self.is_infer = is_infer
# 定义一层3x3卷积+BatchNorm
self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,
out_channels=32,
kernel_size=3)
self.bn1 = paddle.nn.BatchNorm2D(32)
# 定义一层步长为2的3x3卷积进行下采样+BatchNorm
self.conv2 = paddle.nn.Conv2D(in_channels=32,
out_channels=64,
kernel_size=3,
stride=2)
self.bn2 = paddle.nn.BatchNorm2D(64)
# 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN
self.conv3 = paddle.nn.Conv2D(in_channels=64,
out_channels=LABEL_MAX_LEN + 4,
kernel_size=1)
# 定义全连接层,压缩并提取特征(可选)
self.linear = paddle.nn.Linear(in_features=429,
out_features=128)
# 定义RNN层来更好提取序列特征,此处为双向LSTM输出为2 x hidden_size,可尝试换成GRU等RNN结构
self.lstm = paddle.nn.LSTM(input_size=128,
hidden_size=64,
direction="bidirectional")
# 定义输出层,输出大小为分类数
self.linear2 = paddle.nn.Linear(in_features=64 * 2,
out_features=CLASSIFY_NUM)
def forward(self, ipt):
# 卷积 + ReLU + BN
x = self.conv1(ipt)
x = paddle.nn.functional.relu(x)
x = self.bn1(x)
# 卷积 + ReLU + BN
x = self.conv2(x)
x = paddle.nn.functional.relu(x)
x = self.bn2(x)
# 卷积 + ReLU
x = self.conv3(x)
x = paddle.nn.functional.relu(x)
# 将3维特征转换为2维特征 - 此处可以使用reshape代替
x = paddle.tensor.flatten(x, 2)
# 全连接 + ReLU
x = self.linear(x)
x = paddle.nn.functional.relu(x)
# 双向LSTM - [0]代表取双向结果,[1][0]代表forward结果,[1][1]代表backward结果,详细说明可在官方文档中搜索'LSTM'
x = self.lstm(x)[0]
# 输出层 - Shape = (Batch Size, Max label len, Signal)
x = self.linear2(x)
# 在计算损失时ctc-loss会自动进行softmax,所以在预测模式中需额外做softmax获取标签概率
if self.is_infer:
# 输出层 - Shape = (Batch Size, Max label len, Prob)
x = paddle.nn.functional.softmax(x)
# 转换为标签
x = paddle.argmax(x, axis=-1)
return x
数据集路径设置
DATA_PATH = "/data/OCR_Dataset"
训练轮数
EPOCH = 10
每批次数据大小
BATCH_SIZE = 16
label_define = paddle.static.InputSpec(shape=[-1, LABEL_MAX_LEN],
dtype="int32",
name="label")
class CTCLoss(paddle.nn.Layer):
def init(self):
"""
定义CTCLoss
"""
super().init()
def forward(self, ipt, label):
input_lengths = paddle.full(shape=[BATCH_SIZE],fill_value=LABEL_MAX_LEN + 4,dtype= "int64")
label_lengths = paddle.full(shape=[BATCH_SIZE],fill_value=LABEL_MAX_LEN,dtype= "int64")
# 按文档要求进行转换dim顺序
ipt = paddle.tensor.transpose(ipt, [1, 0, 2])
# 计算loss
loss = paddle.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=10)
return loss
实例化模型
model = paddle.Model(Net(), inputs=input_define, labels=label_define)
定义优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
为模型配置运行环境并设置该优化策略
model.prepare(optimizer=optimizer,
loss=CTCLoss())
执行训练
model.fit(train_data=Reader(DATA_PATH),
eval_data=Reader(DATA_PATH, is_val=True),
batch_size=BATCH_SIZE,
epochs=EPOCH,
save_dir="output/",
save_freq=1,
verbose=1,
drop_last=True)
与训练近似,但不包含Label
class InferReader(Dataset):
def init(self, dir_path=None, img_path=None):
"""
数据读取Reader(预测)
:param dir_path: 预测对应文件夹(二选一)
:param img_path: 预测单张图片(二选一)
"""
super().init()
if dir_path:
# 获取文件夹中所有图片路径
self.img_names = [i for i in os.listdir(dir_path) if os.path.splitext(i)[1] == ".jpg"]
self.img_paths = [os.path.join(dir_path, i) for i in self.img_names]
elif img_path:
self.img_names = [os.path.split(img_path)[1]]
self.img_paths = [img_path]
else:
raise Exception("请指定需要预测的文件夹或对应图片路径")
def get_names(self):
"""
获取预测文件名顺序
"""
return self.img_names
def __getitem__(self, index):
# 获取图像路径
file_path = self.img_paths[index]
# 使用Pillow来读取图像数据并转成Numpy格式
img = Image.open(file_path)
img = np.array(img, dtype="float32").reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)) / 255
return img
def __len__(self):
return len(self.img_paths)
待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理
INFER_DATA_PATH = "/data/sample_img"
训练后存档点路径 - final 代表最终训练所得模型
CHECKPOINT_PATH = "./output/final.pdparams"
每批次处理数量
BATCH_SIZE = 32
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
sample_idxs = np.random.choice(50000, size=25, replace=False)
for img_id, img_name in enumerate(os.listdir(INFER_DATA_PATH)):
plt.subplot(1, 3, img_id + 1)
plt.xticks([])
plt.yticks([])
im = Image.open(os.path.join(INFER_DATA_PATH, img_name))
plt.imshow(im, cmap=plt.cm.binary)
plt.xlabel("Img name: " + img_name)
plt.show()
编写简易版解码器
def ctc_decode(text, blank=10):
"""
简易CTC解码器
:param text: 待解码数据
:param blank: 分隔符索引值
:return: 解码后数据
"""
result = []
cache_idx = -1
for char in text:
if char != blank and char != cache_idx:
result.append(char)
cache_idx = char
return result
实例化推理模型
model = paddle.Model(Net(is_infer=True), inputs=input_define)
加载训练好的参数模型
model.load(CHECKPOINT_PATH)
设置运行环境
model.prepare()
加载预测Reader
infer_reader = InferReader(INFER_DATA_PATH)
img_names = infer_reader.get_names()
results = model.predict(infer_reader, batch_size=BATCH_SIZE)
index = 0
for text_batch in results[0]:
for prob in text_batch:
out = ctc_decode(prob, blank=10)
print(f"文件名:{img_names[index]},推理结果为:{out}")
index += 1
运行,打印结果
![](https://img.haomeiwen.com/i5653267/ef0a8c33d0e98438.jpg)
![](https://img.haomeiwen.com/i5653267/92ec7b64c29f2c8a.jpg)
网友评论