SFTGAN的github地址:SFTGAN
程序的组织架构如下:
![](https://img.haomeiwen.com/i15647055/834a8aa1adfe8716.png)
首先说明执行命令
运行test_subdir.py(即可运行文件,代码如下)将images下的所有图片放大到4096并放置于images_4096下,将大于1024并小于2048的图片先resize到1024再通过SFTGAN来super resolution,而大于2048的图片直接resize到4096。其中的1表示首先拷贝和处理images的图片为三通道然后再放置于images_4096, 而5代表迭代五次,因为本程序只能通过SFTGAN放大四倍,若要从256的图片放大到4096要两次,更小的要更多次数,保险起见设置为5.
python test_subdir.py images/ images_4096/ 1 5 2048 4096 1024
以下代码放置于pytorch_test文件夹下,用于将目录下的文件放大到指定大小
'''
Segmentation codes for generating segmentation probability maps for SFTGAN
'''
import os
import glob
import numpy as np
import cv2
import sys
import torch
import torchvision.utils
import time
import architectures as arch
import util
from PIL import Image
# 通道转换
def change_image_channels(input_image_path, output_image_path):
image = Image.open(input_image_path)
if image.mode == 'RGBA':
r, g, b, a = image.split()
image = Image.merge("RGB", (r, g, b))
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
elif image.mode != 'RGB':
image = image.convert("RGB")
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
else:
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
return image
# options
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
times = 3
channel_mark = 1
imgSize = 4096
finalSize = 4096
minImgSize = 1024
imagespath = sys.argv[1] #must end with "/"
outputdir = sys.argv[2] #must end with "/"
channel_mark = int(sys.argv[3]) #default 1, means change all images to 3 channel
times = int(sys.argv[4]) #default 3
imgSize = int(sys.argv[5])
finalSize = int(sys.argv[6])
minImgSize = int(sys.argv[7])
if not os.path.exists(outputdir):
os.makedirs(outputdir)
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> 'cpu'
# device = torch.device('cpu')
model_path = '/home/t-huch/SFTGAN/pretrained_models/SFTGAN_torch.pth' # torch version
if 'torch' in model_path: # torch version
model = arch.SFT_Net_torch()
else: # pytorch version
model = arch.SFT_Net()
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
# load model
seg_model = arch.OutdoorSceneSeg()
model_path = '/home/t-huch/SFTGAN/pretrained_models/segmentation_OST_bic.pth'
seg_model.load_state_dict(torch.load(model_path), strict=True)
seg_model.eval()
seg_model = seg_model.to(device)
print('Testing SFTGAN ...')
print(channel_mark)
if channel_mark == 1:
for root, dirs, files in os.walk(imagespath):
for file in files:
start_time = time.time()
path = os.path.join(root,file)
imgname = os.path.basename(path)
subDir = os.path.join(outputdir,root.replace(imagespath, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
print(path)
change_image_channels(path, os.path.join(subDir,imgname))
while times > 0:
times -= 1
for root, dirs, files in os.walk(outputdir):
for file in files:
start_time = time.time()
path = os.path.join(root,file)
imgname = os.path.basename(path)
# read image
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
print(img.shape, path)
if img.shape[0] <imgSize or img.shape[1] <imgSize:
if img.shape[0] > minImgSize or img.shape[1] > minImgSize:
img = cv2.resize(img, (minImgSize, minImgSize), interpolation=cv2.INTER_CUBIC)
test_img = util.modcrop(img, 8)
img = util.modcrop(img, 8)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
img_LR = util.imresize(img / 255, 1, antialiasing=True)
img = util.imresize(img_LR, 4, antialiasing=True) * 255
img[0] -= 103.939
img[1] -= 116.779
img[2] -= 123.68
img = img.unsqueeze(0)
img = img.to(device)
with torch.no_grad():
output = seg_model(img).detach().float().cpu().squeeze()
test_img = test_img * 1.0 / 255
if test_img.ndim == 2:
test_img = np.expand_dims(test_img, axis=2)
test_img = torch.from_numpy(np.transpose(test_img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = util.imresize(test_img, 1 , antialiasing=True)
img_LR = img_LR.unsqueeze(0)
img_LR = img_LR.to(device)
seg = output
seg = seg.unsqueeze(0)
seg = seg.to(device)
with torch.no_grad():
output = model((img_LR, seg)).data.float().cpu().squeeze()
output = util.tensor2img(output)
subDir = os.path.join(outputdir,root.replace(outputdir, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
util.save_img(output, os.path.join(subDir,imgname))
print("time consumption : {}".format(time.time() - start_time))
elif img.shape[0] == finalSize and img.shape[1] == finalSize:
pass
# subDir = os.path.join(outputdir,root.replace(outputdir, ""))
# if not os.path.exists(subDir):
# os.makedirs(subDir)
# cv2.imwrite(os.path.join(subDir,imgname), img)
else:
img = cv2.resize(img, (finalSize, finalSize), interpolation=cv2.INTER_CUBIC)
subDir = os.path.join(outputdir,root.replace(outputdir, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
cv2.imwrite(os.path.join(subDir,imgname), img)
网友评论