美文网首页
Stochastic Video Generation

Stochastic Video Generation

作者: 杨逸凡 | 来源:发表于2018-03-10 18:40 被阅读0次
    parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on')
    parser.add_argument('--n_future', type=int, default=10, help='number of frames to predict during training')
    parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict during eval')
    

    n_past个frame作为参考,预测之后n_future个frame。

    LSTM Model

    import models.lstm as lstm_models
    if opt.model_dir != '':
        frame_predictor = saved_model['frame_predictor']
        posterior = saved_model['posterior']
        prior = saved_model['prior']
    else:
        frame_predictor = lstm_models.lstm(opt.g_dim+opt.z_dim, opt.g_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
        posterior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
        prior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
        frame_predictor.apply(utils.init_weights)
        posterior.apply(utils.init_weights)
        prior.apply(utils.init_weights)
    

    其中lstm.py如下:

    class lstm(nn.Module):
        def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
            super(lstm, self).__init__()
            self.input_size = input_size
            self.output_size = output_size
            self.hidden_size = hidden_size
            self.batch_size = batch_size
            self.n_layers = n_layers
            self.embed = nn.Linear(input_size, hidden_size)
            self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)])
            self.output = nn.Sequential(
                    nn.Linear(hidden_size, output_size),
                    #nn.BatchNorm1d(output_size),
                    nn.Tanh())
            self.hidden = self.init_hidden()
    
        def init_hidden(self):
            hidden = []
            for i in range(self.n_layers):
                hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()),
                               Variable(torch.zeros(self.batch_size, self.hidden_size).cuda())))
            return hidden
    
        def forward(self, input):
            embedded = self.embed(input.view(-1, self.input_size))
            h_in = embedded
            for i in range(self.n_layers):
                self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
                h_in = self.hidden[i][0]
    
            return self.output(h_in)
    

    继承了nn.Module的类在赋值时提供大小参数(初始化),在调用时提供输入参数(forward计算)。里面的nn.Linear(insize, outsize)、nn.LSTMCell(insize, outsize)等等也是一样的工作原理。

    lstm与gaussian_lstm区别在于:

    class gaussian_lstm(nn.Module):
        def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
            ...
            self.mu_net = nn.Linear(hidden_size, output_size)
            self.logvar_net = nn.Linear(hidden_size, output_size)
            ...
    
        def reparameterize(self, mu, logvar):
            logvar = logvar.mul(0.5).exp_()
            eps = Variable(logvar.data.new(logvar.size()).normal_())
            return eps.mul(logvar).add_(mu)
    
        def forward(self, input):
            embedded = self.embed(input.view(-1, self.input_size))
            h_in = embedded
            for i in range(self.n_layers):
                self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
                h_in = self.hidden[i][0]
            mu = self.mu_net(h_in)
            logvar = self.logvar_net(h_in)
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar
    

    输出的是正态采样的z,以及均值和对数标准差。

    Encoder / Decoder

    以dcgan_64为例

    if opt.model == 'dcgan':
        if opt.image_width == 64:
            import models.dcgan_64 as model
    

    dcgan_64.py如下:

    import torch
    import torch.nn as nn
    
    class dcgan_conv(nn.Module):
        def __init__(self, nin, nout):
            super(dcgan_conv, self).__init__()
            self.main = nn.Sequential(
                    nn.Conv2d(nin, nout, 4, 2, 1),
                    nn.BatchNorm2d(nout),
                    nn.LeakyReLU(0.2, inplace=True),
                    )
    
        def forward(self, input):
            return self.main(input)
    
    class dcgan_upconv(nn.Module):
        def __init__(self, nin, nout):
            super(dcgan_upconv, self).__init__()
            self.main = nn.Sequential(
                    nn.ConvTranspose2d(nin, nout, 4, 2, 1),
                    nn.BatchNorm2d(nout),
                    nn.LeakyReLU(0.2, inplace=True),
                    )
    
        def forward(self, input):
            return self.main(input)
    
    class encoder(nn.Module):
        def __init__(self, dim, nc=1):
            super(encoder, self).__init__()
            self.dim = dim
            nf = 64
            # input is (nc) x 64 x 64
            self.c1 = dcgan_conv(nc, nf)
            # state size. (nf) x 32 x 32
            self.c2 = dcgan_conv(nf, nf * 2)
            # state size. (nf*2) x 16 x 16
            self.c3 = dcgan_conv(nf * 2, nf * 4)
            # state size. (nf*4) x 8 x 8
            self.c4 = dcgan_conv(nf * 4, nf * 8)
            # state size. (nf*8) x 4 x 4
            self.c5 = nn.Sequential(
                    nn.Conv2d(nf * 8, dim, 4, 1, 0),
                    nn.BatchNorm2d(dim),
                    nn.Tanh()
                    )
    
        def forward(self, input):
            h1 = self.c1(input)
            h2 = self.c2(h1)
            h3 = self.c3(h2)
            h4 = self.c4(h3)
            h5 = self.c5(h4)
            return h5.view(-1, self.dim), [h1, h2, h3, h4]
    

    encoder将每一层都输出了

    class decoder(nn.Module):
        def __init__(self, dim, nc=1):
            super(decoder, self).__init__()
            self.dim = dim
            nf = 64
            self.upc1 = nn.Sequential(
                    # input is Z, going into a convolution
                    nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
                    nn.BatchNorm2d(nf * 8),
                    nn.LeakyReLU(0.2, inplace=True)
                    )
            # state size. (nf*8) x 4 x 4
            self.upc2 = dcgan_upconv(nf * 8 * 2, nf * 4)
            # state size. (nf*4) x 8 x 8
            self.upc3 = dcgan_upconv(nf * 4 * 2, nf * 2)
            # state size. (nf*2) x 16 x 16
            self.upc4 = dcgan_upconv(nf * 2 * 2, nf)
            # state size. (nf) x 32 x 32
            self.upc5 = nn.Sequential(
                    nn.ConvTranspose2d(nf * 2, nc, 4, 2, 1),
                    nn.Sigmoid()
                    # state size. (nc) x 64 x 64
                    )
    
        def forward(self, input):
            vec, skip = input 
            d1 = self.upc1(vec.view(-1, self.dim, 1, 1))
            d2 = self.upc2(torch.cat([d1, skip[3]], 1))
            d3 = self.upc3(torch.cat([d2, skip[2]], 1))
            d4 = self.upc4(torch.cat([d3, skip[1]], 1))
            output = self.upc5(torch.cat([d4, skip[0]], 1))
            return output
    

    decoder的每一层和encoder的每一层连接起来。

    Training Functions

    def train(x):
        ...
        mse = 0
        kld = 0
        for i in range(1, opt.n_past+opt.n_future):
            h = encoder(x[i-1])
            h_target = encoder(x[i])[0]
            if opt.last_frame_skip or i < opt.n_past:   
                h, skip = h
            else:
                h = h[0]
            z_t, mu, logvar = posterior(h_target)
            _, mu_p, logvar_p = prior(h)
            h_pred = frame_predictor(torch.cat([h, z_t], 1))
            x_pred = decoder([h_pred, skip])
            mse += mse_criterion(x_pred, x[i])
            kld += kl_criterion(mu, logvar, mu_p, logvar_p)
    
        loss = mse + kld*opt.beta
        loss.backward()
    
        frame_predictor_optimizer.step()
        posterior_optimizer.step()
        prior_optimizer.step()
        encoder_optimizer.step()
        decoder_optimizer.step()
    
    
        return mse.data.cpu().numpy()/(opt.n_past+opt.n_future), kld.data.cpu().numpy()/(opt.n_future+opt.n_past)
    
    

    在n_past之前每次更新skip。
    从第n_past开始,将i的输出作为i+1的输入;并且每次均对目标帧进行解码采样。

    for epoch in range(opt.niter):
        frame_predictor.train()
        posterior.train()
        prior.train()
        encoder.train()
        decoder.train()
        epoch_mse = 0
        epoch_kld = 0
        progress = progressbar.ProgressBar(max_value=opt.epoch_size).start()
        for i in range(opt.epoch_size):
            progress.update(i+1)
            x = next(training_batch_generator)
    
            # train frame_predictor 
            mse, kld = train(x)
            epoch_mse += mse
            epoch_kld += kld
    
    
        progress.finish()
    

    .train(): Sets the module in training mode.

    Predict

    def make_gifs(x, idx):
        for s in range(nsample):
            progress.update(s+1)
            gen_seq = []
            gt_seq = []
            frame_predictor.hidden = frame_predictor.init_hidden()
            posterior.hidden = posterior.init_hidden()
            prior.hidden = prior.init_hidden()
            x_in = x[0]
            all_gen.append([])
            all_gen[s].append(x_in)
            for i in range(1, opt.n_eval):
                h = encoder(x_in)
                if opt.last_frame_skip or i < opt.n_past:   
                    h, skip = h
                else:
                    h, _ = h
                h = h.detach()
                if i + 1 < opt.n_past:
                    h_target = encoder(x[i])[0].detach()
                    z_t, _, _ = posterior(h_target)
                else:
                    z_t, _, _ = prior(h)
                if i < opt.n_past:
                    frame_predictor(torch.cat([h, z_t], 1))
                    x_in = x[i]
                    all_gen[s].append(x_in)
                else:
                    h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                    x_in = decoder([h, skip]).detach()
                    gen_seq.append(x_in.data.cpu().numpy())
                    gt_seq.append(x[i].data.cpu().numpy())
                    all_gen[s].append(x_in)
            _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)
    
        progress.finish()
    

    n_past及以后的图片都是未知的,因此从第n_past开始,将i的输出作为i+1的输入;从第n_past-1开始,从prior分布中根据输入采样。

    相关文章

      网友评论

          本文标题:Stochastic Video Generation

          本文链接:https://www.haomeiwen.com/subject/hiqbfftx.html