RE2 这个名称来源于该网络三个重要部分的合体:Residual vectors;Embedding vectors;Encoded vectors;
掌握这个论文,最重要的一个细节点就是了解如何将增强残差连接融入到模型之中。
先来看架构图,如下:
这个架构图很精简,所以不太容易理解。
大体上区分可以分为三层。第一层就是输入层,第二个就是中间处理层,第三个就是输出层。
中间处理层我们可以称之为block,就是画虚线的部分,可以被循环为n次,但是需要注意的是每个block不是共享的,参数是不同的,是独立的,这点需要注意。
其实这个论文比较有意思的点就是增强残差连接这里。架构图在这里其实很精简,容易看糊涂,要理解还是要看代码和公式。
首先假设我们的句子长度为$l$,然后对于第n个block(就是第n个虚线框的部分)。
它的输入和输出分别是:$x^{(n)}=(x_{1}^{(n)},x_{2}^{(n)},...,x_{l}^{(n)})$ 和$o^{(n)}=(o_{1}^{(n)},o_{2}^{(n)},...,o_{l}^{(n)})$;
首先对一第一个block,也就是$x^{(1)}$,它的输入是embedding层,注意这里仅仅是embedding层;
对于第二个block,也就是$x^{(2)}$,它的输入是embedding层(就是初始的embedding层)和第一个block的输出$o^{(1)}$拼接在一起;
紧接着对于n大于2的情况下,也就是对于第三个,第四个等等的block,它的输入形式是这样的;
理解的重点在这里:在每个block的输入,大体可以分为两个部分,第一部分就是初始的embedding层,这个永远不变,第二个部分就是此时block之前的两层的blocks的输出和;这两个部分进行拼接。
这是第一个体现残差的部分。
第二个残差的部分在block内部:
alignment层之前的输入就有三个部分:第一部分就是embedding,第二部分就是前两层的输出,第三部分就是encoder的输出。
这点结合着图就很好理解了。
attention这里其实操作比较常规,和ESIM很类似,大家可以去看之前这个文章。
公式大概如下:
这里有一个细节点需要注意,在源码中计算softmax之前,也是做了类似TRM中的缩放,也就是参数,放个代码:
#核心代码
def __init__(self, args, __):
super().__init__()
self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(args.hidden_size)))
def _attention(self, a, b):
return torch.matmul(a, b.transpose(1, 2)) * self.temperature
融合层,就是对attentino之前和之后的特征进行一个融合,具体如下:
三种融合方式分别是直接拼接,算了对位减法然后拼接,算了对位乘法然后拼接。最后是对三个融合结果进行拼接。
有一个很有意思的点,作者说到减法强调了两句话的不同,而乘法强调了两句话相同的地方。
Pooling层之后两个句子分别得到向量表达:$v_{1}$和$v_{2}$
三个表达方式,各取所需就可以:
简单总结一下,这个论文最主要就是掌握残差连接。
残差体现在模型两个地方,一个是block外,一个是block内;
对于block,需要了解的是,每一个block的输入是有两部分拼接而成,一个是最初始的embeddding,一个是之前两层的输出和。
对于block内,需要注意的是Alignment之前,有三个部分的输入一个是最初始的embeddding,一个是之前两层的输出和,还有一个是encoder的输出。