论文笔记 - FCN 用于图像分割的全卷积网络

用全卷积网络做图像分割的鼻祖

FCN

参考资料

Fully Convolutional Networks for Semantic Segmentation
知乎/图像语义分割入门+FCN/U-Net网络解析
知乎/FCN学习:Semantic Segmentation
GitHub/pytorch FCN-32s 模型代码

这是一个端到端的网络,可以输入任意大小的图片。

语义分割面临的问题是semantics和location的统一,全局信息告诉我们semantics,而局部信息告诉我们location。

这篇论文16年的版本增加了更多的tuning、分析和实验结果。

老外写的论文是真滴难度啊。。。

早期的研究都存在以下问题:

  • 小模型限制了性能和感受野;
  • 都是patchwise traning;
  • 都要通过超像素投影、随机条件场、滤波或局部分类的后处理;
  • “interlacing” to obtain dense output(所以老外写的是什么鬼啦!);
  • 多尺度金字塔处理;
  • 非线性函数tanh的饱和问题;
  • ensembles(what???);

Fully Convolutional Networks

再读下去我可能会疯掉,什么鬼!直接上博客了……

嗯……看完博客后再接着读吧。

这一部分主要讲了几个点:

1. 将图像分类网络用到像素级别的分割上

FCN网络转移

思路就是将全连接层用1x1的conv替代嘛,这样最后输出的就不是一个标量了,而是一个feature map,此时对feature map中的每个点做softmax,就可以得到上图的heat map,也就是对feature map中的每个点进行分类。至于如何得到原图大小的heat map,就得拿出上采样了。

(但是输入图像太小的话,会导致到1x1卷积的时候,feature map就缩没了,作者在第一层卷积加了Padding=100(太暴力了)。

2. Upsampling

上采样是为了将heat map恢复成原图大小,从而对每个像素进行预测。

论文中作者用到的上采样方法有shift-and-stich和反卷积。

吐槽一下论文标题:Shift-and-stich is filter dilation、Upsampling is (fractionally strided) convolution、Patchwise training is loss sampling,我寻思你写尼玛呢,正常人看得懂吗?

哎总的来说,作者最后用了反卷积。

然而,直接把heat map整回原图大小就会丢失非常多的信息(毕竟一路conv下来的,信息该丢的全丢了),这个时候就有了论文中的FCN-32s、FCN-16s、FCN-8s,其中8s的性能最好,这就是多层特征融合了。

3. 多层特征融合

特征融合

实属睿智!特征融合直接用element-wise相加,不过可能是因为我活在9102年吧,知道这个情况应该用concat(实际上这也是U-Net的做法)。

参考一下参考资料中对多层特征融合的解释:

  • 对于FCN-32s,直接对pool5 feature进行32倍上采样获得32x upsampled feature,再对32x upsampled feature每个点做softmax prediction获得32x upsampled feature prediction(即分割图)。
  • 对于FCN-16s,首先对pool5 feature进行2倍上采样获得2x upsampled feature,再把pool4 feature和2x upsampled feature逐点相加,然后对相加的feature进行16倍上采样,并softmax prediction,获得16x upsampled feature prediction。
  • 对于FCN-8s,首先进行pool4+2x upsampled feature逐点相加,然后又进行pool3+2x upsampled逐点相加,即进行更多次特征融合。具体过程与16s类似,不再赘述。

可以看看他们的效果图:

特征融合实验

4. 模型

直接上代码,简单明了根本不需要讲的啦!(这里用的backbone是VGG16),最后输出为(H, W, n_class)

几个比较困惑的地方:

  • 第一层卷积padding=100;
  • 第一个全连接层用的kernel_size=7;
  • 最后的反卷积kernel_size=64,stride=32。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class FCN32s(nn.Module):

def __init__(self, n_class=21):
super(FCN32s, self).__init__()
# conv1
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2

# conv2
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4

# conv3
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8

# conv4
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16

# conv5
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32

# fc6
self.fc6 = nn.Conv2d(512, 4096, 7)
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()

# fc7
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()

self.score_fr = nn.Conv2d(4096, n_class, 1)
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
bias=False)

5. 损失函数

论文中居然没有损失函数的公式!!!这也太反人类了,看了下代码,原来是对每个像素点做cross_entropy……

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def cross_entropy2d(input, target, weight=None, size_average=True):
# input: (n, c, h, w), target: (n, h, w)
n, c, h, w = input.size()
# log_p: (n, c, h, w)
if LooseVersion(torch.__version__) < LooseVersion('0.3'):
# ==0.2.X
log_p = F.log_softmax(input)
else:
# >=0.3
log_p = F.log_softmax(input, dim=1)
# log_p: (n*h*w, c)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
log_p = log_p.view(-1, c)
# target: (n*h*w,)
mask = target >= 0
target = target[mask]
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
if size_average:
loss /= mask.data.sum()
return loss