美文网首页
utility.py

utility.py

作者: Amdur | 来源:发表于2019-07-26 19:02 被阅读0次
    import os
    import math
    import time
    import datetime
    from functools import reduce
    
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    import numpy as np
    import scipy.misc as misc
    from skimage.restoration import denoise_bilateral
    
    import torch
    import torch.optim as optim
    import torch.optim.lr_scheduler as lrs
    
    class timer():
        def __init__(self):
            self.acc = 0
            self.tic()
            #print ("2-1-1-checkpoint")
    
        def tic(self):
            self.t0 = time.time()
            #print ("2-1-2-checkpoint")
    
        def toc(self):
            return time.time() - self.t0
            #print ("2-1-3-checkpoint")
    
        def hold(self):
            self.acc += self.toc()
            #print ("2-1-4-checkpoint")
    
        def release(self):
            ret = self.acc
            self.acc = 0
            #print ("2-1-5-checkpoint")
    
            return ret
    
        def reset(self):
            self.acc = 0
            #print ("2-1-6-checkpoint")
    
    class checkpoint():
        def __init__(self, args):
            self.args = args
            self.ok = True
            self.log = torch.Tensor()
            now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    
            if args.load == '.':
                if args.save == '.': args.save = now
                self.dir = '../experiment/' + args.save
            else:
                self.dir = '../experiment/' + args.load
                if not os.path.exists(self.dir):
                    args.load = '.'
                else:
                    self.log = torch.load(self.dir + '/psnr_log.pt')
                    print('Continue from epoch {}...'.format(len(self.log)))
    
            if args.reset:
                os.system('rm -rf ' + self.dir)
                args.load = '.'
    
            def _make_dir(path):
                if not os.path.exists(path): os.makedirs(path)
    
            _make_dir(self.dir)
            _make_dir(self.dir + '/model')
            _make_dir(self.dir + '/results')
            _make_dir(self.dir + '/residuals')
            _make_dir(self.dir + '/branches')
    
            open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
            self.log_file = open(self.dir + '/log.txt', open_type)
            with open(self.dir + '/config.txt', open_type) as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
            #print ("2-2-1-checkpoint")
    
        def save(self, trainer, epoch, is_best=False):
            trainer.model.save(self.dir, epoch, is_best=is_best)
            trainer.loss.save(self.dir)
            trainer.loss.plot_loss(self.dir, epoch)
    
            self.plot_psnr(epoch)
            torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
            torch.save(
                trainer.optimizer.state_dict(),
                os.path.join(self.dir, 'optimizer.pt')
            )
            #print ("2-2-2-checkpoint")
    
        def add_log(self, log):
            self.log = torch.cat([self.log, log])
            #print ("2-2-3-checkpoint")
    
        def write_log(self, log, refresh=False):
            #print(log)
            self.log_file.write(log + '\n')
            if refresh:
                self.log_file.close()
                self.log_file = open(self.dir + '/log.txt', 'a')
            #print ("2-2-4-checkpoint")
    
        def done(self):
            self.log_file.close()
            #print ("2-2-5-checkpoint")
    
        def plot_psnr(self, epoch):
            axis = np.linspace(1, epoch, epoch)
            label = 'SR on {}'.format(self.args.data_test)
            fig = plt.figure()
            plt.title(label)
            for idx_scale, scale in enumerate(self.args.scale):
                plt.plot(
                    axis,
                    self.log[:, idx_scale].numpy(),
                    label='Scale {}'.format(scale)
                )
            plt.legend()
            plt.xlabel('Epochs')
            plt.ylabel('PSNR')
            plt.grid(True)
            plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
            plt.close(fig)
            #print ("2-2-5-checkpoint")
    
        def save_results(self, filename, save_list, scale):
            filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
            postfix = ('SR', 'LR', 'HR')
            for v, p in zip(save_list, postfix):
                normalized = v[0].data.mul(255 / self.args.rgb_range)
                ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
                
                if ndarr.shape[-1] == 1: 
                        ndarr = ndarr[:,:,0] 
                        
                misc.imsave('{}{}.png'.format(filename, p), ndarr)
            #print ("2-2-6-checkpoint")
    
        def save_residuals(self, filename, save_list, scale): 
            filename = '{}/residuals/{}_x{}'.format(self.dir, filename, scale)
            sr, hr = save_list[0], save_list[-1]
    
            def _prepare(x):
                normalized = x[0].data.mul(1. / self.args.rgb_range)
                out = normalized.permute(1,2,0).cpu().numpy()
                
                if out.shape[-1] == 1: 
                    out = out[:,:,0]
    
                return out 
    
            ndarr_sr, ndarr_hr = _prepare(sr), _prepare(hr)
            out = np.abs(ndarr_hr - ndarr_sr)
            misc.imsave('{}.png'.format(filename), out)
            #print ("2-2-7-checkpoint")
    
        def save_branches(self, filename, save_list, scale): 
            filename = '{}/branches/{}_x{}'.format(self.dir, filename, scale)
            
            def _prepare(x, residual):
                normalized = x[0].data.mul(1. / self.args.rgb_range)
                if not residual: 
                    out = normalized.permute(1,2,0).cpu().numpy()
                else: 
                    out = np.abs(normalized.permute(1,2,0).cpu().numpy())
    
                if out.shape[-1] == 1: 
                    out = out[:,:,0]
                return out 
    
            for i, branch_output in enumerate(save_list): 
                ndarr = _prepare(branch_output, not (i==0))
                misc.imsave('{}{}.png'.format(filename, '_branch{}'.format(i)), ndarr)
            #print ("2-2-8-checkpoint")
            return 
    
    def get_bilateral(tensor, rgb_range): 
        tensor = tensor.numpy().transpose(0,2,3,1) / rgb_range
        out = np.zeros_like(tensor)
    
        for i, t in enumerate(tensor): 
            out[i] = denoise_bilateral(t)
    
        #print ("2-3-checkpoint")
        return torch.Tensor(out.transpose(0,3,1,2)) * rgb_range
    
    def quantize(img, rgb_range):
        pixel_range = 255 / rgb_range
        return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
        #print ("2-4-checkpoint")
    
    def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
        diff = (sr - hr).data.div(rgb_range)
        if benchmark:
            shave = scale
            if diff.size(1) > 1:
                convert = diff.new(1, 3, 1, 1)
                convert[0, 0, 0, 0] = 65.738
                convert[0, 1, 0, 0] = 129.057
                convert[0, 2, 0, 0] = 25.064
                diff.mul_(convert).div_(256)
                diff = diff.sum(dim=1, keepdim=True)
        else:
            shave = scale + 6
    
        valid = diff[:, :, shave:-shave, shave:-shave]
        mse = valid.pow(2).mean()
        #print ("2-5-checkpoint")
        return -10 * math.log10(mse)
    
    def make_optimizer(args, my_model):
        trainable = filter(lambda x: x.requires_grad, my_model.parameters())
    
        if args.optimizer == 'SGD':
            optimizer_function = optim.SGD
            kwargs = {'momentum': args.momentum}
        elif args.optimizer == 'ADAM':
            optimizer_function = optim.Adam
            kwargs = {
                'betas': (args.beta1, args.beta2),
                'eps': args.epsilon
            }
        elif args.optimizer == 'RMSprop':
            optimizer_function = optim.RMSprop
            kwargs = {'eps': args.epsilon}
    
        kwargs['lr'] = args.lr
        kwargs['weight_decay'] = args.weight_decay
        #print ("2-6-checkpoint")
        return optimizer_function(trainable, **kwargs)
    
    def make_scheduler(args, my_optimizer):
        if args.decay_type == 'step':
            scheduler = lrs.StepLR(
                my_optimizer,
                step_size=args.lr_decay,
                gamma=args.gamma
            )
        elif args.decay_type.find('step') >= 0:
            milestones = args.decay_type.split('_')
            milestones.pop(0)
            milestones = list(map(lambda x: int(x), milestones))
            scheduler = lrs.MultiStepLR(
                my_optimizer,
                milestones=milestones,
                gamma=args.gamma
            )
        #print ("2-7-checkpoint")
        return scheduler
    

    相关文章

      网友评论

          本文标题:utility.py

          本文链接:https://www.haomeiwen.com/subject/faadrctx.html