美文网首页
TensorFlow(一)模型参数名称和数值读取

TensorFlow(一)模型参数名称和数值读取

作者: 续袁 | 来源:发表于2019-05-13 14:11 被阅读0次

1. 模型参数名称读取(新式4个模型文件)

image.png
import os
import re
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

model_exp = "./log/vgg16/fine_tune"
model_exp = "model_vgg16"
#model_exp = "model-lenet"

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        #raise load_modelValueError('No meta file found in the model directory (%s)' % model_dir)
        print("No meta file found in the model directory ")
    elif len(meta_files) > 1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)  # 通过checkpoint文件找到模型文件名
    if ckpt and ckpt.model_checkpoint_path:
        # ckpt.model_checkpoint_path表示模型存储的位置,不需要提供模型的名字,它回去查看checkpoint文件
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file


meta_file, ckpt_file = get_model_filenames(model_exp)
data=open("data.txt",'w+')
print('Metagraph file: %s' % meta_file, file=data)  #打印内容到TXT文件中
print('Checkpoint file: %s' % ckpt_file, file=data)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)

reader = pywrap_tensorflow.NewCheckpointReader(os.path.join(model_exp, ckpt_file))
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))   ##该语句打印模型参数变量的数值
    print('tensor_name: %s'% key, file=data)
    print(reader.get_tensor(key), file=data)
data.close()
with tf.Session() as sess:
    saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
    saver.restore(tf.get_default_session(),
                  os.path.join(model_exp, ckpt_file))
    #print(tf.get_default_graph().get_tensor_by_name("Logits/weights:0"))

2.模型参数名称读取(旧式1个模型文件)

image.png
import os
import re
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

model_exp = "./log/vgg16/fine_tune"
model_exp = "model_vgg16"
model_exp = "model_vgg16_0"

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        #raise load_modelValueError('No meta file found in the model directory (%s)' % model_dir)
        print("No meta file found in the model directory ")
    elif len(meta_files) > 1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)  # 通过checkpoint文件找到模型文件名
    if ckpt and ckpt.model_checkpoint_path:
        # ckpt.model_checkpoint_path表示模型存储的位置,不需要提供模型的名字,它回去查看checkpoint文件
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file


#meta_file, ckpt_file = get_model_filenames(model_exp)
ckpt_file ="vgg_16.ckpt"
data=open("data_name00.txt",'w+')
#打印内容到TXT文件中
print('Checkpoint file: %s' % ckpt_file, file=data)
print('Checkpoint file: %s' % ckpt_file)

reader = pywrap_tensorflow.NewCheckpointReader(os.path.join(model_exp, ckpt_file))
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    #print(reader.get_tensor(key))   ##该语句打印模型参数变量的数值
    print('tensor_name: %s'% key, file=data)
    #print(reader.get_tensor(key), file=data)
data.close()

3. 打印结果

3.1 输出结果到TXT中

data=open("data.txt",'w+')
print('Metagraph file: %s' % meta_file, file=data)  #打印内容到TXT文件中
print('Checkpoint file: %s' % ckpt_file, file=data)
print('tensor_name: %s'% key, file=data)
print(reader.get_tensor(key), file=data)
data.close()

TensorFlow(一)InceptionResnetV1参数结构
TensorFlow(一)模型InceptionResnetV1参数结构和参数值展示

参数资料

[1] 【python】读取和输出到txt
[2] 如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图
[3] Tensorflow 模型文件结构、模型中Tensor查看
[4] 查看TensorFlow checkpoint文件中的变量名和对应值
[5] 输出TensorFlow中checkpoint内变量的几种方法

相关文章

网友评论

      本文标题:TensorFlow(一)模型参数名称和数值读取

      本文链接:https://www.haomeiwen.com/subject/osylaqtx.html