在学习了使用 TensorFlow 的 CNN 进行图像分类之后,现在对这些方法做一个简单的拓展,即来处理多任务多标签的情形。为了便于说明,我们假设现在要对 0-9 这 10 个数字 和 A-Z (排除 I、O) 这 24 个字母进行识别,所有的数据都使用 captcha 生成(读过 TensorFlow 训练 CNN 分类器 这篇文章的读者应该不陌生了)。以下的代码(命名为 generate_train_data.py)使用 captcha 生成了 100000 万张 28 x 28 的图像,每张图像都是带有大量噪声的一个字符(所有字符见下面代码中的 alphabets 列表,所有的图像保存在文件夹 ./datasets/images 中,每张图像命名为 image图像序号_类标号.jpg,其中的类标号为该字符在列表 alphabets 中的下标)。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018
@author: shirhe-lyh
"""
import cv2
import numpy as np
from captcha.image import ImageCaptcha
def generate_captcha(text='1'):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
if __name__ == '__main__':
output_dir = './datasets/images/'
alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z']
for i in range(100000):
label = np.random.randint(0, 34)
image = generate_captcha(alphabets[label])
image_name = 'image{}_{}.jpg'.format(i+1, label)
output_path = output_dir + image_name
cv2.imwrite(output_path, image)
我们的目的是训练一个简单的 CNN 模型将对这些图像进行分类,由于这个问题很简单,直接训练一个 34 类的分类器就达成目标了。但类别数越大,训练就越困难,因此我们采取另一种分化的策略,将这个 34 类的问题分为两个子问题,分别是:1.只识别数字;2.只识别字母。之所以可以这么分,是因为 数字 和 字母 的差别很大,完全可以认为它们属于两种不同的范畴,从而可以看成独立的分类任务来处理。这样我们现在的问题是:怎样同时识别 10 个数字和 24 个字母?这是一个多任务多标签问题:我们要处理识别数字和识别字母这两个任务,其中每个任务都是涉及多个标签(分别是 10 个标签和 24 个标签)。
虽然这篇文章举例的这个问题非常简单,但这个方法(再加上预训练模型技巧)可以用于更加复杂的问题,比如 阿里的 FashionAI 服饰属性识别全球挑战赛,感兴趣的朋友可以用 ResNet-50 预训练模型去微调一个 8 任务模型。
本文的所有代码见 github:multi_task_test,欢迎访问交流并反馈问题!
一、多分支 CNN 模型定义
虽然我们要处理的是两个独立的任务,但我们希望这两个任务共用大部分的神经网络层,这样既可以节省计算量,一般来说,也可以提升准确率。因此,我们将要定义的神经网络结构设计为(所有共用的层在文章 TensorFlow-slim 训练 CNN 分类模型 中用来识别 0-9 这 10 个数字):
当获取了一张图像(数字或字母)之后,将它送入第一个卷积层(conv1)、第二个卷积层(conv2)、······,直到第二个全连接层(fc2),到此为止,这些层都是两个任务共用的,它们的作用是用来提取图像特征。然后,针对两个不同的任务,将网络分为两个分支,一个用于输出该图像是各个数字的概率(digits_output),另一个用于输出该图像是各个字母的概率(letters_output)。网络的具体定义如下(网络各层的名字可能和上图不一致):
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
net = preprocessed_inputs
net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.5,
is_training=self._is_training)
net = slim.fully_connected(net, 512, scope='fc1')
net = slim.fully_connected(net, 512, scope='fc2')
prediction_dict = {}
for class_name, num_classes in self.num_classes_dict.items():
logits = slim.fully_connected(net, num_outputs=num_classes,
activation_fn=None,
scope='Predict/' + class_name)
prediction_dict[class_name] = logits
return prediction_dict
从以上代码可以看到,多任务多标签任务的 CNN 定义也非常简单,只需要引入一个 for 循环即可。接下来,要定义损失函数和准确率函数。
在生成图像的时候,图片名字命名的模式是 image图像序号_类标号.jpg,比如,假设第 1 张图像是字母 G,那么它的类标号是 16 = 10 + 7 - 1,因此它的名字是 image1_16.jpg。但这个类标号 16 是基于所有 34 个类来说的,实际上,如果只限于字母来说,它的类标号应该是 6。之所以对数字和字母使用统一的类标号,其实是为了便于定义损失和准确率函数。原因在于:对字母 G,因为我们现在是独立处理数字和字母两个分支任务,因此 G 应该只对分类字母的分支贡献损失,而不应当对分类数字的分支产生损失。如果统一对数字和字母分配类标号,那么 G 的类标号 16 的独热(one-hot)编码是:
0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
其中的 - 是为了便于看清两个任务的分界线,实际请忽略。此时,在计算损失时,将这个独热编码一分为二:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
前一部分对应于 G 在分类 0-9 这 10 个数字的任务内的(非严格独热)编码,因为全部为 0,因此在计算分类交叉熵的时候损失为 0,这是我们期望的;后一部分恰好是 G 在分类 A-Z(排除 I、O)这 24 个字母的任务内的独热编码,正好用于计算分类交叉熵,也是我们期望的,可见统一分配类标号在计算损失时是非常方便的。了解了这一点之后,损失函数的定义如下:
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each task.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
for class_name in self.num_classes_dict:
weights = tf.cast(tf.greater(
tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
dtype=tf.float32)
slim.losses.softmax_cross_entropy(
logits=prediction_dict[class_name],
onehot_labels=onehot_labels_dict[class_name],
weights=weights,
scope='Loss/' + class_name)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def _onehot_groundtruth_dict(self, groundtruth_lists):
"""Transform groundtruth lables to one-hot formats.
Args:
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for task.
Returns:
onehot_labels_dict: A dictionary mapping strings (class names)
to one-hot lable tensors.
"""
one_hot = tf.one_hot(
groundtruth_lists, depth=sum(self.num_classes_dict.values()))
onehot_labels_dict = {}
start_index = 0
for class_name in self._class_order:
onehot_labels_dict[class_name] = tf.slice(
one_hot, [0, start_index],
[-1, self.num_classes_dict[class_name]])
start_index += self.num_classes_dict[class_name]
return onehot_labels_dict
其中,函数 _onehot_groundtruth_dict 用于将统一分配的类标号对应的独热编码分为数字和字母这两个任务对应的两个独热编码,之后的 loss 函数就可以用来计算正常的分类交叉熵损失。为了确保全 0 的独热编码对应 0 的损失,定义了 weights 这一个变量,它的作用是:当编码为全 0 时,该样本对应的损失权重为 0,因此贡献的损失为 0,即不属于这个分类任务的样本对这个分类任务的损失贡献为 0(虽然理论上全 0 的独热编码对应的分类交叉熵为 0,但为了确保这点而不出现意外,weights 是非常必要的)。
至于,准确率函数的定义则更简单,想法如下:当一张图像经过神经网络预测后,我们得到两个分支任务的概率输出,我们不关心它来源于哪个任务,因为这不影响准确率的计算;分别对两个任务的概率输出取 tf.argmax
得到在每个任务内的预测类标号,然后对这两个预测的类标号再计算它在对应任务内的独热编码,把这两个独热编码与上面计算损失时切割得到的两个独热编码分别按对应元素求和,如果求和结果中出现 2 说明预测结果正确,否则错误;对一个批量中的所有图像累计处理之后,即可算出准确率。继续上面的例子,前面已经说过,G 的类标号 16 对应的独热编码一分为二的结果为:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
假如现在神经网络的两个分支预测的类标号分别为 1 和 6,那么它们分别对应独热编码:
0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
以上独热编码按位置对应相加,得到:
0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
前一个结果(即 0 1 0 0 0 0 0 0 0 0
)所有位置上都没有出现 2,说明预测和实际的类标号没有重合,对准确率没有产生作用;后一个结果(即 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
) 中,第 6 个索引位置出现 2 说明预测和实际的类标号是一样的,因此预测正确,预测正确数加 1。显然,每一张图像要么加 0 (两个任务都预测错误)要么加 1(其中一个任务预测正确),因此这样计算准确率是正确的(不可能加 2,因为实际的两个独热编码中,其中的一个全是 0)。详细的细节请参考如下完整代码(将其命名为 model.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from abc import ABCMeta
from abc import abstractmethod
slim = tf.contrib.slim
class BaseModel(object):
"""Abstract base class for any model."""
__metaclass__ = ABCMeta
def __init__(self, num_classes_dict):
"""Constructor.
Args:
num_classes: Number of classes.
"""
self._num_classes_dict = num_classes_dict
@property
def num_classes_dict(self):
return self._num_classes_dict
@abstractmethod
def preprocess(self, inputs):
"""Input preprocessing. To be override by implementations.
Args:
inputs: A float32 tensor with shape [batch_size, height, width,
num_channels] representing a batch of images.
Returns:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, widht, num_channels] representing a batch of images.
"""
pass
@abstractmethod
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
pass
@abstractmethod
def postprocess(self, prediction_dict, **params):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
pass
@abstractmethod
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
pass
class Model(BaseModel):
"""xxx definition."""
def __init__(self,
is_training,
num_classes_dict={'digits': 10, 'letters': 24}):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
super(Model, self).__init__(num_classes_dict=num_classes_dict)
self._is_training = is_training
self._class_order = ['digits', 'letters']
def preprocess(self, inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = tf.to_float(inputs)
preprocessed_inputs = tf.subtract(preprocessed_inputs, 128.0)
preprocessed_inputs = tf.div(preprocessed_inputs, 128.0)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
net = preprocessed_inputs
net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.5,
is_training=self._is_training)
net = slim.fully_connected(net, 512, scope='fc1')
net = slim.fully_connected(net, 512, scope='fc2')
prediction_dict = {}
for class_name, num_classes in self.num_classes_dict.items():
logits = slim.fully_connected(net, num_outputs=num_classes,
activation_fn=None,
scope='Predict/' + class_name)
prediction_dict[class_name] = logits
return prediction_dict
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
postprecessed_dict = {}
for class_name in self.num_classes_dict:
logits = prediction_dict[class_name]
# logits = tf.nn.softmax(logits, name=class_name)
postprecessed_dict[class_name] = logits
return postprecessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each task.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
for class_name in self.num_classes_dict:
weights = tf.cast(tf.greater(
tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
dtype=tf.float32)
slim.losses.softmax_cross_entropy(
logits=prediction_dict[class_name],
onehot_labels=onehot_labels_dict[class_name],
weights=weights,
scope='Loss/' + class_name)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def _onehot_groundtruth_dict(self, groundtruth_lists):
"""Transform groundtruth lables to one-hot formats.
Args:
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for task.
Returns:
onehot_labels_dict: A dictionary mapping strings (class names)
to one-hot lable tensors.
"""
one_hot = tf.one_hot(
groundtruth_lists, depth=sum(self.num_classes_dict.values()))
onehot_labels_dict = {}
start_index = 0
for class_name in self._class_order:
onehot_labels_dict[class_name] = tf.slice(
one_hot, [0, start_index],
[-1, self.num_classes_dict[class_name]])
start_index += self.num_classes_dict[class_name]
return onehot_labels_dict
def accuracy(self, postprocessed_dict, groundtruth_lists):
"""Calculate accuracy.
Args:
postprocessed_dict: A dictionary containing the postprocessed
results
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
accuracy: The scalar accuracy.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
num_corrections = 0.
for class_name in self.num_classes_dict:
predicted_argmax = tf.argmax(tf.nn.softmax(
postprocessed_dict[class_name]), axis=1)
onehot_predicted = tf.one_hot(
predicted_argmax, depth=self.num_classes_dict[class_name])
onehot_sum = tf.add(onehot_labels_dict[class_name],
onehot_predicted)
correct = tf.greater(onehot_sum, 1)
num = tf.reduce_sum(tf.cast(correct, tf.float32))
num_corrections += num
total_nums = tf.cast(tf.shape(groundtruth_lists)[0], dtype=tf.float32)
accuracy = num_corrections / total_nums
return accuracy
在定义 postprocess 函数时,我把语句:
logits = tf.nn.softmax(logits, name=class_name)
注释掉了(这显得这个函数没有任何用处),我的本意是为了观察 predict 函数中两个网络分支的最本原输出,主要考虑的是:当一张图片送到网络入口时,如果根本不知道它是数字还是字母,那么经过神经网络处理后,我们面临着两个任务的输出,要怎么判断它属于哪个任务中的哪个标签呢?如果我们已经知道这张图像来源于其中某一个任务,比如来源于数字任务,那么直接对数字任务分支的输出取 tf.argmax 就知道它对应的预测标签了。但现在的关键问题是,如果不知道它属于其中哪个任务,能否根据两个分支的输出直接判断出来呢?答案是可以的,尽管这是基于经验观察的。通过模型训练并导出为 .pb 文件之后,运行 evaluate.py 文件(很多次),可以观察两个分支的直接输出,你会发现两个任务中所有这些输出的最大值对应的标签就是网络的预测输出,也就是说:可以通过比较两个任务的所有输出,来预测图像来源于哪个任务(进而预测属于哪个标签)——所有输出的值中,最大值所在的任务就可以认为是图像来源的任务。
二、模型训练与保存
因为模型训练的代码和文章 TensorFlow-slim 训练 CNN 分类模型(续) 中 train.py 的是一样的,这里直接忽略(也可以访问 github:multi_task_test 获取本文所有代码)。
当你获取到代码后,首先在项目当前目录下新建文件夹 datasets/images,然后在当前目录下的终端运行
python3 generate_train_data.py
生成 100000 张训练图像。之后,继续运行
python3 generate_tfrecord.py \
--images_path ./datasets/images/ \
--output_path ./datasets/train.record
得到训练的 .record 文件。 此时,在项目目录下再新建文件夹 training,接着在终端执行如下命令
python3 train.py --record_path ./datasets/train.record --logdir ./training/
便开始了训练过程。如果你要可视化的观看损失和准确率的变化情况,在当前目录下的终端执行
tensorboard --logdir ./training/
得到本地浏览器链接,打开这个链接即可监控训练的全过程。比如,我训练 5000 多次之后,准确率和损失的图像如下:
Tensorboard 显示的准确率和损失曲线当你觉得训练的准确率已经足够高了,并且文件夹 training 中也保存好了当前训练次数的模型文件之后,使用 Ctrl + C 中断训练过程。接下来,就是将 training 中的训练模型文件 .ckpt 转化为 .pb 文件,然后测试训练效果了。有关自定义的将 .ckpt 格式转化为 .pb 格式的模型文件请访问文章 TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式。在那篇文章中,已经指出,需要针对不同的分类模型做出改变的地方主要是包含 model 参数的那些函数,尤其是由输入得到输出的函数 _add_output_tensor_nodes。比如,我们这篇文章有两个分支任务的输出,对应的函数 _add_output_tensor_nodes 修改为:
def _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name='inference_op'):
"""Adds output nodes.
Adjust according to specified implementations.
Adds the following nodes for output tensors:
* classes: A float32 tensor of shape [batch_size] containing class
predictions.
Args:
postprocessed_tensors: A dictionary containing the following fields:
'classes': [batch_size].
output_collection_name: Name of collection to add output tensors to.
Returns:
A tensor dict containing the added output tensor nodes.
"""
outputs = {}
for class_name, logits in postprocessed_tensors.items():
outputs[class_name] = tf.identity(logits, name=class_name)
for output_key in outputs:
tf.add_to_collection(output_collection_name, outputs[output_key])
return outputs
其它函数不需要修改,完整文件请查看 github:multi_task_test 的 export.py 文件。然后,在项目的当前目录终端执行模型导出命令:
python3 export_inference_graph.py \
--trained_checkpoint_prefix ./training/model.ckpt-5265 \
--output_directory ./training/inference_graph_pb
你会在 training 文件夹中看到一个新的文件夹 inference_graph_pb,里面的文件 frozen_inference_graph.pb 就是我们用来做模型推断的文件。上面一条命令中的 model.ckpt-5265 请根据你自己的训练情况做修改,这里我是只训练了 5000 多次,然后使用训练了 5265 次的模型用于图像推断。
当你一切都顺利执行之后,恭喜你来到最后一步,是时候验证一下你训练的模型的效果了。写个简单的模型验证文件 evaluate.py:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 14:02:05 2018
@author: shirhe-lyh
"""
"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py \
--frozen_graph_path: Path to model frozen graph.
"""
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha
flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS
def generate_captcha(text='1'):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
def main(_):
alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z']
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
inputs = model_graph.get_tensor_by_name('image_tensor:0')
digits = model_graph.get_tensor_by_name('digits:0')
digit_classes = tf.argmax(tf.nn.softmax(digits), axis=1)
letters = model_graph.get_tensor_by_name('letters:0')
letter_classes = tf.argmax(tf.nn.softmax(letters), axis=1)
for i in range(10):
label = np.random.randint(0, 34)
image = generate_captcha(alphabets[label])
image_np = np.expand_dims(image, axis=0)
predicted_ = sess.run([digits, digit_classes,
letters, letter_classes],
feed_dict={inputs: image_np})
predicted_digits = np.round(predicted_[0], 2)
predicted_digit_classes = predicted_[1]
predicted_letters = np.round(predicted_[2], 2)
predicted_letter_classes = predicted_[3]
print(predicted_digits, '----', predicted_digit_classes)
print(predicted_letters, '----', predicted_letter_classes)
predicted_label = predicted_letter_classes[0] + 10
if label < 10:
predicted_label = predicted_digit_classes[0]
print(alphabets[predicted_label], ' vs ', alphabets[label])
if __name__ == '__main__':
tf.app.run()
在终端执行如下命令,进行模型评估:
python3 evaluate.py \
--frozen_graph_path ./training/inference_graph_pb/frozen_inference_graph.pb
你可以仔细的观察最后两个分支的直接输出,看看最大值对应的那个任务是否恰好是验证图像实际来源的任务。
预告:下一篇文章将要介绍如何用 TensorFlow 实现 生成对抗网络,敬请期待!
网友评论