SFTGAN test

作者: 狼无雨雪 | 来源:发表于2019-07-04 18:57 被阅读0次

    SFTGAN的github地址:SFTGAN

    程序的组织架构如下:


    SFTGAN.png

    首先说明执行命令


    运行test_subdir.py(即可运行文件,代码如下)将images下的所有图片放大到4096并放置于images_4096下,将大于1024并小于2048的图片先resize到1024再通过SFTGAN来super resolution,而大于2048的图片直接resize到4096。其中的1表示首先拷贝和处理images的图片为三通道然后再放置于images_4096, 而5代表迭代五次,因为本程序只能通过SFTGAN放大四倍,若要从256的图片放大到4096要两次,更小的要更多次数,保险起见设置为5.

    python test_subdir.py images/ images_4096/ 1 5 2048 4096 1024
    

    以下代码放置于pytorch_test文件夹下,用于将目录下的文件放大到指定大小

    '''
    Segmentation codes for generating segmentation probability maps for SFTGAN
    '''
    
    import os
    import glob
    import numpy as np
    import cv2
    import sys
    import torch
    import torchvision.utils
    import time
    import architectures as arch
    import util
    from PIL import Image
    # 通道转换
    def change_image_channels(input_image_path, output_image_path):
        image = Image.open(input_image_path)
        if image.mode == 'RGBA':
            r, g, b, a = image.split()
            image = Image.merge("RGB", (r, g, b))
            try:
                os.remove(output_image_path)
            except:
                pass
            image.save(output_image_path)
        elif image.mode != 'RGB':
            image = image.convert("RGB")
            try:
                os.remove(output_image_path)
            except:
                pass
            image.save(output_image_path)
        else:
            try:
                os.remove(output_image_path)
            except:
                pass
            image.save(output_image_path)
        return image
    
    
    # options
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    times = 3
    channel_mark = 1
    imgSize = 4096
    finalSize = 4096
    minImgSize = 1024
    
    
    imagespath = sys.argv[1]   #must end with "/"
    outputdir = sys.argv[2]   #must end with "/"
    channel_mark = int(sys.argv[3]) #default 1, means change all images to 3 channel
    times = int(sys.argv[4]) #default 3
    imgSize = int(sys.argv[5])
    finalSize = int(sys.argv[6])
    minImgSize = int(sys.argv[7])
    if not os.path.exists(outputdir):
        os.makedirs(outputdir)
    device = torch.device('cuda')  # if you want to run on CPU, change 'cuda' -> 'cpu'
    # device = torch.device('cpu')
    
    
    
    model_path = '/home/t-huch/SFTGAN/pretrained_models/SFTGAN_torch.pth'  # torch version
    
    
    
    if 'torch' in model_path:  # torch version
        model = arch.SFT_Net_torch()
    else:  # pytorch version
        model = arch.SFT_Net()
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    
    
    # load model
    seg_model = arch.OutdoorSceneSeg()
    model_path = '/home/t-huch/SFTGAN/pretrained_models/segmentation_OST_bic.pth'
    seg_model.load_state_dict(torch.load(model_path), strict=True)
    seg_model.eval()
    seg_model = seg_model.to(device)
    
    print('Testing SFTGAN ...')
    
    print(channel_mark)
    if channel_mark == 1:
        
        for root, dirs, files in os.walk(imagespath):
            for file in files:
                start_time = time.time()
                path = os.path.join(root,file)
                imgname = os.path.basename(path)
                subDir = os.path.join(outputdir,root.replace(imagespath, ""))
                if not os.path.exists(subDir):
                    os.makedirs(subDir)
                print(path)
                change_image_channels(path, os.path.join(subDir,imgname))
    
    
    while times > 0:
        times -= 1
        for root, dirs, files in os.walk(outputdir):
            for file in files:
                start_time = time.time()
                path = os.path.join(root,file)
                imgname = os.path.basename(path)
                # read image
                img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
                print(img.shape, path)
                if img.shape[0] <imgSize or img.shape[1] <imgSize:
                    if img.shape[0] > minImgSize or img.shape[1] > minImgSize:
                        img = cv2.resize(img, (minImgSize, minImgSize), interpolation=cv2.INTER_CUBIC)
                    test_img = util.modcrop(img, 8)
                    img = util.modcrop(img, 8)
                    if img.ndim == 2:
                        img = np.expand_dims(img, axis=2)
                    img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
    
                    img_LR = util.imresize(img / 255, 1, antialiasing=True)
                    img = util.imresize(img_LR, 4, antialiasing=True) * 255
    
                    img[0] -= 103.939
                    img[1] -= 116.779
                    img[2] -= 123.68
                    img = img.unsqueeze(0)
                    img = img.to(device)
    
                    with torch.no_grad():
                        output = seg_model(img).detach().float().cpu().squeeze()
    
                    test_img = test_img * 1.0 / 255
                    if test_img.ndim == 2:
                        test_img = np.expand_dims(test_img, axis=2)
                    test_img = torch.from_numpy(np.transpose(test_img[:, :, [2, 1, 0]], (2, 0, 1))).float()
                    img_LR = util.imresize(test_img, 1 , antialiasing=True)
                    img_LR = img_LR.unsqueeze(0)
                    img_LR = img_LR.to(device)
    
                    seg = output
    
                    seg = seg.unsqueeze(0)
                    seg = seg.to(device)
                    with torch.no_grad():
                        output = model((img_LR, seg)).data.float().cpu().squeeze()
                    output = util.tensor2img(output)
                    subDir = os.path.join(outputdir,root.replace(outputdir, ""))
                    if not os.path.exists(subDir):
                        os.makedirs(subDir)
                    util.save_img(output, os.path.join(subDir,imgname))
    
                    print("time consumption : {}".format(time.time() - start_time))
                elif img.shape[0] == finalSize and img.shape[1] == finalSize:
                    pass
    #                 subDir = os.path.join(outputdir,root.replace(outputdir, ""))
    #                 if not os.path.exists(subDir):
    #                     os.makedirs(subDir)
    #                 cv2.imwrite(os.path.join(subDir,imgname), img)
                else:
                    img = cv2.resize(img, (finalSize, finalSize), interpolation=cv2.INTER_CUBIC)
                    subDir = os.path.join(outputdir,root.replace(outputdir, ""))
                    if not os.path.exists(subDir):
                        os.makedirs(subDir)
                    cv2.imwrite(os.path.join(subDir,imgname), img)
    
    

    相关文章

      网友评论

        本文标题:SFTGAN test

        本文链接:https://www.haomeiwen.com/subject/lnxdhctx.html