美文网首页
调用Keras的vgg16模型进行测试

调用Keras的vgg16模型进行测试

作者: 不求上进的夏天 | 来源:发表于2020-07-26 20:40 被阅读0次
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
import os
import re

path = '/content/drive/My Drive/test/calibration_set'
val_path = '/content/drive/My Drive/test/calibration_set/val.txt'
synsets_path = '/content/drive/My Drive/test/calibration_set/synsets.txt'
def ReadTxtName(rootdir):  # 读取txt文件内容
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            lines.append(line)
    return lines


val = ReadTxtName(val_path)
synset = ReadTxtName(synsets_path)

#  建立图片文件名到label的映射关系
map_list = []
num_list = []
for map in val:
  val_num = re.findall('val_(\d+)', map)
  val_mapping = re.findall('JPEG\s(\d+)', map)
  num_list.append(val_num)
  map_list.append(val_mapping)
  
a = [x for y in num_list for x in y]
b = [x for y in map_list for x in y]

synset_list = []
for x in b:
  idx = int(x)
  synset_list.append(synset[idx])
map_dict = dict(zip(a,synset_list))  # 映射字典


label = []
for imgs in os.listdir(path):
  if imgs.endswith('.JPEG'):
    number = re.findall('val_(\d+)',imgs)
    lbl = map_dict[number[0]]
    label.append(lbl)

label = np.array(label)  # 根据映射字典获取测试集标签
def vgg_predict(img):  # 返回top5的编号
  image_data = img_to_array(img)
  image_data = image_data.reshape((1,) + image_data.shape)
  image_data = preprocess_input(image_data)
  prediction = model.predict(image_data)
  results = decode_predictions(prediction, top=5)
  results = np.array(results)
  result_num = np.squeeze(results[:,:,0])
  return result_num


model = VGG16(weights='imagenet', include_top=True)

pre_result = []
for imgs in os.listdir(path):
  if imgs.endswith('.JPEG'):
    img_path = os.path.join(path,imgs)
    img = load_img(img_path,target_size=(224, 224))
    result = vgg_predict(img)
    pre_result.append(result)
pre_result = np.array(pre_result)  


cnt = 0
for i in range(1000):  # 判断预测结果中是否包含真实标签
  y = label[i]
  y_hat = pre_result[i]
  if(y_hat.__contains__(y)):
    cnt += 1
print('top5 error:'1-cnt/1000)  # top5 error: 0.136

相关文章

网友评论

      本文标题:调用Keras的vgg16模型进行测试

      本文链接:https://www.haomeiwen.com/subject/zbeglktx.html