美文网首页
pytorch模型转keras模型

pytorch模型转keras模型

作者: AI算法_图哥 | 来源:发表于2019-10-02 11:22 被阅读0次
在这里插入图片描述

1. 概述

使用pytorch建立的模型,有时想把pytorch建立好的模型装换为keras,本人使用TensorFlow作为keras的backend

2. 依赖

依赖的标准库:

  • pytorch
  • keras
  • tensorflow
  • pytorch2keras

3. 安装方式

git clone https://github.com/nerox8664/pytorch2keras.git
python setup.py install

4. 代码

import numpy as np
import torch
from torch.autograd import Variable
from pytorch2keras import converter

class Pytorch2KerasTestNet(torch.nn.Module):
    def __init__(self):
        super(Pytorch2KerasTestNet, self).__init__()
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        y = self.relu(self.in1(self.conv1(x)))
        return y


class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        
        print("conv2d")
        out = self.conv2d(out)
        return out

def check_error(output, k_model, input_np, epsilon=1e-5):
    pytorch_output = output.data.numpy()
    keras_output = k_model.predict(input_np)

    error = np.max(pytorch_output - keras_output)
    print('Error:', error)

    assert error < epsilon
    return error        

model   = Pytorch2KerasTestNet()
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
input_var = Variable(torch.FloatTensor(input_np))
output = model(input_var)
k_model = converter.pytorch_to_keras(model, input_var, [(3, 224, 224,)], verbose=True)
k_model.summary()

max_error = 0
error = check_error(output, k_model, input_np)
if max_error < error:
    max_error = error
print('Max error: {0}'.format(max_error))

#保存模型
k_model.save('my_model.h5')

# 重新载入模型
from keras.models import load_model
import tensorflow as tf

model = load_model('my_model.h5',custom_objects={"tf": tf})
model.summary()

输出结果:

Layer (type)                 Output Shape              Param #   
=================================================================
input_0 (InputLayer)         (None, 3, 224, 224)       0         
_________________________________________________________________
5 (Lambda)                   (None, 3, 232, 232)       0         
_________________________________________________________________
6 (Conv2D)                   (None, 32, 224, 224)      7808      
_________________________________________________________________
7 (Lambda)                   (None, 32, 224, 224)      0         
_________________________________________________________________
output_0 (Activation)        (None, 32, 224, 224)      0         
=================================================================
Total params: 7,808
Trainable params: 7,808
Non-trainable params: 0

5. 最后

image

相关文章

  • pytorch模型转keras模型

    1. 概述 使用pytorch建立的模型,有时想把pytorch建立好的模型装换为keras,本人使用Tensor...

  • pytorch转caffe2 之 onnx转caffe2报错的解

    目标:将 pytorch模型 转为 onnx模型 再转为 caffe2模型,得到两个.pb文件 pytorch转o...

  • 自动部署深度神经网络模型TensorFlow(Keras)到生产

    目录 Keras简介 Keras模型分类 Keras模型部署准备 默认部署Keras模型 自定义部署Keras模型...

  • CV-字符识别模型

    Pytorch构建CNN模型 Pytorch中构建CNN模型只需要定义好模型的参数和正向传播就可以,Pytorch...

  • 30s上手Keras

    Keras的核心数据结构是“模型”,模型是一种组织网络层的方式。Keras中主要的模型是Sequential模型,...

  • keras 基本概念

    keras的核心数据结构是模型,模型是一种组织网络层的方式。Keras中主要的模型是sequential模型,其实...

  • pytorch finetune模型

    pytorch finetune模型 文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变...

  • 构建高级模型(05)

    函数式 API tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。使用 Keras 函...

  • PyTorch模型保存深入理解

    前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了...

  • 可视化创建的深度学习模型

    深度学习模型创建好后,有几种方式可以可视化 keras.Model.summary()方式查看模型摘要 keras...

网友评论

      本文标题:pytorch模型转keras模型

      本文链接:https://www.haomeiwen.com/subject/kxolpctx.html