高效架构¶
让模型更快不仅仅是降低精度,还在于设计更智能的架构,使每个token的计算量更少。本文涵盖StreamingLLM、稀疏和线性注意力、多查询和分组查询注意力、推理时的混合专家、知识蒸馏、剪枝和神经架构搜索
- 量化(文件1)使每个操作更廉价。本文从源头上减少操作数量。两者互补:一个架构高效且量化的模型可以比原始模型快10-100倍。
StreamingLLM:无限长度生成¶
-
标准Transformer将所有先前的token存储在KV缓存中,KV缓存随序列长度线性增长。在某一点上,缓存超过GPU内存,生成失败。StreamingLLM(Xiao等人,2023)使用固定大小的滚动KV缓存解决了这个问题。
-
关键观察:序列中的前几个token,无论其内容如何,都获得不成比例的高注意力分数。这些被称为注意力汇聚点。如果将它们从缓存中逐出,注意力分布会崩溃,生成质量灾难性下降。
-
StreamingLLM的解决方案:在缓存中永久保留少量汇聚token(前1-4个token),加上最近\(w\)个token的滚动窗口。总缓存大小为\(\text{sink} + w\),无论生成了多少token都是固定的。
-
注意力汇聚点锚定softmax分布,滚动窗口提供最近的上下文。这实现了无限长度生成,内存恒定,代价是失去了访问序列中间上下文的能力。
-
StreamingLLM无需重新训练即可用于自然形成注意力汇聚点的模型(大多数预训练LLM都会)。对于不形成汇聚点的模型,在训练期间添加一个可学习的汇聚token即可解决。
稀疏注意力¶
- 全自注意力在序列长度\(n\)上是\(O(n^2)\),因为每个token关注所有其他token。对于\(n = 128K\),注意力矩阵有\(128K^2 = 160\)亿个条目。稀疏注意力模式通过限制哪些token关注哪些token来减少这个数量。
-
滑动窗口注意力(Mistral、Gemma):每个token只关注之前\(w\)个token(例如\(w = 4096\))。注意力是\(O(n \cdot w)\)而不是\(O(n^2)\)。信息通过多层在窗口之外传播:经过\(L\)层后,有效上下文为\(L \times w\)。
-
局部+全局注意力(Longformer、BigBird):大多数token使用滑动窗口注意力(局部),但少数指定token(例如[CLS],每512个token)关注所有token(全局)。这同时捕获了局部模式和长距离依赖。
-
膨胀注意力:关注窗口内每第\(k\)个token,创建一个覆盖更大范围但注意力分数数量相同的稀疏模式。跨层增加膨胀度创建类似于膨胀卷积的层次结构(第8章)。
-
现代LLM的实际胜者是滑动窗口+全注意力交错:某些层使用滑动窗口(廉价,处理局部上下文),某些层使用全注意力(昂贵,捕获长距离)。Mistral/Mixtral使用这种模式。
线性注意力和状态空间模型¶
-
我们能完全替换\(O(n^2)\)的注意力吗?线性注意力和状态空间模型(SSM)通过避免显式注意力矩阵,以\(O(n)\)时间处理序列。
-
线性注意力用核近似替换softmax注意力:
-
通过先关联\(K^T V\)乘积(这是\(d \times d\),与序列长度无关),计算变成\(O(n \cdot d^2)\)而不是\(O(n^2 \cdot d)\)。对于\(n \gg d\)的长序列,这是巨大的节省。
-
RWKV结合了RNN和Transformer的思想。它使用循环公式顺序处理token(像RNN),但可以在训练时并行化(像Transformer)。推理是每个token \(O(1)\)(常量内存,KV缓存不增长)。
-
Mamba(Gu & Dao,2023)是一种选择性状态空间模型。它通过学习到的状态转换处理序列:
-
其中\(\bar{A}\)和\(\bar{B}\)是依赖于输入的(选择性),允许Mamba动态关注或忽略输入的部分。与固定SSM不同,选择性使Mamba在语言任务上与Transformer具有竞争力,同时保持\(O(n)\)的扩展性。
-
权衡:线性注意力和SSM对长序列更快,但对于需要精确长距离检索的任务,通常不如全注意力。混合架构(一些Transformer层+一些Mamba层)通常能提供两全其美的效果。
多查询和分组查询注意力¶
-
标准多头注意力(MHA,第7章)为每个头使用独立的\(K\)、\(V\)投影。对于\(h\)个head,KV缓存中有\(h\)个独立的键和值张量。多查询注意力(MQA)和分组查询注意力(GQA)减少了这个数量。
-
MQA(Shazeer,2019):所有头共享单组\(K, V\)投影。每个头仍然有自己的\(Q\)投影。KV缓存缩小了\(h\)倍(例如,32个头则缩小32倍)。
-
GQA(Ainslie等人,2023):一个中间方案。头被分组,每组共享一组\(K, V\)投影。有\(h = 32\)个头和\(g = 8\)个组,每组4个头共享K/V。KV缓存缩小了\(h/g = 4\)倍。
- 大多数现代LLM使用GQA(Llama 2/3、Gemma、Mistral)。它减少了KV缓存内存和推理延迟,与MHA相比质量损失可以忽略不计。
多头潜在注意力(MLA)¶
- MLA(DeepSeek-V2,2024)通过将KV缓存压缩为低秩潜在空间,比GQA更进一步。MLA不是缓存完整的键和值向量,而是每个token缓存一个压缩后的潜在向量\(\mathbf{c}_t\),并在注意力期间动态重构K/V:
-
压缩向量\(\mathbf{c}_t\)比原始K和V的组合小得多。DeepSeek-V2实现了与MHA相比93.3%的KV缓存大小减少,甚至优于MQA,同时保持MHA级别的质量。
-
权衡:从潜在向量重构K/V在每个注意力操作中增加了少量计算成本。但由于LLM解码是内存带宽受限的(而非计算受限),这总体上是个净收益:更少的内存加载 > 每token稍多计算。
Flash Attention¶
-
Flash Attention(Dao等人,2022,第16章文件05有详细论述)不是架构变化,而是一种实现优化,在任何高效注意力讨论中都不可或缺。它计算精确的标准注意力,具有以下特点:
- O(n)内存而不是O(n²)(注意力矩阵从未在HBM中具体化)。
- 比标准注意力快2-4倍(通过分块和在线softmax将数据保留在SRAM中)。
- 无质量损失——输出在数学上与标准注意力完全相同。
-
Flash Attention现在是PyTorch(
torch.nn.functional.scaled_dot_product_attention)、JAX和所有主要推理框架中默认的注意力实现。如果你在2024+年运行注意力,你几乎肯定在使用Flash Attention。
Ring Attention¶
-
Ring Attention(Liu等人,2023)将注意力计算分布到多个设备上,用于即使使用Flash Attention也无法装入单GPU内存的长序列。
-
思路:将序列分区到\(N\)个设备上。每个设备持有\(n/N\)个token的Q、K、V。设备排列成环形。每一步:
- 每个设备计算局部注意力(其Q对其局部K/V)。
- 每个设备将其K/V块发送到环中的下一个设备。
- 每个设备从上一个设备接收K/V,并针对这些K/V计算注意力。
- 经过\(N\)步后,每个设备已经关注过每个K/V块。
-
通信与计算重叠:在当前K/V块上计算注意力的同时,下一个块正在传输中。这几乎完全隐藏了通信延迟。
-
Ring Attention通过将KV缓存分布在一圈GPU上,实现了百万token上下文窗口。每台设备的内存为O(n/N),使得任意长序列都可行(仅受设备数量限制)。
推理时的混合专家¶
-
MoE模型(第7章)每个token只激活其参数的一小部分(通常8个专家中的2个)。在推理时,独特的挑战是专家缓存:所有专家都必须在内存中(因为任何token可能路由到任何专家),但每个token只有2个活跃。
-
对于Mixtral 8x7B模型:总参数 = 47B(8 × 7B专家,但有共享组件)。每个token的活跃参数 ≈ 13B(2个专家 + 共享层)。该模型具有LLM-70B级别的质量,但推理成本为LLM-13B级别,不过需要在内存中保留47B参数。
-
专家卸载:对于GPU内存受限的部署,将非活跃专家保留在CPU或SSD上,按需加载。这之所以有效,是因为token路由足够可预测,可以预取可能的专家。
-
专家缓存:在GPU内存中维护最近使用的专家的LRU缓存。如果相同的专家被重复激活(在领域内数据中常见),缓存命中率很高。
知识蒸馏¶
- 蒸馏(第6章)训练一个小的"学生"模型来模仿一个大的"教师"。学生从教师的软预测(类上的概率分布)中学习,这比单独的硬标签包含更多信息。
-
其中\(T\)是温度(更高的\(T\)使分布变软,揭示教师的不确定性),\(\alpha\)平衡蒸馏损失与标准交叉熵损失。
-
对于LLM:蒸馏用于从大型、能力强的模型创建小型、快速的模型。GPT-4 → 一个7B学生模型,在特定任务上捕获GPT-4的大部分行为。学生模型的推理成本可以低10-100倍。
-
任务特定蒸馏:仅在与部署任务相关的数据上进行蒸馏。从70B教师模型在医疗问答上蒸馏出的7B模型,在该特定任务上可以超越70B模型(因为学生有限的容量完全集中在目标领域上)。
剪枝¶
-
剪枝移除不必要的权重(将其设为零),减少模型大小和计算量。
-
非结构化剪枝(基于幅值):移除绝对值最小的单个权重。这创建了稀疏权重矩阵。简单有效用于压缩,但当前硬件(GPU)除非稀疏性遵循特定模式,否则无法高效加速稀疏操作。
-
结构化剪枝:移除整个单元——注意力头、MLP神经元或层。这产生一个更小的稠密模型,可以在标准硬件上直接加速。权衡是粒度更粗(移除一个完整的头可能同时移除了有用和无用的权重)。
-
2:4稀疏性(NVIDIA Ampere+):一种硬件支持的稀疏模式,每4个权重中有2个为零。GPU的稀疏Tensor Core跳过零乘法,实现约2倍加速。这是目前唯一具有实际硬件加速的稀疏模式。
-
彩票假说(Frankle & Carlin,2019):在随机初始化的网络中,存在一个子网络("中奖彩票"),可以单独训练以匹配完整网络的性能。找到这些子网络(通过训练、剪枝和重置)成本高昂,但这个洞察激励了剪枝研究。
神经架构搜索(NAS)¶
-
NAS通过搜索可能的架构空间来自动化架构设计,找到在硬件约束(延迟、内存、功耗)下最大化精度的架构。
-
EfficientNet(第8章)就是通过NAS找到的:复合缩放规则(平衡深度、宽度、分辨率)是从搜索中涌现的,而非人类直觉。
-
对于推理效率,NAS可以找到针对特定硬件目标优化的架构:"找到一个在iPhone神经引擎上延迟<5ms且在ImageNet上精度>80%的模型。"搜索空间包括层类型、宽度、激活函数和注意力模式。
-
一次性网络训练一个单个过参数化网络,为不同的部署目标提取子网络。一次训练运行产生针对云GPU、移动GPU和CPU优化的模型,每个都针对其目标进行了优化。
编程任务(使用CoLab或notebook)¶
-
实现滑动窗口注意力,并与全注意力比较内存使用。
import jax import jax.numpy as jnp def full_attention(Q, K, V): """标准O(n²)注意力。""" scores = Q @ K.T / jnp.sqrt(Q.shape[-1]) weights = jax.nn.softmax(scores, axis=-1) return weights @ V def sliding_window_attention(Q, K, V, window_size=128): """滑动窗口注意力:每个token关注前window_size个token。""" n = Q.shape[0] d = Q.shape[-1] output = jnp.zeros_like(Q) for i in range(n): start = max(0, i - window_size + 1) k_window = K[start:i+1] v_window = V[start:i+1] scores = Q[i] @ k_window.T / jnp.sqrt(d) weights = jax.nn.softmax(scores) output = output.at[i].set(weights @ v_window) return output n, d = 512, 64 key = jax.random.PRNGKey(0) Q = jax.random.normal(key, (n, d)) K = jax.random.normal(jax.random.PRNGKey(1), (n, d)) V = jax.random.normal(jax.random.PRNGKey(2), (n, d)) print(f"全注意力内存: O(n²) = {n*n} 个条目") print(f"窗口 (w=128) 内存: O(n*w) = {n*128} 个条目") print(f"减少: {n*n / (n*128):.1f}x") -
比较MHA、GQA和MQA的KV缓存大小。展示为什么GQA是实际的最佳选择。
def kv_cache_size(n_heads, n_kv_heads, d_head, seq_len, bytes=2): """KV缓存大小(MB)。""" return 2 * n_kv_heads * d_head * seq_len * bytes / 1e6 n_heads = 32 d_head = 128 seq_len = 32768 mha = kv_cache_size(n_heads, n_heads, d_head, seq_len) # 32个KV头 gqa = kv_cache_size(n_heads, 8, d_head, seq_len) # 8个KV头 mqa = kv_cache_size(n_heads, 1, d_head, seq_len) # 1个KV头 print(f"MHA (32个KV头): {mha:.0f} MB 每层") print(f"GQA (8个KV头): {gqa:.0f} MB 每层 ({mha/gqa:.0f}x 更小)") print(f"MQA (1个KV头): {mqa:.0f} MB 每层 ({mha/mqa:.0f}x 更小)") -
通过从随机注意力层中移除最不重要的注意力头并测量输出变化来模拟结构化剪枝。
import jax import jax.numpy as jnp key = jax.random.PRNGKey(0) n_heads, seq_len, d_head = 8, 64, 32 # 随机多头注意力输出(每个头一个) head_outputs = jax.random.normal(key, (n_heads, seq_len, d_head)) # 完整输出:连接所有头 full_output = head_outputs.reshape(seq_len, n_heads * d_head) # 重要性:通过范数度量每个头的贡献 head_norms = jnp.linalg.norm(head_outputs, axis=(1, 2)) print("头重要性(按范数):", jnp.round(head_norms, 2)) # 剪枝最不重要的头 for n_keep in [8, 6, 4, 2]: top_heads = jnp.argsort(head_norms)[-n_keep:] pruned = head_outputs[top_heads].reshape(seq_len, n_keep * d_head) # 填充到原始大小用于比较(将剪掉的头设为零) full_pruned = jnp.zeros_like(head_outputs) full_pruned = full_pruned.at[top_heads].set(head_outputs[top_heads]) full_pruned = full_pruned.reshape(seq_len, n_heads * d_head) error = jnp.linalg.norm(full_output - full_pruned) / jnp.linalg.norm(full_output) print(f"保留 {n_keep}/{n_heads} 个头: 相对误差 = {error:.4f}, " f"内存 = {n_keep/n_heads:.0%}")