美文网首页
trainer.py

trainer.py

作者: Amdur | 来源:发表于2019-07-26 19:01 被阅读0次
    import os
    import math
    from decimal import Decimal
    
    import utility
    
    import torch
    from torch.autograd import Variable
    from tqdm import tqdm
    
    from functools import reduce
    
    class Trainer():
        def __init__(self, args, loader, my_model, my_loss, ckp):
            self.args = args
            self.scale = args.scale
    
            self.ckp = ckp
            self.loader_train = loader.loader_train
            self.loader_test = loader.loader_test
            self.model = my_model
            self.loss = my_loss
            self.optimizer = utility.make_optimizer(args, self.model)
            self.scheduler = utility.make_scheduler(args, self.optimizer)
    
            if self.args.load != '.':
                self.optimizer.load_state_dict(
                    torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
                )
                for _ in range(len(ckp.log)): self.scheduler.step()
    
            self.error_last = 1e8
    
        def train(self):
            self.scheduler.step()
            self.loss.step()
            epoch = self.scheduler.last_epoch + 1
            lr = self.scheduler.get_lr()[0]
    
            self.ckp.write_log(
                '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
            )
            self.loss.start_log()
            self.model.train()
    
            timer_data, timer_model = utility.timer(), utility.timer()
            for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
                lr, hr = self.prepare([lr, hr])
                timer_data.hold()
                timer_model.tic()
    
                self.optimizer.zero_grad()
                sr = self.model(lr, idx_scale, True)
    
                if (self.args.branch_label.lower() == 'residual') and (self.args.enable_branches): 
                    sr_ = reduce(lambda x,y: x+y, self.model.model.branch_outputs[:-1])
    
                    if self.args.bilateral_residuals: 
                        sr_ = utility.get_bilateral(sr_, self.args.rgb_range)
                        hr = utility.get_bilateral(hr, self.args.rgb_range)
    
                        if not self.args.cpu: 
                            sr_ = sr_.cuda() 
                            hr = hr.cuda()
    
                    hr = hr - sr_ 
    
                loss = self.loss(sr, hr)
                if loss.item() < self.args.skip_threshold * self.error_last:
                    loss.backward()
                    self.optimizer.step()
                else:
                    print('Skip this batch {}! (Loss: {})'.format(
                        batch + 1, loss.item()
                    ))
    
                timer_model.hold()
    
                if (batch + 1) % self.args.print_every == 0:
                    self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                        (batch + 1) * self.args.batch_size,
                        len(self.loader_train.dataset),
                        self.loss.display_loss(batch),
                        timer_model.release(),
                        timer_data.release()))
    
                timer_data.tic()
    
            self.loss.end_log(len(self.loader_train))
            self.error_last = self.loss.log[-1, -1]
    
        def test(self):
            epoch = self.scheduler.last_epoch + 1
            self.ckp.write_log('\nEvaluation:')
            self.ckp.add_log(torch.zeros(1, len(self.scale)))
            self.model.eval()
    
            timer_test = utility.timer()
            with torch.no_grad():
                for idx_scale, scale in enumerate(self.scale):
                    eval_acc = 0
                    self.loader_test.dataset.set_scale(idx_scale)
                    tqdm_test = tqdm(self.loader_test, ncols=80)
                    for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
                        
                        filename = filename[0]
                        no_eval = (hr.nelement() == 1)
                        if not no_eval:
                            lr, hr = self.prepare([lr, hr])
                        else:
                            lr = self.prepare([lr])[0]
    
                        sr = self.model(lr, idx_scale)
                        sr = utility.quantize(sr, self.args.rgb_range)
    
                        save_list = [sr]                    
                        if not no_eval:
                            eval_acc += utility.calc_psnr(
                                sr, hr, scale, self.args.rgb_range,
                                benchmark=self.loader_test.dataset.benchmark
                            )
                            save_list.extend([lr, hr])
    
                        if self.args.save_results:
                            self.ckp.save_results(filename, save_list, scale)
                        
                        if self.args.save_residuals: 
                            self.ckp.save_residuals(filename, save_list, scale)
    
                        if self.args.save_branches: 
                            self.ckp.save_branches(filename, self.model.model.branch_outputs
                                                ,scale)
    
                    self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
                    best = self.ckp.log.max(0)
                    self.ckp.write_log(
                        '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                            self.args.data_test,
                            scale,
                            self.ckp.log[-1, idx_scale],
                            best[0][idx_scale],
                            best[1][idx_scale] + 1
                        )
                    )
    
            self.ckp.write_log(
                'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
            )
            if not self.args.test_only:
                self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
    
        def prepare(self, l, volatile=False):
            device = torch.device('cpu' if self.args.cpu else 'cuda')
            def _prepare(tensor):
                if self.args.precision == 'half': tensor = tensor.half()
                return tensor.to(device)
               
            return [_prepare(_l) for _l in l]
    
        def terminate(self):
            if self.args.test_only:
                self.test()
                return True
            else:
                epoch = self.scheduler.last_epoch + 1
                return epoch >= self.args.epochs
    

    相关文章

      网友评论

          本文标题:trainer.py

          本文链接:https://www.haomeiwen.com/subject/ioadrctx.html