evaluate.py
from __future__import print_function
import matplotlib.pyplotas plt
import argparse
from randomimport shuffle
import random
import os
from netimport *
import pandasas pd
import numpyas np
parser = argparse.ArgumentParser(description='')
parser.add_argument("--image_size", type=int, default=64, help="load image size")# 网络输入的尺度#default=256
parser.add_argument("--snapshots", default='./log', help="Path of Snapshots")# 读取训练好的模型参数的路径
parser.add_argument("--test_data_path", default='./dataset/test/', help="path of x training datas.")# 高信噪比光谱数据的训练图片路径
parser.add_argument("--out_dir_x", default='./test_output/x/', help="Output Folder")# 保存x域的输入图片与生成的y域图片的路径
parser.add_argument("--out_dir_y", default='./test_output/y/', help="Output Folder")# 保存y域的输入图片与生成的x域图片的路径
args = parser.parse_args()
hang_num=11
#df_one=pd.read_csv(r'F:/bu/data/sdss/new_experiment/new_lowtest.csv',header=None)
df_one=pd.read_csv(r'F:/bu/data/sdss/测试下载数据/low.csv',header=None)
df_one=df_one.ix[:,:4096]
x_image_resize=[]
for iin range(hang_num):
print("...low...")
data_list=df_one.ix[i,:4096]
list_df_low = []
for jin range(4096):
#print(j)
df_low_scal = (data_list[j]-np.mean(data_list)) / (max(data_list) -min(data_list))
list_df_low.append(df_low_scal)
re_list = np.array(list_df_low).reshape(64, 64, 1)
x_image_resize.append(re_list)
print(len(x_image_resize))
df_two=pd.read_csv(r'F:/bu/data/sdss/测试下载数据/high.csv',header=None)
y_image_resize=[]
for iin range(hang_num):
print("...high...")
data_list=df_two.ix[i,:]
list_df_high = []
for jin range(4096):
df_high2 = (data_list[j] - np.mean(data_list))/ (max(data_list) -min(data_list))
list_df_high.append(df_high2)
re_list = np.array(list_df_high).reshape(64, 64, 1)
y_image_resize.append(re_list)
print(len(y_image_resize))
def make_list_number_same(x_input_images_raw, y_input_images_raw):# add_train_list函数将x域和y域的图像数量变成一致
if len(x_input_images_raw) ==len(y_input_images_raw):# 如果x域和y域图像数量本来就一致,直接返回
return x_input_images_raw, shuffle(y_input_images_raw)
elif len(x_input_images_raw) >len(y_input_images_raw):# 如果x域的训练图像数量大于y域的训练图像数量,则随机选择y域的图像补充y域
add_num =int(len(x_input_images_raw) -len(y_input_images_raw))# 计算两域图像数量相差的倍数
length =len(y_input_images_raw)
for iin range(add_num):
n = random.sample(range(length), 1)# 每训练一个epoch,就打乱一下x域图像顺序
num = n[0]
y_input_images_raw.append(y_input_images_raw[num])
return x_input_images_raw, y_input_images_raw# 返回数量一致的x域和y域图像路径名称列表
else:# 与elif中的逻辑一致,只是x与y互换,不再赘述
add_num =int(len(y_input_images_raw) -len(x_input_images_raw))# 计算两域图像数量相差的倍数
length =len(x_input_images_raw)
for iin range(add_num):
n = random.sample(range(length), 1)# 每训练一个epoch,就打乱一下x域图像顺序
num = n[0]
x_input_images_raw.append(x_input_images_raw[num])
return x_input_images_raw, y_input_images_raw# 返回数量一致的x域和y域图像路径名称列表
def cv_inv_proc(img):# cv_inv_proc函数将读取图片时归一化的图片还原成原图
img_rgb = (img +1.) *127.5
return img_rgb.astype(np.float32)# bgr
def get_write_picture(x_image, y_image, fake_y, fake_x):# get_write_picture函数得到网络测试结果
x_image = cv_inv_proc(x_image)# 还原x域的图像
y_image = cv_inv_proc(y_image)# 还原y域的图像
fake_y = cv_inv_proc(fake_y[0])# 还原生成的y域的图像
fake_x = cv_inv_proc(fake_x[0])# 还原生成的x域的图像
x_output = np.concatenate((x_image, fake_y), axis=1)# 得到x域的输入图像以及对应的生成的y域图像
y_output = np.concatenate((y_image, fake_x), axis=1)# 得到y域的输入图像以及对应的生成的x域图像
return x_output, y_output
batch_size =5
def main():
if not os.path.exists(args.out_dir_x):# 如果保存x域测试结果的文件夹不存在则创建
os.makedirs(args.out_dir_x)
if not os.path.exists(args.out_dir_y):# 如果保存y域测试结果的文件夹不存在则创建
os.makedirs(args.out_dir_y)
test_x_image = tf.placeholder(tf.float32, shape=[batch_size, 64, 64, 1], name='test_x_image')# 输入的x域图像
test_y_image = tf.placeholder(tf.float32, shape=[batch_size, 64, 64, 1], name='test_y_image')# 输入的y域图像
fake_y = generator(image=test_x_image, reuse=False, name='generator_x2y')# 生成的y域图像
fake_x_ = generator(image=fake_y, reuse=False, name='generator_y2x')# 重建的x域图像
fake_x = generator(image=test_y_image, reuse=True, name='generator_y2x')# 生成的x域图像
fake_y_ = generator(image=fake_x, reuse=True, name='generator_x2y')# 重建的y域图像
restore_var = [vfor vin tf.global_variables()if 'generator' in v.name]# 需要载入的已训练的模型参数
config = tf.ConfigProto()
config.gpu_options.allow_growth =True # 设定显存不超量使用
sess = tf.Session(config=config)# 建立会话层
saver = tf.train.Saver(var_list=restore_var, max_to_keep=1)# 导入模型参数时使用
checkpoint = tf.train.latest_checkpoint(args.snapshots)# 读取模型参数
saver.restore(sess, checkpoint)# 导入模型参数
#total_step = len(x_image_resize)
total_step =50
for stepin range(total_step):
n = random.sample(range(len(x_image_resize)), batch_size)
num = n[0]
#num=x_image_resize[step]
batch_x_image = np.array(x_image_resize)[n]
batch_y_image = np.array(y_image_resize)[n]
feed_dict = {test_x_image: batch_x_image, test_y_image: batch_y_image}# 建立feed_dict
fake_x_value, fake_y_value = sess.run([fake_x, fake_y], feed_dict=feed_dict)# 得到生成的x域图像与y域图像
###三张图画在一起 去噪l_h_l
real_x = []##画真的x
plt.yticks([0.0,0.5,1.0,1.5,2.0,2.5])
#plt.tight_layout(1)
#plt.subplots_adjust(bottom=0.1,left=0.1,right=0.15,top=0.15)
plt.figure(figsize=(6,2.8))
for iin range(64):
for jin range(64):
real_x.append(x_image_resize[num][i][j]+1.9)
x = [ifor iin range(4000, 8096, 1)]
plt.plot(x, real_x, color='seagreen')
# plt.savefig('F:/bu/data/sdss/new_experiment/ll/' + str(counter))
real_yy = []##画真的y
for iin range(64):
for jin range(64):
real_yy.append(y_image_resize[num][i][j]+1.1)
x = [ifor iin range(4000, 8096, 1)]
plt.plot(x, real_yy, color='coral')
fake_xxx = []
for iin range(64):
for jin range(64):
fake_xxx.append(fake_y_value[0][i][j]+0.5)
x = [ifor iin range(4000, 8096, 1)]
plt.plot(x, fake_xxx, color='seagreen')
#plt.savefig('F:/bu/data/sdss/new_experiment/test/l_h_l/' + str(step))
plt.savefig('F:/bu/data/sdss/测试下载数据/pic/' +str(step))
plt.close()
'''
real_yy = [] ##画真的y
for i in range(64):
for j in range(64):
real_yy.append(y_image_resize[num][i][j] + 1.3)
x = [i for i in range(4000, 8096, 1)]
plt.plot(x, real_yy, color='seagreen')
real_x = [] ##画真的x
for i in range(64):
for j in range(64):
real_x.append(x_image_resize[num][i][j] + 0.6)
x = [i for i in range(4000, 8096, 1)]
plt.plot(x, real_x, color='coral')
# plt.savefig('F:/bu/data/sdss/new_experiment/ll/' + str(counter))
fake_xxx = []
for i in range(64):
for j in range(64):
fake_xxx.append(fake_x_value[0][i][j])
x = [i for i in range(4000, 8096, 1)]
plt.plot(x, fake_xxx, color='seagreen')
plt.savefig('F:/bu/data/sdss/new_experiment/test/h_l_h/' + str(step))
plt.close()
'''
print('step {:d}'.format(step))
if __name__ =='__main__':
main()
网友评论