原创:悬鱼铭
图像分类是人工智能中重要的基础任务,也是目标检测、图像分割、目标跟踪等视觉进阶任务的基础,是人工智能从业者必须掌握的知识点。本文通过以下几点阐述:
-
图像分类有哪些落地场景?
-
图像分类有哪些细分任务?
-
图像分类如何实现?(有代码注释)
-
图像分类的损失函数
全文总共3千字左右,阅读时间10分钟!
一、图像分类有哪些落地场景?
图像分类是计算机视觉领域的基础任务,也是应用比较广泛的任务。图像分类用来解决“是什么”的问题,如给定一张图片,用标签描述图片的主要内容,下图中有三个企鹅,标签为企鹅。
来自花瓣网【1】,侵权联系删除
基于图像分类的智能应用也日渐成熟,如在安防监控、智慧交通、医疗影像诊断等。安防监控中,上下班的人脸打卡机器,人们走到机器前,相机采集到人脸图像,机器里的算法判断人脸图像与人脸图像库里的某个人脸最相似,就判断出人脸图像属于那个人。
来自花瓣网【2】,侵权联系删除
交通领域中的交通标识识别,可以辅助驾驶;手机拍照识别花的种类,智能整理相册;电商平台里的输入标签,返回含有标签的商品等。
二、图像分类有哪些细分任务?
图像分类根据标签的不同,大致可分为二分类任务、多分类任务、多标签分类任务。
应对复杂的生活场景,图像分类会有更加细致的任务。在安防监控中,在监控系统中寻找逃犯,从全国各地摄像头采集的图像,面对庞大的图像库,只要找到是逃犯的图像,其他图像都不是逃犯。这里就对图像进行二分类,是逃犯与不是逃犯,这是二分类的任务。
在智能整理手机相册时,每张图像会设置一个标签,同一个标签的图像会放在一个文件夹下,并且以标签来命名,整理完之后会有多个文件夹,这里标签有多个,是多分类的任务。
在短视频推荐中,真实生活场景,图像包含丰富的内容,每张图像不在局限单标签,而是把图像包含的主要内容展示,往多标签发展。对图像进行多标签分类,提供丰富多样的标签,可以促进个性化推荐。
三、图像分类如何实现?
经典的图像分类一般包括预处理、特征提取、分类器,其中特征提取一般通过手工精心设计。研究者会花费大量的精力去探索如何提取到鲁棒性较好的图像特征。深度学习中的卷积神经网络(Convolution Neural Network, CNN)可在大量数据中自动学习到数据的层次化表示。近年来,得益于强大的计算机、更大的数据集,CNN提取图像特征成为主流方法。
基于深度学习的图像分类,将传统的图像分类流程(预处理、特征提取、分类器),全部体现在各种层的组合,有卷积层、池化层、全连接层,图像分类流程如图1所示。训练过程中主要是求解模型的参数,一个输入图片经过多个卷积、池化,它们提取图像特征图,图像特征图拉伸为一维特征向量,连接全连接层,将特征图映射到标签(类别),可知输入图片属于每个标签的概率值。选取概率值最大的标签作为预测的结果。根据推理的结果与图片的真实标签的差距,即为损失函数,再通过梯度下降的方法求解模型参数。参数确定之后,模型就确定了,可以推理测试集中新的图片。
图像多分类任务
图像分类中常用的数据集有CIFAR10【3】,有6万张图像,其中5万张训练集图像,1万张测试集图像,图像大小为 的彩色图像。每张图片一个标签,数据集总共10个标签,有鸟(bird)、太阳(sunset)、狗(dog)、猫(cat)等。
下面进入图像多类别分类实践,以VGG16网络为基准模型,在Pytorch中展示图像多类别分类的训练流程,并且将Pytorch中的数据流展示在下图中。
接下来是使用VGG16,进行图像分类,数据集是CIFAR10。
import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
# 预处理的设置
# 图片转化为 backbone网络规定的图片大小
# 归一化是减去均值,除以方差
# 把 numpy array 转化为 tensor 的格式
image_size = 224
r_mean, g_mean, b_mean = 0.4914, 0.4822, 0.4465
r_std, g_std, b_std = 0.247, 0.243, 0.261
my_tf = transforms.Compose([
transforms.Resize((image_size,image_size)),
transforms.ToTensor(),
transforms.Normalize([r_mean,g_mean,b_mean], [r_std,g_std,b_std])])
# 读取数据集 CIFAR-10 的图,有10个标签,5万张图片,进行预处理。
train_dataset= torchvision.datasets.CIFAR10(root='./',train=True,transform=my_tf,download=True)
test_dataset= torchvision.datasets.CIFAR10(root='./',train=False,transform=my_tf,download=True)
# 调用预训练模型vgg16
my_vgg = torchvision.models.vgg16(pretrained=True)
# 固定网络框架全连接层之前的参数
for param in my_vgg.parameters():
param.requires_grad=False
# 将vgg最后一层输出的类别数,改为cifar-10的类别数(10)
class_size = 10
in_f = my_vgg.classifier[6].in_features
my_vgg.classifier[6] = nn.Linear(in_f,class_size)
# 超参数设置
learn_rate = 0.001
num_epoches = 10
batch_size = 32
momentum = 0.9
# 多分类损失函数,使用默认值
criterion = nn.CrossEntropyLoss()
# 梯度下降,求解模型最后一层参数
optimizer = optim.SGD(my_vgg.classifier[6].parameters(),lr=learn_rate,momentum=momentum)
# 判断使用CPU还是GPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# 图片分批次送入内存(32张图片,batch_size),进行计算。
train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset)
# 训练阶段
my_vgg.to(device)
my_vgg.train()
for epoch in range(num_epoches):
print(f"epoch: {epoch+1}")
for idx,(img,label)in enumerate(train_dataloader):
images = img.to(device)
labels = label.to(device)
output = my_vgg(images)
loss = criterion(output,labels)
loss.backward() # 损失反向传播
optimizer.step() # 更新梯度
optimizer.zero_grad() # 梯度清零
if idx%100==0:
print(f"current loss = {loss.item()}")
# 测试阶段
my_vgg.to(device)
my_vgg.eval() # 把训练好的模型的参数冻结
total,correct = 0 , 0
for img,label in test_dataloader:
images = img.to(device)
labels = label.to(device)
output = my_vgg(images)
_,idx = torch.max(output.data,1) # 输出最大值的位置
total += labels.size(0) # 全部图片
correct +=(idx==labels).sum() # 正确的图片
print(f"accuracy:{100.*correct/total}")
运行以上代码,epoch为1的时候,损失整体是逐步下降的,训练的结果如下图。
四、图像分类的损失函数
最后回顾一下图像分类的三种常见情况,扩展其对应的损失函数。
4.1 二分类
数据集有两个类别,每张图片是两个类别中的一个,标签为0或1。比如,训练一个图像分类器,判断一张输入图片是否为鸟。
二分类的损失函数的计算方式如公式1所示,其中是sigmoid激活函数。激活函数是将数值变换到(0,1)之间,图展示了sigmoid函数,是递增的函数,当x等于0的时候,sigmoid函数的值为0.5,,随着x的增大,函数趋近1,随着x的减小,函数趋近0。
Sigmoid函数
4.2 多分类
数据集有个类别,每张图片是个类别中的一个,每张图片只有一个标签。比如,判断图片中的一个物体是鸟、人、狗、猫。
sigmoid的函数主要是在二分类中,而多分类的网络一般用softmax作为最后一层,输出为预测结果(每个的概率值),然后计算交叉熵损失,公式如4所示。
4.3 多标签分类
数据集有个类别,每张图片可有多个标签,标签是来源于个类别。比如:图片有多个物体,判断这些物体是鸟、人、狗、猫等。
二分类和多分类的情况是限定了一张图片一个类别。事实上,类别下面会有更细致的分类,图片里也会有多个物体,一张图片对应多个标签(一个标签表示一个物体),并且每张图片的标签数量是不固定的。当我们把标签看成相互独立,可以把多标签问题看成对每个标签的二分类。网络框架的最后一层是sigmoid激活函数,然后计算二分类的交叉熵损失,公式如5所示。
今天分享到这里,希望对你有所帮助!
参考资料:
【1】多卷菌采集到这一刻,文字太轻。(61图)_花瓣 (huaban.com)
【2】Dribbble - face_recognition.png by Rico
【3】CIFAR-10 and CIFAR-100 datasets (toronto.edu)
【4】开课吧人工智能核心课程
网友评论