本文介绍 TensorFlow1.5 新年全新教程(系列)
TensorFlow1.5 新年全新教程(系列)
This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat:
jintianiloveu
很久没有更博客了,眨眼都已经2018年了,遥想去年跨年就好像发生在前天一样,预祝大家2019年猪年大吉。
闲话不多说。在家呆久了不学点东西感觉心虚,科技发展这么快,不脚踏实地开疆拓土怎么行呢?新年就要有新气象嘛,作为一位人工智能行业从业者,希望以一个过来的人的身份,带领更多的人在这条道路上披荆斩棘,开拓新的领域。工欲善其事必先利其器,TensorFlow1.5都已经发布了,我们还有什么理由不去学习一下最新的tf.data.Dataset API? 还有什么理由不期待一下TensorFlow Lite的终极版本以及专属于移动端的模型存储框架FlatBuf…感觉科技又前进了一个世纪,不过没有关系。凡事都得从当下做起。自从1.5版本发布 之后,tensorflow里面的很多API都将冻住了,并且会越来越规范化,为的正式迎接2018年深度学习应用落地的爆发之年。
闲话就说到这里了。我们首先从tensorflow的最新dataset API说起。
开始之前给大家安利一个工具:alfred, 专门为深度学习打造的工具,欢迎大家star, fork,enhance。我们接下来用它来随时爬几张猪啊狗啊的图片。
tf.data.Dataset
这个以前是在contrib下面的一个接口,现在放到了data下面,可以说是非常正统的tensorflow数据导入接口了。以前都是用tfrecords,现在不管是从单张图片,从文件夹路径,还是从numpy array类型的数据,都非常方便了。
假设我们有一个图片分类的简单任务。我们的目录是这样的:
-data
|-dog
|-pig
|-...
这个猪啊狗啊的图片alfred可以帮你爬取:
sudo pip3 install alfred-py
alfred scrap image -q 'dog'
alfred scrap image -q 'pig'
每个类别装了许多同一类的图片。那直接读取到python的list,然后转成tensor,通过tf.data.Dataset
就可以读入到tensorflow里面。
import tensorflow as tf
import os
NUMC_CLASSES = 2
def load_image():
train_dir = 'data'
all_classes = []
all_images = []
all_labels = []
for i in os.listdir(train_dir):
current_dir = os.path.join(train_dir, i)
if os.path.isdir(current_dir):
all_classes.append(i)
for img in os.listdir(current_dir):
if img.endswith('png') or img.endswith('jpg'):
all_images.append(os.path.join(current_dir, img))
all_labels.append(all_classes.index(i))
return all_images, all_labels, all_classes
def train():
all_images, all_labels, all_classes = load_image()
print(all_classes)
# convert all images list to tensor, using Dataset API to load
train_data = tf.data.Dataset.from_tensor_slices((tf.constant(all_images), tf.constant(all_labels)))
iterator = tf.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
next_elem = iterator.get_next()
train_init_op = iterator.make_initializer(train_data)
with tf.Session() as sess:
sess.run(train_init_op)
while True:
try:
print(sess.run(next_elem))
except tf.errors.OutOfRangeError:
print('data iterator finish.')
break
if __name__ == '__main__':
train()
我们可以看到输出结果是:
['dog', 'pig']
(b'dog_00.jpg', 0)
(b'dog_01.jpg', 0)
(b'pig_00.jpg', 1)
(b'pig_01.jpg', 1)
(b'pig_010.jpg', 1)
(b'pig_02.jpg', 1)
(b'pig_03.jpg', 1)
(b'pig_04.jpg', 1)
(b'pig_05.jpg', 1)
(b'pig_06.jpg', 1)
(b'pig_07.jpg', 1)
(b'pig_08.jpg', 1)
(b'pig_09.jpg', 1)
data iterator finish.
图片和标签都已经获得。用最新的Dataset API中的 from_tensor_slices
可以非常方便的从list中将数据导入。
很多时候我们都需要对图片进行预处理,比如我们需要做一个检测数据集,我们要读入label和bbox,这个时候label需要one-hot,我们就需要对这个东西进行预处理,这个时候map就有用了。
tf.data.Dataset.map
这还没有完,我们的目的是操作每一张图片,做一些变换。或者对label进行一些处理,比如one-hot。在最新的dataset API中也有map函数进行操作。可以在这个map方法里,指定所有应有的操作。
def input_map_fn(img_path, label):
# do some process to label
one_hot = tf.one_hot(label, NUMC_CLASSES)
img_f = tf.read_file(img_path)
img_decodes = tf.image.decode_image(img_f, channels=3)
return img_decodes, one_hot
然后将train_data加上即可。
train_data = train_data.map(input_map_fn)
最终我们可以看到熟悉的,图片值 + one_hot label的训练数据。如果是对于像多标签分类,目标检测这样的任务label,也是做同样的处理。只要能保证前期的输入能在后期的网络中拿到就行了。
好了,现在tensorflow全新的数据导入API应该已经融会贯通了。下一篇大家等待更新,博主这还得去乡下拜个年。
网友评论