AFT(Attention-free Transformer)

Junity 发布于 11 天前 291 次阅读 最后更新于 10 天前 695 字 预计阅读时间: 3 分钟


AI 摘要

本研究提出了AFT(Attention-free Transformer),通过简化注意力机制的设计,有效降低了空间复杂度至O(Td),同时保持模型性能。分析表明AFT可视为逐通道多头注意力的变体,并结合局部注意力机制衍生出计算高效的AFT-local变体。其中AFT-free在s=1时实现线性时间复杂度,为图像处理任务提供了高效解决方案。

传统的注意力机制需要对 QK 中的每对元素进行计算以得到注意力评分,在计算时需要得到一个注意力矩阵,这导致其空间开销的增长速度是 O(T2) 的。AFT通过,将空间复杂度降低到了 O(Td) ,其中 d 为特征向量长度。并且通过忽略长距离向量间的注意力,AFT还可以将时间复杂度降低到 O(Tsd) ,其中 s 为注意力区间的长度,在图像处理中很有用。
aft.png

AFT

AFT的过程如下:对于输入向量 XRTd ,首先和传统自注意力机制一样,通过三个线性变换得到 Q , K , V 三个量:

Q=XWQK=XWKV=XWV

然后,通过下面的公式来生成结果:

Y=f(X);Yt=σq(Qt)t=1Texp(Kt+wt,t)Vtt=1Texp(Kt+wt,t)

其中, wt,t 是一个可以学习的参数,作用和位置编码类似。
初看这个公式可能很难理解为什么AFT要这样做,我们将其按特征向量的维度展开,使用上标 i 来表示特征向量的第 i 维,那么公式可以写成下面的形式:

Yti=ati,Viati=σq(Qti)exp(Kti+wt,t)t=1Texp(Kti+wt,t)

然后 Yt=Concat(Yt1,Yt2,...,Ytd) 。可以发现上面的式子还可以进一步写成下面的形式:

Yti=g(Ki,Qti)VYt=Concat(Yt1,Yt2,...,Ytd)

下面是MHA的公式:

fi(X)=Score(Q,K)Vf(X)=Concat(f1(X),f2(X),...,fn(X))

可以看出AFT实际上类似一个逐特征通道进行的MHA。

AFT的计算复杂度

在公式 Yti=g(Ki,Qti)V 中,计算一个 Yti 的时间复杂度为 O(T2) ,这是因为计算 ati 的复杂度是 O(T) ,与 Vi 相乘后就是 O(T2) 。因此总的时间复杂度就是 O(T2d) 。所以上面的原始AFT相当于Transformer在时间复杂度上没有改进。

但在空间复杂度上,Transformer需要计算一个 T2 大小的注意力矩阵,而在AFT中没有了这个操作,而是使用了类似 Softmax

AFT 变体

AFT-Local

论文作者通过实验,发现Transformer应用在图片上时更加注重局部性而忽略长程的注意力,因此作者提出了 AFT-local ,通过仅保留局部的注意力来降低计算的复杂度。

AFT-local 的原理是改造了 wt,t:

wt,t={wt,t|tt|<s0otherwise

其中,s 为注意力窗口的长度。在这种情况下,AFT的时间复杂度降为了 O(Tsd)

AFT-Free

当上面的 s=1 时,就得到了AFT-Free,这是一个时间复杂度为线性的算法。

此作者没有提供个人介绍。
最后更新于 2025-04-16