餐饮企业网站建设方案书,建设工程施工合同范本哪个网站,微信公众号设计网站,wordpress主题百度网盘在学习huggingFace的Transformer库时#xff0c;我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数#xff0c;它被用来加速大模型的Attention计算#xff0c;本文就详细介绍一下它的使用方法#xff0c;核心内容主要参考了torch.nn.functional中该函数的注释…在学习huggingFace的Transformer库时我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数它被用来加速大模型的Attention计算本文就详细介绍一下它的使用方法核心内容主要参考了torch.nn.functional中该函数的注释。
1. Attention计算公式
Attention的计算主要涉及三个矩阵Q、K、V。我们先不考虑multi-head attention只考虑one head的self attention。在大模型的prefill阶段这三个矩阵的维度均为N x dN即为上下文的长度在decode阶段Q的维度为1 x d KV还是N x d。然后通过下面的公式计算attention矩阵 O A t t e n t i o n ( Q , K , V ) s o f t m a x ( Q K T d ) V OAttention(Q,K,V)softmax(\frac{QK^T}{\sqrt d})V OAttention(Q,K,V)softmax(d QKT)V 在真正使用attention的时候我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致它改变了Q、K、V每一行的定义将维度d的向量分成h组变成一个h x dk的矩阵Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵不考虑batch维。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V我们就可以按照self attention的定义计算h个头的attention值。
不过在真正进行大模型推理的时候就会发现KV Cache是非常占显存的所以大家尝试各种手段压缩KV Cache具体可以参考《大模型推理–KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA)这块在torch2.5以上的SDPA中也已经得到了支持。
2. SDPA伪代码
在SDPA的注释中给出了伪代码
def scaled_dot_product_attention(query, key, value, attn_maskNone, dropout_p0.0,is_causalFalse, scaleNone, enable_gqaFalse) - torch.Tensor:L, S query.size(-2), key.size(-2)scale_factor 1 / math.sqrt(query.size(-1)) if scale is None else scaleattn_bias torch.zeros(L, S, dtypequery.dtype)if is_causal:assert attn_mask is Nonetemp_mask torch.ones(L, S, dtypetorch.bool).tril(diagonal0)attn_bias.masked_fill_(temp_mask.logical_not(), float(-inf))attn_bias.to(query.dtype)if attn_mask is not None:if attn_mask.dtype torch.bool:attn_bias.masked_fill_(attn_mask.logical_not(), float(-inf))else:attn_bias attn_maskif enable_gqa:key key.repeat_interleave(query.size(-3)//key.size(-3), -3)value value.repeat_interleave(query.size(-3)//value.size(-3), -3)attn_weight query key.transpose(-2, -1) * scale_factorattn_weight attn_biasattn_weight torch.softmax(attn_weight, dim-1)attn_weight torch.dropout(attn_weight, dropout_p, trainTrue)return attn_weight value可以看出我们实际在使用SDPA时除了query、key和value之外还有另外几个参数attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子一般无需传递。dropout_p表示Dropout概率在推理阶段也不需要传递不过官方建议如下输入dropout_p(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。
先看enable_gqa。前面提到GQA是一种KV Cache压缩方法MHA的KV和Q一样也会有h个头GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型Q的h等于28KV的h等于4相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache但是真正要计算Attention的时候还是需要对齐KV与Q的head数所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作目的就是将QKV的head数统一。在torch2.5以后的版本中我们无需再手动去执行repeat_kv直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv而且速度比自己去做repaet_kv还要更快。
attn_mask和is_causal两个参数的作用相同目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出SDPA会首先构造一个L x S的零矩阵attn_biasL表示Q的上下文长度S表示KV Cache的长度。在prefill阶段L和S相等在decode阶段L为1S还是N。所以在prefill阶段attn_bias就是一个N x N的矩阵将is_causal设置为True时就会构造一个下三角为0上三角为负无穷的矩阵作为attn_bias然后将其加到QKT矩阵上这样就实现了因果关系的Attention计算。在decode阶段attn_bias就是一个1 x N的向量此时可以将is_causal设置为Falseattn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响表示KV Cache所有的行都参与计算因果关系保持正确。
attn_mask作用和is_causal一样但是需要我们自行构造如果你对如何构造不了解建议就使用is_causal选项prefill阶段设置为Truedecode阶段设置为Falseattn_mask设置为None。不过如果prefill按照chunk来执行也即chunk_prefill阶段我们会发现is_causal设置为True时的attn_bias设置的不正确我们不是从左上角开始构造下三角矩阵而是要从右下角开始构造下三角矩阵这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式一种是bool类型True的位置会保持不变False的位置会置为负无穷一种是float类型会直接将attn_mask加到SDPA内部的attn_bias上和bool类型一样我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说绝大多数情况下我们只需要设置is_causal选项prefill阶段设置为Truedecode阶段设置为Falseattn_mask设置为None即可。如果推理阶段引入了chunk_prefill则我们需要自行构造attn_mask但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。
3. SDPA实现(翻译自SDPA注释)
目前SDPA有三种实现
基于FlashAttention-2的实现Memory-Efficient Attention(facebook xformers)Pytorch版本对上述伪代码的c实现(对应MATH后端)。
针对CUDA后端SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端将使用PyTorch实现。所有实现方式默认都是启用的SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制torch提供了以下函数来启用和禁用各种实现方式
torch.nn.attention.sdpa_kernel一个上下文管理器用于启用或禁用任何一种实现方式torch.backends.cuda.enable_flash_sdp全局启用或禁用FlashAttentiontorch.backends.cuda.enable_mem_efficient_sdp全局启用或禁用memory efficient attentiontorch.backends.cuda.enable_math_sdp全局启用或禁用PyTorch的C实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C实现。如果某个融合实现方式不可用将会发出警告说明该融合实现方式无法运行的原因。由于融合浮点运算的特性此函数的输出可能会因所选择的后端内核而异。C 实现支持torch.float64当需要更高精度时可以使用。对于math后端如果输入是torch.half或torch.bfloat16类型那么所有中间计算结果都会保持为torch.float类型。
4. SDPA使用示例
首先强调一点灌入SDPA的QKV都是做过转置的也即维度为batch x head x N x d在老版本的torch中还需要QKV都是contiguous的新版本下无此要求。SDPA注释中还给了两个示例我们在此也给出
# Optionally use the context manager to ensure one of the fused kernels is runquery torch.rand(32, 8, 128, 64, dtypetorch.float16, devicecuda)key torch.rand(32, 8, 128, 64, dtypetorch.float16, devicecuda)value torch.rand(32, 8, 128, 64, dtypetorch.float16, devicecuda)with sdpa_kernel(backends[SDPBackend.FLASH_ATTENTION]):F.scaled_dot_product_attention(query,key,value)上述示例中给定的输入为batch等于32head等于8上下文长度128embedding维度64然后通过sdpa_kernel选择使用FlashAttention。 示例二
# Sample for GQA for llama3
query torch.rand(32, 32, 128, 64, dtypetorch.float16, devicecuda)
key torch.rand(32, 8, 128, 64, dtypetorch.float16, devicecuda)
value torch.rand(32, 8, 128, 64, dtypetorch.float16, devicecuda)
with sdpa_kernel(backends[SDPBackend.MATH]):F.scaled_dot_product_attention(query,key,value,enable_gqaTrue)示例二演示了GQA的用法给定的query head数为32key和value均为8此时我们可以通过enable_gqa选项来实现对GQA的支持此外代码还通过sdpa_kernel选项使用了MATH后端。
5. 参考
FlashAttention-2: Faster Attention with Better Parallelism and Work PartitioningMemory-Efficient AttentionGrouped-Query AttentionAttention Is All You Need