Fast Transformer Decoding: One Write-Head is All You Need 论文阅读

Fast Transformer Decoding: One Write-Head is All You Need 论文阅读

Posted by nothin on November 12, 2025

Fast Transformer Decoding: One Write-Head is All You Need 论文阅读

本篇文章在MHA的基础上做了修改,不使用多头注意力机制,而是使用多query机制,所有的query共享一个k,v头。所以这带来性能提升时必然的,因为减少了计算。所以在不降低原有的准确率似乎是重点。

刚接触这个领域,如有错误,希望能得到您的指正。

introduction

首先作者提出了Transformer 依靠注意力层在序列之间和序列之间传递信息。 Transformer 的一大挑战是增量推理(incremental inference)的速度。现代计算硬件上增量 Transformer 推理的速度受到重新加载attention layers状态的大“key”和“value”张量所需的内存带宽的限制。作者提出一种架构变体(多查询注意力),该变体极大地提高了推理速度,而质量仅略有下降。

background

这里的实例中省略了除以维度$\sqrt{d_k}$ ,因为这是一个标量计算对整个计算过程的时间复杂度不会产生影响。

点乘注意力机制

1
2
3
4
5
6
7
8
9
10
11
12
def DotProductAttention (q , K, V) :
	'''Dot−Product Attention on one query.
	Args :
		q : a vector with shape [k]
		K: a matrix with shape [m, k]
		V: a matrix with shape [m, v]
	Returns :
	    y : a vector with shape [v]
	'''
    logits = tf.einsum("k , mk−>m" , q , K)
    weights = tf.softmax(logits)
    return tf.einsum("m, mv−>v", weights , V)

batched多头注意力

transformer模型使用h个不同的注意力头,他们在计算时互相独立,只有在最终合并输出时才会产生关联。同时,将多个query合并在一起使用批处理的方式会高效得多。这里使用mask来防止后向信息流。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def MultiheadAttention(
	X, M, mask, P_q, P_k, P_v, P_o):
    '''Multi-head Attention on one query
    Args:
    	X : a vector with shape [b, n, d]
    	M : a matrix with shape [b, m, d]
    	mask: a tensor with shape [b, h, n, m]
    	P_q: a tensor with shape [h, d, k]
    	P_k: a tensot with shape [h, d, k]
    	P_v: a tensor with shape [h, d, v]
    	P_o: a tensor with shape [h, d, v]
    Returns:
    	y: a vector with shape [b, n, d]
    '''
    Q = tf.einsum("bnd,hdk -> bhnk", X, P_q)
    K = tf.einsum("bmd,hdk -> bhmk", M, P_k)
    V = tf.einsum("bmd,hdv -> bhmv", M, P_v)
    logits = tf.einsum("bhnk,bhmk -> bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    o = tf.einsum("bhnm,bhmv -> bhnv", weight, V)
    y = tf.einsum("bhnv,hdv -> bnd", o, P_o)
    return y

维度符号说明

符号 含义 说明
b batch size 批次大小,一次处理多少个独立序列
n query序列长度 当前序列中有多少个位置需要计算attention
m key/value序列长度 用于计算attention的上下文长度
d 模型维度 输入/输出的embedding维度
h 注意力头数 多少个并行的attention头
k key/query维度 每个头的key和query维度(通常 k = d/h)
v value维度 每个头的value维度(通常 v = d/h)

这里介绍了典型值的说明,同时做出了简化

  • $m=n$
  • $k=v=d/h$
  • $n <= d$

运算的总的复杂度是$O(bnd^2)$ ,内存访问的复杂度是$O(bnd + bhn^2 + d^2)$,第一项由 X、M、Q、K、V、O 和 Y 决定,第二项由 Logits 和weight决定,第三项由投影张量 P_q、P_k、P_v 和 P_o 决定。

将二者做除法,内存访问与算术运算的比率为$O(1/k + 1/(bn))$ ,这和现代gpu/tpu想匹配,因为计算能力可能比内存访问速度高两个数量级。

增量多头注意力机制

某些情况下,由于数据依赖性使得无法并行处理来自多个位置的query。在训练过程中,由于所有位置的token已知,可以使用上面的并行方式进行处理,但是在训练好模型进行推理之后,一个位置的输出会影响后续位置的输出,这导致无法并行计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def MultiheadSelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """
    Multi-head Self-Attention
    
    Args:
        x: a tensor with shape [b, d]          
        prev_K: tensor with shape [b, h, m, k] 
        prev_V: tensor with shape [b, h, m, v] 
        P_q: a tensor with shape [h, d, k]     
        P_k: a tensor with shape [h, d, k]     
        P_v: a tensor with shape [h, d, v]     
        P_o: a tensor with shape [h, d, v]    
    
    Returns:
        y: a tensor with shape [b, d]        
        new_K: tensor with shape [b, h, m+1, k]
        new_V: tensor with shape [b, h, m+1, v] 
    """
   
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    new_K = tf.concat([
        prev_K, 
        tf.expand_dims(
            tf.einsum("bd,hdk->bhk", x, P_k),  
            axis=2  
        )
    ], axis=2)  
    
    new_V = tf.concat([
        prev_V,  
        tf.expand_dims(
            tf.einsum("bd,hdv->bhv", x, P_v),  
            axis=2  
        )
    ], axis=2) 
    
    logits = tf.einsum("bhk,bhmk->bhm", q, new_K)
    weights = tf.softmax(logits)
    o = tf.einsum("bhm,bhmv->bhv", weights, new_V)
    y = tf.einsum("bhv,hdv->bd", o, P_o)
    
    return y, new_K, new_V

维度解释

变量 形状 含义
x [b, d] 当前步的输入(batch中每个序列的当前token)
prev_K [b, h, m, k] 历史缓存:已生成m个token的Key
prev_V [b, h, m, v] 历史缓存:已生成m个token的Value
q [b, h, k] 当前token的Query(h个头)
new_K [b, h, m+1, k] 更新后的Key缓存(增加了当前token)
new_V [b, h, m+1, v] 更新后的Value缓存(增加了当前token)
logits [b, h, m+1] 当前Query对所有(m+1)个Key的分数
weights [b, h, m+1] Attention权重分布
o [b, h, v] h个头的输出
y [b, d] 当前位置的最终输出

性能分析,使用相同的上述假设,算数运算的复杂度仍然是$O(bnd^2)$,内存访问总量是$O(bn^2d + nd^2)$,第一项由 K 和 V 决定,第二项由 P_q、P_k、P_v 和 P_o 决定。此时比率为$O(n/d + 1/b)$,当$n\approx d$ 或者$b\approx 1$ 时,该比率接近为1,此时内存带宽成为现代硬件的瓶颈。为了是这一项尽可能减小,可以增加批量大小b,这是很显然的,$n/d$这项如何减小是比较困难的。这一项与重复加载k和v张量有关。一种方式通过限制序列长度n,另一种方式是通过减少token关注的邻居位置,或者使用其他方式压缩位置信息来减少被关注的位置数量(就是不计算某token与之前的全部token的attention分数)。而在本文中,是使用一种正交的方式,即删除 heads 维度,但是保存query的 heads 维度。

Multi-Query Attention

多头注意力由多个注意力层(头)组成,并与查询、键、值和输出上的不同线性变换并行。多查询注意力是相同的,只是不同的头共享一组键和值。 (增量)多查询(自)注意力的代码与上面列出的多头注意力的代码相同,只是我们从 tf.einsum 方程中删除了字母“h”,其中它代表 K、V、Pk 或 Pv 的“头”维度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """
    Multi-Query Attention
    
    Args:
        X: [b, n, d] - batch中n个query位置的输入
        M: [b, m, d] - batch中m个key/value位置的输入
        mask: [b, h, n, m] - 注意力掩码
        P_q: [h, d, k] - h个query投影矩阵
        P_k: [d, k] - 单个key投影矩阵(所有头共享)
        P_v: [d, v] - 单个value投影矩阵(所有头共享)
        P_o: [h, d, v] - h个output投影矩阵
    
    Returns:
        Y: [b, n, d] - 输出
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,dk->bmk", M, P_k)
    V = tf.einsum("bmd,dv->bmv", M, P_v)
    logits = tf.einsum("bhnk,bmk->bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y


def MultiquerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """
    Multi-Query Self-Attention (增量版本)
    
    Args:
        x: [b, d] - 当前token的输入
        prev_K: [b, m, k] - 之前缓存的Keys(所有头共享)
        prev_V: [b, m, v] - 之前缓存的Values(所有头共享)
        P_q: [h, d, k] - h个query投影矩阵
        P_k: [d, k] - 单个key投影矩阵(所有头共享)
        P_v: [d, v] - 单个value投影矩阵(所有头共享)
        P_o: [h, d, v] - h个output投影矩阵
    
    Returns:
        y: [b, d] - 当前位置的输出
        new_K: [b, m+1, k] - 更新后的Key缓存
        new_V: [b, m+1, v] - 更新后的Value缓存
    """
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    new_K = tf.concat([
        prev_K,
        tf.expand_dims(tf.einsum("bd,dk->bk", x, P_k), axis=1)
    ], axis=1)
    new_V = tf.concat([
        prev_V,
        tf.expand_dims(tf.einsum("bd,dv->bv", x, P_v), axis=1)
    ], axis=1)
    logits = tf.einsum("bhk,bmk->bhm", q, new_K)
    weights = tf.softmax(logits)
    o = tf.einsum("bhm,bmv->bhv", weights, new_V)
    y = tf.einsum("bhv,hdv->bd", o, P_o)
    return y, new_K, new_V

符号说明:

  • b: batch size (批次大小)
  • n: query序列长度
  • m: key/value序列长度
  • d: 模型维度
  • h: attention头数
  • k: key/query维度
  • v: value维度

Multi-Query Attention的关键区别:

  • P_k和P_v没有h维度 → 所有头共享同一个Key/Value投影
  • prev_K和prev_V也没有h维度 → KV缓存大小减少h倍

使用前述假设计算复杂度,算术运算的复杂度仍然是$O(bnd^2)$。内存访问的复杂度是$O(bnd + bn^2d +nd^2)$,第一项由 x、q、o 和 y 产生,第二项由 K 和 V 产生,第三项由 P_q、P_k、P_v、P_o 产生。将内存除以计算,发现内存访问与算术运算的比率为$O(1/d + n/(dh) + 1/b)$,将计算量 n d 减少了 h 倍。理论上,给定大批量大小 b,这应该会显着提高增量生成的性能。

实验和结果

使用attention is all you need的相同任务,并将其作为baseline进行对比。

实验设置项
数据集 WMT 2014 English-German
模型架构 Encoder-Decoder Transformer
层数 6 layers
模型维度 (d_model) 1024
前馈网络维度 (d_ff) 4096
注意力头数 (h) 8
Key维度 (d_k) 128
Value维度 (d_v) 128
位置编码 Learned positional embeddings
权重共享 Token embedding 和 output layer 共享
总参数量 211 million
训练步数 100,000 steps (~20 epochs)
Batch size 128 examples
序列长度 256 tokens (input) + 256 tokens (target)
序列构造方式 多个句子拼接至256 tokens
训练硬件 32-core TPUv3 cluster
训练时长 ~2 hours per model

在“多query”模型中,将模型中的所有注意力层替换为多查询注意力,这包括编码器自注意力层、解码器自注意力层和编码器-解码器注意力层。将前馈隐藏层从 4096 扩大到 5440,以使总参数计数等于baseline。

为了证明局部注意力和多查询注意力是正交的,还训练了baseline和多查询模型的“局部”版本,其中解码器自注意力层(而不是其他注意力层)将注意力限制在当前位置和之前的 31 个位置。

减小 K 和 V 大小的一种更简单的替代方法是减少头 h 的数量和/或减小键和值的维度 k 和 v。该论文训练了几个这样的模型进行比较,同时再次扩大前馈隐藏层以使总参数计数等于基线。对于baseline,使用 6 层模型,$d_{model} = 1024, d_{ff} = 8192,h = 8,d_k = d_v = 128$。baseline和所有变体的总参数计数为 1.92 亿。

模型质量

下表显示了机器翻译实验的结果。使用贪婪最大似然解码对开发集进行解码,并使用 sacrebleu sacrebleu -t wmt13 -l en-de -tok intl计算 BLEU 分数。还列出了开发集上每个sub token的困惑度。根据这两个指标,多查询注意力模型似乎比baseline稍差,但更接近减少 h、dk 和 dv 的任何替代方案。

Attention Type h d_k, d_v d_ff ln(PPL) (dev) BLEU (dev) BLEU (test) beam 1 / 4
multi-head 8 128 4096 1.424 26.7 27.7 / 28.4
multi-query 8 128 5440 1.439 26.5 27.5 / 28.5
multi-head local 8 128 4096 1.427 26.6 27.5 / 28.3
multi-query local 8 128 5440 1.437 26.5 27.6 / 28.2
multi-head 1 128 6784 1.518 25.8 -
multi-head 2 64 6784 1.480 26.2 26.8 / 27.9
multi-head 4 32 6784 1.488 26.1 -
multi-head 8 16 6784 1.513 25.8 -

通过使用贪婪解码和beam search(beam 4,α = 0.6)对测试集进行解码来验证结果,并使用 sacrebleu sacrebleu -t wmt14 -l en-de -tok intl进行评估。同样,多查询模型的表现与baseline类似,并且实际上在使用 Beam-4 解码时具有最高的 BLEU 分数 (28.5)。

下表显示了十亿字语言建模基准的结果。模型是通过开发集上的每个字(而不是每个sub token)的困惑度来评估的。结果与翻译结果相似。多查询注意力模型比基线稍差,但明显优于涉及减少 h、dk 和 dv 的任何替代方案

Attention Type h d_k, d_v d_ff dev-PPL
multi-head 8 128 8192 29.9
multi-query 8 128 9088 30.2
multi-head 1 128 9984 31.2
multi-head 2 64 9984 31.1
multi-head 4 32 9984 31.0
multi-head 8 16 9984 30.9

速度

下表显示了各种模型的训练和推理时间。训练和推理速度均在一个 TPUv2(8 核)上进行评估。基础模型的训练步骤(由 32,768 个输入token和 32,768 个目标token组成)花费了 433 毫秒,多查询模型花费了 425 毫秒。除以 32,768,发现每个(输入token + 目标token)的训练时间为 13.2μs。

使用 128 个token的源序列长度和 128 个目标序列长度对一批 1024 个序列(每个核心 128 个)运行增量贪婪推理。对于baseline模型,模型的编码器部分花费了 222 毫秒,解码器的每个增量步骤花费了 47 毫秒。除以相应的token数量,发现编码器的平均推理时间为每个token 1.7μs,解码器的平均推理时间为每个token 46μs。对于多查询模型,编码器每步花费 195ms,解码器每步花费 3.9ms,摊销的每个token成本分别为 1.5μs 和 3.8μs。

Attention Type Training Inference enc. + dec. Beam-4 Search enc. + dec.
multi-head 13.2 1.7 + 46 2.0 + 203
multi-query 13.0 1.5 + 3.8 1.6 + 32
multi-head local 13.2 1.7 + 23 1.9 + 47
multi-query local 13.0 1.5 + 3.3 1.6 + 16

列出的值以每个输出token的 TPUv2 微秒为单位。

本文提出了训练可以并行计算,并且算术强度比内存访问多的多,和gpu/tpu的特性相符合,但是推理时无法进行并行计算,因此内存会成为推理时的瓶颈。



评论