本文是用于医学图像配准的空间变化网络(Spatial Transformer Networks, STN)的论文笔记。
STN 可以插入到已有的卷积神经网络结构中,让 CNN 具有空间变换的能力,不仅可以让网络能够提取出一张图片中所关心的区域,而且还可以把图片转换为规范的形式,以更方便下层网络进行处理。对于多通道的输入来说,产生的变形将会作用于每一个通道。
另一种表述就是 STN 能够根据分类或者其它任务自适应地将数据进行空间变换和对齐。比如 CNN 在分类时,通常需要考虑输入样本的局部性、平移不变性、缩小不变性和旋转不变性等,以提高分类的准确度,这些不变性就对应着图像的裁剪、平移、缩放和旋转等,而 STN 可以实现这些。
上图的 (a) 列是经过扭曲变形的 MNIST 手写数字图像,(b) 列图像中的框是 STN 中的定位网络预测出来的变换(可以理解为感兴趣的区域),(c) 列图像是 STN 的输出结果,即经过规范化和对齐后的手写数字图像,(d) 列是分类预测结果。
空间变换器(spatial transformer)可以分为三个部分:第一部分是定位网络(localisation network),它可以根据输入的图像/特征图得到一组空间变换的参数;第二部分是网格生成器(grid generator),它可以根据第一部分预测出来的空间变换参数,生成一个采样网格,即输出图像/特征图的每个点是从输入图像/特征图中哪些点采样而来的;第三部分是采样器(sampler),它可以将采样网格作用在输入特征图上,并产生相应的输出特征图。
上图是空间变换器的示意图,U 是输入特征图,V 是输出特征图。
定位网络
可以根据输入特征图 U 来生成空间变换 $T_\theta$ 的参数 $\theta$,当变换的类型不同时,参数 $\theta$ 的大小也是不同的,比如在二维仿射变换中 $\theta$ 是 6 维的。
网格生成器
假设输出特征图对应的网格是一个规则网格,即网格的每个坐标值都是整数,且相邻坐标之间的间隔一致。通过定位网络,我们已经得到了输出特征图中的每个点对应输入特征图中的哪些点。这样我们就可以利用输出特征图的网格和空间变换参数计算得到所对应的输入特征图对应的采样网格(即由目标坐标得到源坐标),如下公式所示:
其中 s 是 source 的缩写,表示源图像中的坐标;t 是 target 的缩写,表示目标图像中的坐标。$A_\theta$ 是二维仿射变换矩阵。
采样器
直接简单地从源像素数组中复制像素值是不可行的,因为仿射变换后的坐标 $(x^s,y^s)$ 可能为实数,但是像素位置坐标必须是整数。为了解决像素值缺失问题,必须进行插值。插值核函数很多,源码中选择了论文中提供的第二种插值方式——双线性插值。双线性插值的示意图和计算公式如下图所示:
但是上图中的计算公式非常不优雅,DeepMind 在论文利用 max 与 abs 函数,改写成一个简洁、优雅的插值等式:
$$
V_i^c=\sum_n^H\sum_m^W U_{nm}^c\max(0,1-|x_i^s-m|)\max(0,1-|y_i^s-n|)
$$
其中 $V_i^c$ 表示输出特征图中第 $c$ 个通道、第 $i$ 个像素(即坐标 $(x_i^t,y_i^t$)处的像素值;$U_{nm}^c$ 表示输入特征图中第 $c$ 个通道的坐标 $(n,m)$ 处的像素值;H 和 W 分别为输入特征图的高和宽。两个 $\sum$ 实际上只筛出了四个邻近插值点,虽然写法简洁,但白循环很多,所以源码中选择了直接算4个点,而不是用循环筛。为了让上述公式可微,在论文中还定义了该公式的导数,这里就不详细介绍了。
以上部分是从他人的博客中看到的,对于那个优雅的插值公式不是很懂。
上图中输出特征图 V 对应的网格是一个规范网格,输入特征图 U 对应的网格是规范网格经过仿射变化后的一个扭曲变形的网格。定位网络的目的是得到规范网格和扭曲网格之间的对应关系,网格生成器的目的就是根据规整网格得到扭曲网格,采样器的目的是根据以上信息由输入特征图 U 得到输出特征图 V。
- 本文作者: 俎志昂
- 本文链接: zuzhiang.cn/2020/02/28/STN/
- 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!