论文:Recurrent Feature Reasoning for Image Inpainting

paper:Recurrent Feature Reasoning for Image Inpainting

code:OpenSourceCode:PyTorch

解析参考文章

Recurrent Feature Reasoning for Image Inpainting 作者:nachifur

图像修复网络-RFRNet网络结构简介 作者:softmat

一、为什么要做这个研究?

目前的图像补全可以处理常规或小的缺陷,对于补全较大的连续孔洞仍然困难。

二、本文解决的什么问题

采用Recurrent的方式,逐步的补全较大的连续孔洞。

三、怎么做的?

本文对inpainting的整个过程做了充分的分解。读的过程中感觉论文很复杂,但实质的点没那么多。Method部分有5个模块,让人感觉工作量很大。可以学习写作技巧。

大致的思想:编码,迭代多次:[识别空洞->特征推理(填补空洞)],融合多次推理的特征,解码。

算法伪代码

四、Partial Convolution Layer

4.1 图解

PartialConv最开始也是做修补的,但是没有在原文找到便于理解的图。本文提供的这个图很有水平,直接不用看论文,就能明白文章做了什么。先下采样(encode),然后循环多次填充空洞,然后上采样(decode)。

preview

​ 逐步的填充空洞

背景:我们人类想让inpainting逐步的填充空洞。但是传统的inpainting,mask仅仅作为网络开始的输入,逐步的卷积是否会将特征空洞变小?这个其实是黑箱,我们也不知道网络是不是这么干的。

人类的知识:杠精可能说:目前深度学习已经很牛了,大家现在都是在水文章。那么目前人类到底是否存在智慧?人类的智慧是否已经不足以胜任网络设计?

深度学习目前还是工具,人类设计仍在起作用:PartialConv其实就是在做这件事:逐步的卷积将特征空洞变小。这其实就是inpainting的指导思想。之前做inpainting的人们都知道这个思想,但是对于大家仅仅也只是指导,而PartialConv巧妙的直接将这个指导思想用于inpainting。PartialConv的成功说明了:人类设计仍在起作用

4.2 公式

具有美感的公式,就是如此简洁(赞扬PartialConv)。

partial convolution operation:(加大边缘的权重)

preview

mask update function:(通过卷积,逐步的扩大边缘)

img

4.3 code实现

直接上伪代码,方便理解,不可直接使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 伪代码没有bias。可用的代码查阅:https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super(PartialConv2d, self).__init__(*args, **kwargs)
self.w = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
self.slide_winsize = self.w.shape[2] * self.w.shape[3]

def forward(self, input, mask=None):
# w 是不更新的,权都是1。
self.update_mask = F.conv2d(mask, self.w, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

# 如果mask的非空洞为1,空洞为0。那么这里的update_mask经卷积之后是:0,1,2...9。
# 空洞边缘:mask_ratio>1;无孔洞:mask_ratio=1;孔洞:mask_ratio=1e8。
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)

# 这里的clamp直接将0,1,2...9截断为0,1。其实就是将mask的边缘变为1,缩小空洞。
# 作者使用卷积来更新mask。整数权重conv+clamp的配合,秒不可言!其实就是图像的形态学膨胀,但是作者的实现真妙!
self.update_mask = torch.clamp(self.update_mask, 0, 1)

# 将空洞部分的值,变为0。
# 空洞边缘:mask_ratio>1;无孔洞:mask_ratio=1;孔洞:mask_ratio=0。
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

# 这里调用的父类,nn.Conv2d(传统的卷积),对特征torch.mul(input, mask)进行处理。
# 如果仅仅有一层,这与传统inpainting是一样的,但是对于PartialConv2d,下一次的mask是变化的。
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask))

# 对边缘进行加权,让网络关注空洞边缘的修补
output = torch.mul(raw_out, self.mask_ratio)
return output, self.update_mask

源码

官方的代码只给模型结构,但是不给inpainting的全部训练代码,nvidia的商业机密?

NVIDIA论文1

NVIDIA论文2

NVIDIA开源Mask掩码数据集

五、 Area Identification和Feature Reasoning

Area Identification:两层PartialConv2d

Feature Reasoning:RFR-module,其实就是普通的UNet加了一个mask注意力。而在Image Inpainting for Irregular Holes Using Partial Convolutions 中,结构是Unet,然后conv->PartialConv2d。

为啥不用Unet+PartialConv2d?反而拆成Area Identification+Feature Reasoning+Recurrent?全用别人的不好吧,这么搞便于讲故事?如果熟悉inpainting的,可能感觉文中的3.1.1 Area Identification和3.1.2 Feature Reasoning,没啥贡献?读完文中的3.1.2节,感觉这段一笔带过就行。但这也是写作的功底,没啥咱也要写出花。会讲故事才是硬道理。

六、Feature Merging

preview

普通的concat和sum都存在一些问题:

However, using convolutional operations to do so limits the number of recurrences, because the number of channels in the concatenation is fixed. Directly summing all feature maps removes image details, because the hole regions in different feature maps are inconsistent and prominent signals are smoothed.

怎么解决上述问题?也是sum,但是仅仅加有效的pixel。正常的sum,分母为N。很巧妙。

img

七、Knowledge Consistent Attention

这里图与代码不对应,应该把i+1的图去掉。

img

一个图中的不同位置求相似度。(x,y)位置的所有channel是一个向量。

img

然后进行平滑处理,代码使用的是avg_pool2d。

img

为了关联上一个迭代,和本次迭代,采用 [公式] 平衡。

img

然后,使用注意力score重建特征。

img

最后,将重建特征和输入特征concat输出。

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class KnowledgeConsistentAttention(nn.Module):
def __init__(self, patch_size = 3, propagate_size = 3, stride = 1):
super(KnowledgeConsistentAttention, self).__init__()
self.patch_size = patch_size
self.propagate_size = propagate_size
self.stride = stride
self.prop_kernels = None
self.att_scores_prev = None
self.masks_prev = None
self.ratio = nn.Parameter(torch.ones(1))
def forward(self, foreground, masks):
bz, nc, h, w = foreground.size()
if masks.size(3) != foreground.size(3):
masks = F.interpolate(masks, foreground.size()[2:])
background = foreground.clone()
background = background
conv_kernels_all = background.view(bz, nc, w * h, 1, 1)
conv_kernels_all = conv_kernels_all.permute(0, 2, 1, 3, 4)
output_tensor = []
att_score = []
for i in range(bz):
# 一张图,不同位置做attention
feature_map = foreground[i:i+1]
conv_kernels = conv_kernels_all[i] + 0.0000001
norm_factor = torch.sum(conv_kernels**2, [1, 2, 3], keepdim = True)**0.5
conv_kernels = conv_kernels/norm_factor
# 得到相似度,也可以理解为编码,对应(4)
conv_result = F.conv2d(feature_map, conv_kernels, padding = self.patch_size//2)
# 平滑相似度,对应(5)
conv_result = F.avg_pool2d(conv_result, 3, 1, padding = 1)*9
# 沿通道进行归一化处理,score
attention_scores = F.softmax(conv_result, dim = 1)
# 对应(6),这里分母可以看成归一化,因为这里与paper中不一样,self.ratio初始为1。
# mask的作用是,限定像素是有效的。paper中:Formally, if the pixel at location (x, y) is a valid pixel in the last recurrence, ...
if self.att_scores_prev is not None:
attention_scores = (self.att_scores_prev[i:i+1]*self.masks_prev[i:i+1] + attention_scores * (torch.abs(self.ratio)+1e-7))/(self.masks_prev[i:i+1]+(torch.abs(self.ratio)+1e-7))
att_score.append(attention_scores)
# 对应(8),可以看成解码
feature_map = F.conv_transpose2d(attention_scores, conv_kernels, stride = 1, padding = self.patch_size//2)
final_output = feature_map
output_tensor.append(final_output)
# 上一次迭代的score和mask
self.att_scores_prev = torch.cat(att_score, dim = 0).view(bz, h*w, h, w)
self.masks_prev = masks.view(bz, 1, h, w)
return torch.cat(output_tensor, dim = 0)

class AttentionModule(nn.Module):
def __init__(self, inchannel, patch_size_list = [1], propagate_size_list = [3], stride_list = [1]):
assert isinstance(patch_size_list, list), "patch_size should be a list containing scales, or you should use Contextual Attention to initialize your module"
assert len(patch_size_list) == len(propagate_size_list) and len(propagate_size_list) == len(stride_list), "the input_lists should have same lengths"
super(AttentionModule, self).__init__()

self.att = KnowledgeConsistentAttention(patch_size_list[0], propagate_size_list[0], stride_list[0])
self.num_of_modules = len(patch_size_list)
self.combiner = nn.Conv2d(inchannel * 2, inchannel, kernel_size = 1)
def forward(self, foreground, mask):
outputs = self.att(foreground, mask)
outputs = torch.cat([outputs, foreground],dim = 1)
# 对应(9)式
outputs = self.combiner(outputs)
return outputs

八、Loss Functions

代码:loss_G = ( tv_loss 0.1+ style_loss 120+ preceptual_loss 0.05+ valid_loss 1+ hole_loss 6)。代码的loss和系数与Image Inpainting for Irregular Holes Using Partial Convolutions 中的一样。*paper:没有tvloss,系数与代码的也不一样。

preview

为什么要分hole和valid?这两个loss在PartialConv中使用,但是没有找到原因。本文中有这么一句:

This kind of loss function combination also enables efficient training due to the smaller number of parameters to update.

九、对比实验

KCA与其他的注意力的对比:

preview

更多的IterNums,没有提升性能。

preview

将RFR模块更深地移动,计算成本显着降低,而性能仅受影响:

img

不同特征融合:

preview

用RFR模块替换edge-connect中的8个残余块:(b,e是edge-connect的输出)

img

9.1 实验结果

在place2,celeba,Paris StreetView上的对比:

preview

img

作者自己跑的别人的实验(在相同配置下,paper中说全部模型达到收敛,这个挺难判断的)。

img

不加KCA注意力,与其他模型的参数量对比。(你的是迭代的,这里为什么不和PConv比时间?)

img

时间的对比:(与一些Coarse-To-Fine的多阶段模型对比)

The inference time of our model for each image is usually between 85 and 95 ms, which is also faster than several state-of-the-art methods (e.g. [30, 14, 11]).

拿自己长处和别人的短处比?

9.2 存在的问题:

没有提供place2的预训练模型。

place2多次迭代出现伪影:

https://github.com/jingyuanli001/RFR-Inpainting/issues/39

十、TvLoss

【CNN基础】常见的loss函数及其实现(一)——TV Loss