Swin-Transformer中的相对位置编码
1 |
|
什么是位置编码
首先在transformer中,patch之间是没有位置关系的,因此,不管猫在狗左边还是在狗右边,对于没有添加位置编码的transformer来说,都是一样的,在某些情况下是不应的,因此引入了位置编码。
绝对位置编码,其实就是人手工设计的一种位置编码, 它不随着网络学习,而是希望网络能够学会这种位置编码。
相对位置编码,直白的理解就是,只要一个元素在我左边m个,下边n个,只要是这种相对位置,那么贡献应该是一样的。
在Swin-Transformer中,使用的window窗口大小为 $ 7\times7$,那么在一行方向上有[-6, 6]共13种相对位置关系,在一列上也是13种相对位置关系,那么总共有$13 \times 13$种相对位置关系。
1 |
|
这里便是为这169种位置关系生成一个table,只要具有相同的位置关系,都使用同样的可学习的位置编码。
下面便是生成这样的索引:
1 |
|
生成网格,分别表示行和列,比如coords[0,0,:] = [0, 0, 0, 0,… ,0]也就是说第一行的行号都为0,coords[1,0,:] = [0, 1, 2, 3,… ,6],也就是说第一行的列号分别是0, 1, … ,6
1 |
|
然后因为一个patch会和所有位置做自注意力, 那么就是任一个位置都会和所有位置产生一个相对位置所以relative_coords的形状为[2, 49, 49]
然后其实这样的位置关系已经做好了,但是我们要通过查表来获得位置编码的参数,而relative_position_bias_table索引是[0, 168],但是我们现在的到索引都是[-6, 6]为了能够查表,需要将其做一个转化
1 |
|
首先把范围转换到[0, 12],然后呢,当出现[2, 0] 和 [0,2]时如何区分呢,其实就是不把两个位置单独相加,而是做一个乘法再相加,也就是代码的第三行,这样做后$[2, 0]\rightarrow2 * 13 + 0=26, [0, 2]\rightarrow 0 * 13 + 2= 2$,这样就可以分别查到不同的位置编码参数。nice