文章名为:Learning Continuous Image Representation with Local Implicit Image Function,简称LIIF,收录在CVPR21年,源码为:https://yinboc.github.io/liif/。由于拜读文章之后,对实现比较感兴趣,所以学习并分享源码实现部分(有所简化,主要为了将代码和原文对应上。):
核心思想:首先,需要得到离散图片的特征,然后通过一个网络,将特征和连续域的坐标映射成目标的预测值。
以EDSR数据集为例,作者构建了一个Encoder(EDSR类),来完成三通道的离散图片到特征的映射。
def conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class EDSR(nn.Module):
def __init__(self):
super(EDSR, self).__init__()
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.out_dim = n_feats
def forward(self, x):
x = self.head(x)
res = self.body(x)
res += x
x = res
return x
代码有所简化,主要结构是卷积(self.head)+16层ResBlock+卷积(self.body)。
然后是这个网络,代码中体现为一个MLP:
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
主要结构是(Linear+ReLU)*4 + Linear共9层。
最简单的情况是:实际在预测的时候,以待预测下标相对最近的离散查询点的坐标来进行的:
另外,结合了三个特性来优化模型表现:
-
Feature unfolding:个人理解:应该是多个离散的像素合成为一个block,增加信息量。
feat = self.encoder(x) # 输入x到前面的EDSR实例,得到特征$z$
feat = nn.functional.unfold(feat, 3, padding=1)
.view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) # (n, c, w, h) -> (n, c*9, block_index1, block_index2),其中,第二维是按照先块后通道来排列的。而且block_index1和block_index2的数值对应为w和h。
-
Local ensemble:个人理解:为了缓解图像边界的突变问题,把各个部分按照距离进行加权平均和,平滑过渡边界区域。
其中,需要先解决中心坐标的问题。以下是根据图像大小来生成对应的中心坐标。
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
数据集生成,主要是dataset的getitem方法:
# 用torchvision来缩放,缩放的方法是双立方插值
def resize_fn(img, size):
return transforms.ToTensor()(
transforms.Resize(size, Image.BICUBIC)(
transforms.ToPILImage()(img)))
# 把图像转成对应的坐标和数值
def to_pixel_samples(img):
""" Convert the image to coord-RGB pairs.
img: Tensor, (3, H, W)
"""
coord = make_coord(img.shape[-2:])
rgb = img.view(3, -1).permute(1, 0)
return coord, rgb
#
def __getitem__(self, idx):
img_lr, img_hr = self.dataset[idx]
p = idx / (len(self.dataset) - 1)
w_hr = round(self.size_min + (self.size_max - self.size_min) * p) # 随着idx的变大,目标图像越来越大。
img_hr = resize_fn(img_hr, w_hr) # 缩放目标图像
if self.augment: # 图像增强
if random.random() < 0.5: # 概率性反转
img_lr = img_lr.flip(-1)
img_hr = img_hr.flip(-1)
if self.gt_resize is not None: # 如果指定了目标图像大小,则缩放目标图像
img_hr = resize_fn(img_hr, self.gt_resize)
# 得到hr的坐标和rgb数值
hr_coord, hr_rgb = to_pixel_samples(img_hr)
# hr部分只随机取出一部分像素?(不放回取样)
if self.sample_q is not None:
sample_lst = np.random.choice(
len(hr_coord), self.sample_q, replace=False)
hr_coord = hr_coord[sample_lst]
hr_rgb = hr_rgb[sample_lst]
# 得到目标图像对应cell的大小。
cell = torch.ones_like(hr_coord)
cell[:, 0] *= 2 / img_hr.shape[-2]
cell[:, 1] *= 2 / img_hr.shape[-1]
# 每次是输入图像,目标图像的坐标,目标图像的cell大小,目标图像。
return {
'inp': img_lr,
'coord': hr_coord,
'cell': cell,
'gt': hr_rgb
}
然后再对应到具体的操作:整体的操作是:构造目标图像的坐标coord_,然后通过这个坐标去采样特征和坐标,包括:从当前输入低分图像的特征feat采样到q_feat;从当前输入低分图像的坐标feat_coord采样得到q_coord。然后构造出相对的坐标rel_coord,再乘上特征对应的shape大小,得到相对的特征空间偏移,作为公式(4)中的
# 先通过函数算出$z$
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - q_coord
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
# cell decoding放在下面解说。
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble: # 为什么要做面积计算的结果翻转?
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
根据相对的坐标,rel_coord可以算出对应部分的面积,进而在最后以公式(4)的比例累加的方式进行计算。(体现在最后8行代码)
- Cell decoding:个人理解:把坐标以及对应的小块大小作为额外信息输入,主要是的加入,实验表明是有效果的。
先把目标图像的cell大小,乘上特征对应的shape大小,变为相对大小rel_cell。这部分就是公式(5)中的。
最终,把这些特征传入MLP(即self.imnet)进行计算,获得预测的值。
网友评论