美文网首页
迁移学习

迁移学习

作者: poteman | 来源:发表于2019-08-05 17:59 被阅读0次
  • 获取预训练模型的权重
import os

from tensorflow.keras import layers
from tensorflow.keras import Model
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 \
    -O /tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
  

from tensorflow.keras.applications.inception_v3 import InceptionV3

local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'

pret_trained_model = InceptionV3(input_shape = (150, 150, 3),
                                include_top = False,
                                weights = None)

pret_trained_model.load_weights(local_weights_file)

for layer in pret_trained_model.layers:
  layer.trainable = False

# pret_trained_model.summary()

last_layer = pret_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_layer.output_shape)
last_output = last_layer.output
  • 定义模型
from tensorflow.keras.optimizers import RMSprop

x = layers.Flatten()(last_output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1, activation='sigmoid')(x)

model = Model(pret_trained_model.input, x)

model.compile(optimizer = RMSprop(lr=0.0001),
             loss = 'binary_crossentropy',
             metrics = ['acc'])
  • 构造batch数据生成器
!wget --no-check-certificate \
        https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
       -O /tmp/cats_and_dogs_filtered.zip

from tensorflow.keras.preprocessing.image import ImageDataGenerator

import os
import zipfile

local_zip = '//tmp/cats_and_dogs_filtered.zip'

zip_ref = zipfile.ZipFile(local_zip, 'r')

zip_ref.extractall('/tmp')
zip_ref.close()

# Define our example directories and files
base_dir = '/tmp/cats_and_dogs_filtered'

train_dir = os.path.join( base_dir, 'train')
validation_dir = os.path.join( base_dir, 'validation')


train_cats_dir = os.path.join(train_dir, 'cats') # Directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs') # Directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats') # Directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')# Directory with our validation dog pictures

train_cat_fnames = os.listdir(train_cats_dir)
train_dog_fnames = os.listdir(train_dogs_dir)

# Add our data-augmentation parameters to ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator( rescale = 1.0/255. )

# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(train_dir,
                                                    batch_size = 20,
                                                    class_mode = 'binary', 
                                                    target_size = (150, 150))     

# Flow validation images in batches of 20 using test_datagen generator
validation_generator =  test_datagen.flow_from_directory( validation_dir,
                                                          batch_size  = 20,
                                                          class_mode  = 'binary', 
                                                          target_size = (150, 150))
  • 训练模型
history = model.fit_generator(
            train_generator,
            validation_data = validation_generator,
            steps_per_epoch = 100,
            epochs = 20,
            validation_steps = 50,
            verbose = 2)
  • 作图显示准确率和loss随epoch的变化
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()

plt.show()

【参考资料】
1.google colab
2.TF官网: Transfer Learning Using Pretrained ConvNets

相关文章

  • 2018-04-15 迁移学习的度量准则

    迁移学习的方法主要包括:基于样本的迁移,基于特征的迁移,基于模型的迁移和基于关系的迁移。 “迁移学习的总体思路可以...

  • 迁移学习简述

    摘要:什么是迁移学习,迁移学习的例子有哪些,在预测建模中如何使用迁移学习?本文将带你一步步深入探讨。 迁移学习是一...

  • 教育心理学第五章 第三节 学习迁移

    教育心理学第五章 第三节 学习迁移 一 学习迁移的定义 学习迁移也称迁移,指一种学习对另一种学习的影响或习得经验对...

  • 2022-03-15

    迁移是一种学习对另一种学习的影响,先前学习对后继学习的影响称为顺向迁移,后继学习对先前学习的影响称为逆向迁移,凡是...

  • 2019年上半年收集到的人工智能迁移学习干货文章

    2019年上半年收集到的人工智能迁移学习干货文章 迁移学习全面指南:概念、项目实战、优势、挑战 迁移学习:该做的和...

  • ResNet18迁移学习-动物多任务分类

    迁移学习 迁移学习的具体内容有很多大佬文章已经说得很清楚了,这里就不献丑了。 本文尝试通过迁移学习,将Pytorc...

  • 读《教育心理学》(十三)

    今日读第十二章《学习的迁移》,该章节包含四大节: 第一节 学习迁移的概述 第二节 学习迁移...

  • 我眼中的迁移学习

    01 什么是迁移学习? 在教学中,我们说的学习迁移也称为训练迁移,是指一种学习对另一种学习的影响,或者习得的经验,...

  • 迁移学习练习-Horses vs. Humans

    最近在Coursera上学习迁移学习,使用的是TensorFlow.我们将采用迁移学习来使用Inception_v...

  • 浅析学习迁移

    浅析学习迁移 ——读《人是如何学习的》有感 迁移的概念,迁移作为一个词语时,它的意思是离开原...

网友评论

      本文标题:迁移学习

      本文链接:https://www.haomeiwen.com/subject/bxomdctx.html