美文网首页
[CVPR'21] LIIF文章及源码理解

[CVPR'21] LIIF文章及源码理解

作者: KyoDante | 来源:发表于2022-04-22 11:19 被阅读0次

文章名为:Learning Continuous Image Representation with Local Implicit Image Function,简称LIIF,收录在CVPR21年,源码为:https://yinboc.github.io/liif/。由于拜读文章之后,对实现比较感兴趣,所以学习并分享源码实现部分(有所简化,主要为了将代码和原文对应上。):


核心思想:首先,需要得到离散图片的特征z,然后通过一个网络f_{\theta},将特征z和连续域的坐标x映射成目标的预测值。

以EDSR数据集为例,作者构建了一个Encoder(EDSR类),来完成三通道的离散图片到特征z的映射。

def conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class EDSR(nn.Module):
    def __init__(self):
        super(EDSR, self).__init__()
        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]
        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.out_dim = n_feats
        
    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        x = res
        return x

代码有所简化,主要结构是卷积(self.head)+16层ResBlock+卷积(self.body)。

然后是这个网络f_{\theta},代码中体现为一个MLP:

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        shape = x.shape[:-1]
        x = self.layers(x.view(-1, x.shape[-1]))
        return x.view(*shape, -1)

主要结构是(Linear+ReLU)*4 + Linear共9层。

最简单的情况是:实际在预测的时候,以待预测下标x_{q}相对最近的离散查询点的坐标来进行的:

另外,结合了三个特性来优化模型表现:

  • Feature unfolding:个人理解:应该是多个离散的像素合成为一个block,增加信息量。


feat = self.encoder(x) # 输入x到前面的EDSR实例,得到特征$z$
feat = nn.functional.unfold(feat, 3, padding=1)
                    .view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) # (n, c, w, h) -> (n, c*9, block_index1, block_index2),其中,第二维是按照先块后通道来排列的。而且block_index1和block_index2的数值对应为w和h。
  • Local ensemble:个人理解:为了缓解图像边界的突变问题,把各个部分按照距离进行加权平均和,平滑过渡边界区域。


其中,需要先解决中心坐标的问题。以下是根据图像大小来生成对应的中心坐标。

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret

数据集生成,主要是dataset的getitem方法:

# 用torchvision来缩放,缩放的方法是双立方插值
def resize_fn(img, size):
    return transforms.ToTensor()(
        transforms.Resize(size, Image.BICUBIC)(
            transforms.ToPILImage()(img)))
# 把图像转成对应的坐标和数值
def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])
    rgb = img.view(3, -1).permute(1, 0)
    return coord, rgb
# 
def __getitem__(self, idx):
    img_lr, img_hr = self.dataset[idx]
    p = idx / (len(self.dataset) - 1)
    w_hr = round(self.size_min + (self.size_max - self.size_min) * p) # 随着idx的变大,目标图像越来越大。
    img_hr = resize_fn(img_hr, w_hr) # 缩放目标图像

    if self.augment: # 图像增强
        if random.random() < 0.5: # 概率性反转
            img_lr = img_lr.flip(-1)
            img_hr = img_hr.flip(-1)

    if self.gt_resize is not None: # 如果指定了目标图像大小,则缩放目标图像
        img_hr = resize_fn(img_hr, self.gt_resize)

    # 得到hr的坐标和rgb数值
    hr_coord, hr_rgb = to_pixel_samples(img_hr)

    # hr部分只随机取出一部分像素?(不放回取样)
    if self.sample_q is not None:
        sample_lst = np.random.choice(
            len(hr_coord), self.sample_q, replace=False)
        hr_coord = hr_coord[sample_lst]
        hr_rgb = hr_rgb[sample_lst]
    # 得到目标图像对应cell的大小。
    cell = torch.ones_like(hr_coord)
    cell[:, 0] *= 2 / img_hr.shape[-2]
    cell[:, 1] *= 2 / img_hr.shape[-1]
    # 每次是输入图像,目标图像的坐标,目标图像的cell大小,目标图像。
    return {
        'inp': img_lr,
        'coord': hr_coord,
        'cell': cell,
        'gt': hr_rgb
    }

然后再对应到具体的操作:整体的操作是:构造目标图像的坐标coord_,然后通过这个坐标去采样特征和坐标,包括:从当前输入低分图像的特征feat采样到q_feat;从当前输入低分图像的坐标feat_coord采样得到q_coord。然后构造出相对的坐标rel_coord,再乘上特征对应的shape大小,得到相对的特征空间偏移,作为公式(4)中的x_{q}-v_{t}^{*}

# 先通过函数算出$z$
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
    .permute(2, 0, 1) \
    .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])

# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
        q_feat = F.grid_sample(
            feat, coord_.flip(-1).unsqueeze(1),
            mode='nearest', align_corners=False)[:, :, 0, :] \
            .permute(0, 2, 1)
        q_coord = F.grid_sample(
            feat_coord, coord_.flip(-1).unsqueeze(1),
            mode='nearest', align_corners=False)[:, :, 0, :] \
            .permute(0, 2, 1)
        rel_coord = coord - q_coord
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1)

        # cell decoding放在下面解说。
        if self.cell_decode:
            rel_cell = cell.clone()
            rel_cell[:, :, 0] *= feat.shape[-2]
            rel_cell[:, :, 1] *= feat.shape[-1]
            inp = torch.cat([inp, rel_cell], dim=-1)

        bs, q = coord.shape[:2]
        pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
        preds.append(pred)

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9)

tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble: # 为什么要做面积计算的结果翻转?
    t = areas[0]; areas[0] = areas[3]; areas[3] = t
    t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
    ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret

根据相对的坐标,rel_coord可以算出对应部分的面积,进而在最后以公式(4)的比例累加的方式进行计算。(体现在最后8行代码)

  • Cell decoding:个人理解:把坐标以及对应的小块大小作为额外信息输入,主要是c的加入,实验表明是有效果的。

先把目标图像的cell大小,乘上特征对应的shape大小,变为相对大小rel_cell。这部分就是公式(5)中的c

最终,把这些特征传入MLP(即self.imnet)进行计算,获得预测的值。


相关文章

网友评论

      本文标题:[CVPR'21] LIIF文章及源码理解

      本文链接:https://www.haomeiwen.com/subject/wpeqertx.html