相对位置编码

Swin-Transformer中的相对位置编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
self.relative_position_bias_table = nn.Parameter(
  torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

什么是位置编码

首先在transformer中,patch之间是没有位置关系的,因此,不管猫在狗左边还是在狗右边,对于没有添加位置编码的transformer来说,都是一样的,在某些情况下是不应的,因此引入了位置编码。

绝对位置编码,其实就是人手工设计的一种位置编码, 它不随着网络学习,而是希望网络能够学会这种位置编码。

相对位置编码,直白的理解就是,只要一个元素在我左边m个,下边n个,只要是这种相对位置,那么贡献应该是一样的。

在Swin-Transformer中,使用的window窗口大小为 $ 7\times7$,那么在一行方向上有[-6, 6]共13种相对位置关系,在一列上也是13种相对位置关系,那么总共有$13 \times 13$种相对位置关系。

1
2
self.relative_position_bias_table = nn.Parameter(
  torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

这里便是为这169种位置关系生成一个table,只要具有相同的位置关系,都使用同样的可学习的位置编码。

下面便是生成这样的索引:

1
2
3
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww

生成网格,分别表示行和列,比如coords[0,0,:] = [0, 0, 0, 0,… ,0]也就是说第一行的行号都为0,coords[1,0,:] = [0, 1, 2, 3,… ,6],也就是说第一行的列号分别是0, 1, … ,6

1
2
3
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

然后因为一个patch会和所有位置做自注意力, 那么就是任一个位置都会和所有位置产生一个相对位置所以relative_coords的形状为[2, 49, 49]

然后其实这样的位置关系已经做好了,但是我们要通过查表来获得位置编码的参数,而relative_position_bias_table索引是[0, 168],但是我们现在的到索引都是[-6, 6]为了能够查表,需要将其做一个转化

1
2
3
4
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

首先把范围转换到[0, 12],然后呢,当出现[2, 0] 和 [0,2]时如何区分呢,其实就是不把两个位置单独相加,而是做一个乘法再相加,也就是代码的第三行,这样做后$[2, 0]\rightarrow2 * 13 + 0=26, [0, 2]\rightarrow 0 * 13 + 2= 2$,这样就可以分别查到不同的位置编码参数。nice