参考
Llm arch 1080P.html架构图,以及文末基于官方代码整理的在线架构示意资料,围绕现代 decoder-only LLM 的核心设计展开:GQA + RoPE + SwiGLU + KV Cache,并说明 FlashAttention 这类实现优化的作用。本文以 LLaMA-3 8B 为真实参考对象,用具体矩阵形状和数值例子演示一次前向传播、推理缓存、FLOPs 与训练能耗。
| 参数 | 数值 |
|---|---|
| 模型名称 | LLaMA-3 8B |
| 发布机构 | Meta |
| 发布日期 | 2024-04-18 |
| 架构类型 | Decoder-only Transformer |
| 参数量 | 约 8.03B |
| 训练 token 数 | 15T+ |
| 知识截止 | 2023-03 |
| 上下文长度 | 8,192 tokens |
| 词表大小 $\lvert V\rvert$ | 128,256 |
| 层数 $L$ | 32 |
| 模型维度 $d_{model}$ | 4,096 |
| Query heads | 32 |
| KV heads | 8 |
| 每头维度 $d_{head}$ | 128 |
| GQA 分组比例 | 每 4 个 Query head 共享 1 组 K/V |
| SwiGLU 中间维度 $d_{ff}$ | 14,336 |
| 位置编码 | RoPE,$\theta=500000$ |
| Norm | Pre-RMSNorm |
| 权重精度 | BF16 |
| Embedding / LM Head | 不共享权重 |
关键直觉:
GPT-3 的每层注意力通常可以近似看成 $Q,K,V,O$ 都是 $[d,d]$ 方阵;LLaMA-3 8B 因为用了 GQA,$K,V$ 不再投影到 4096 维,而是只投影到:
所以:
这就是图中 “GQA 共享” 和 “KV Cache 在推理时复用历史 K/V” 的核心。
图中的主路径可以读成:
文本输入
↓ Tokenizer
token ids
↓ Embeddings
X: [batch, seq_len, 4096]
↓ 重复 32 层 Transformer Block
├─ Pre-RMSNorm
├─ GQA / FlashAttention / RoPE / KV Cache
├─ Add 残差
├─ Pre-RMSNorm
├─ SwiGLU MLP
└─ Add 残差
↓ Final RMSNorm
↓ Linear Head
logits: [batch, seq_len, 128256]
↓ Softmax / top-p / top-k
新 token 输出
↓ 自回归循环
为了把形状看清楚,下面先用一个小例子:
输入句子:机器学习很有趣
假设分词后:["机器", "学习", "很", "有趣"]
seq_len = 4
batch = 1
真实模型里 token id 不是中文词本身,而是 tokenizer 产生的整数。这里用中文词只是为了便于理解。
词表矩阵:
参数量:
假设 “学习” 的 token id 是 12345,查表就是取出第 12345 行:
对于 4 个 token:
注意:LLaMA 系列没有传统的绝对位置 embedding 表。位置信息主要通过 RoPE 注入到 $Q$ 和 $K$ 中,而不是加到 embedding 上。
下面追踪一层。输入为:
RMSNorm 形状不变:
RMSNorm 与 LayerNorm 类似,都是让数值更稳定;区别是 RMSNorm 不减均值,只按均方根缩放:
其中 $d=4096$,$w$ 是可训练缩放参数,形状为 $[4096]$。
拆成 32 个 query head:
拆成 8 个 KV head:
拆成:
这就是 GQA:
| 类型 | head 数 | 总维度 |
|---|---|---|
| Query | 32 | 4096 |
| Key | 8 | 1024 |
| Value | 8 | 1024 |
每 4 个 Query head 共享 1 个 Key head 和 1 个 Value head:
Q heads 0,1,2,3 → 使用 KV head 0
Q heads 4,5,6,7 → 使用 KV head 1
...
Q heads 28,29,30,31 → 使用 KV head 7
前面例子中的“机器、学习、很、有趣”是 4 个 token 位置,属于序列维度 seq_len=4。而“每 4 个 Query head 共享 1 个 KV head”属于注意力头维度。两者彼此独立:
4 个 token → 一个 Query head 要查看的 4 个序列位置
4 个 Query heads → 同一 GQA 分组中共享一个 KV head 的 4 个特征通道
因此,只展示一个 Query head 如何关注 4 个 token,不能证明 4Q→1KV。下面会用 8 Query heads / 2 KV heads 的完整玩具模型,显式展示两个 GQA 分组。
如果是传统 MHA,K/V 也会有 32 个 head,KV Cache 会是 GQA 的 4 倍。
RoPE 不改变张量形状:
它的做法是把每个 head 的 128 维向量按两两一组看成二维平面:
在位置 $p$ 上,每一组会旋转一个角度:
直观理解:RoPE 让 “第 4 个 token 的 query” 和 “第 2 个 token 的 key” 的点积天然包含相对距离信息。
下面把真实 LLaMA-3 8B 缩小为一个仍然属于 GQA、但可以手算的模型。数值是人为设计的,只用于展示映射和计算。
输入:["机器", "学习", "很", "有趣"]
batch = 1
seq_len = 4
Query heads = 8
KV heads = 2
head_dim = 2
d_model = 16
模型维度和共享比例:
所以它有两个真正的 GQA 分组:
Q heads 0~3 → KV head 0
Q heads 4~7 → KV head 1
如果只有 4 Query heads / 1 KV head,严格分类更接近 MQA;这里采用 8Q/2KV,避免把 MQA 特例误称为完整 GQA。
输入:
投影矩阵:
拆分 head 后:
维度顺序是:
Q:[batch, token位置, Query head, head_dim]
K:[batch, token位置, KV head, head_dim]
V:[batch, token位置, KV head, head_dim]
Q 和 K 随后应用 RoPE,V 不应用 RoPE。下面给出的 Q/K 数值视为完成投影和 RoPE 后的结果;V 是完成投影后的结果。
KV head 0:
| 位置 $j$ | token | $k_{j,0}$ | $v_{j,0}$ |
|---|---|---|---|
| 1 | 机器 | $[1,0]$ | $[1,0]$ |
| 2 | 学习 | $[0,1]$ | $[0,1]$ |
| 3 | 很 | $[-1,0]$ | $[-1,0]$ |
| 4 | 有趣 | $[0,-1]$ | $[0,-1]$ |
KV head 1:
| 位置 $j$ | token | $k_{j,1}$ | $v_{j,1}$ |
|---|---|---|---|
| 1 | 机器 | $[1,1]$ | $[2,0]$ |
| 2 | 学习 | $[-1,1]$ | $[0,2]$ |
| 3 | 很 | $[-1,-1]$ | $[-2,0]$ |
| 4 | 有趣 | $[1,-1]$ | $[0,-2]$ |
“共享一个 KV head”不是所有 token 共用一个 K/V 向量。每个 KV head 在每个 token 位置都有不同的 K/V;同组 Query head 共享的是该 KV head 的整组投影结果,以及推理时对应的 KV Cache。
第 4 个 token“有趣”可以看到位置 1~4。保留 batch 和 token 维度时:
假设它产生:
| Query head $h$ | $q_{4,h}$ | 使用的 KV head |
|---|---|---|
| 0 | $[2,0]$ | 0 |
| 1 | $[0,2]$ | 0 |
| 2 | $[-2,0]$ | 0 |
| 3 | $[0,-2]$ | 0 |
| 4 | $[1,1]$ | 1 |
| 5 | $[-1,1]$ | 1 |
| 6 | $[-1,-1]$ | 1 |
| 7 | $[1,-1]$ | 1 |
映射函数:
因此:
Q0 ─┐
Q1 ─┤
Q2 ─┼── KV head 0
Q3 ─┘
Q4 ─┐
Q5 ─┤
Q6 ─┼── KV head 1
Q7 ─┘
Q0~Q3 不读取 KV head 1,Q4~Q7 也不读取 KV head 0。各 Query head 之间没有顺序依赖,通常由批量张量运算同时处理,而不是轮流占用 KV head。
实现时可以在逻辑上把每个 KV head 提供给组内 4 个 Query head,但不一定需要在显存中真正复制四份;GQA kernel 可以直接读取共享数据。
计算公式:
分数:
| Query | 机器 | 学习 | 很 | 有趣 |
|---|---|---|---|---|
| $Q_0$ | 1.414 | 0 | -1.414 | 0 |
| $Q_1$ | 0 | 1.414 | 0 | -1.414 |
| $Q_2$ | -1.414 | 0 | 1.414 | 0 |
| $Q_3$ | 0 | -1.414 | 0 | 1.414 |
沿 token 位置做 Softmax:
| Query | 机器 | 学习 | 很 | 有趣 |
|---|---|---|---|---|
| $Q_0$ | 0.647 | 0.157 | 0.038 | 0.157 |
| $Q_1$ | 0.157 | 0.647 | 0.157 | 0.038 |
| $Q_2$ | 0.038 | 0.157 | 0.647 | 0.157 |
| $Q_3$ | 0.157 | 0.038 | 0.157 | 0.647 |
每个 Query head 使用自己的权重,对 KV head 0 的 Value 加权:
| Query | 输出 |
|---|---|
| $Q_0$ | $[0.609,0]$ |
| $Q_1$ | $[0,0.609]$ |
| $Q_2$ | $[-0.609,0]$ |
| $Q_3$ | $[0,-0.609]$ |
分数:
| Query | 机器 | 学习 | 很 | 有趣 |
|---|---|---|---|---|
| $Q_4$ | 1.414 | 0 | -1.414 | 0 |
| $Q_5$ | 0 | 1.414 | 0 | -1.414 |
| $Q_6$ | -1.414 | 0 | 1.414 | 0 |
| $Q_7$ | 0 | -1.414 | 0 | 1.414 |
Softmax:
| Query | 机器 | 学习 | 很 | 有趣 |
|---|---|---|---|---|
| $Q_4$ | 0.647 | 0.157 | 0.038 | 0.157 |
| $Q_5$ | 0.157 | 0.647 | 0.157 | 0.038 |
| $Q_6$ | 0.038 | 0.157 | 0.647 | 0.157 |
| $Q_7$ | 0.157 | 0.038 | 0.157 | 0.647 |
对 KV head 1 的 Value 加权:
| Query | 输出 |
|---|---|
| $Q_4$ | $[1.218,0]$ |
| $Q_5$ | $[0,1.218]$ |
| $Q_6$ | $[-1.218,0]$ |
| $Q_7$ | $[0,-1.218]$ |
第二组的权重碰巧与第一组相同,是因为这里选用了对称向量;两组实际读取的是不同 K/V,输出也不同。真实模型中,两组的注意力权重通常不会相同。
两个分组一共得到 8 个 head 输出:
拼接后:
输出投影:
这个玩具例子完整展示了:
真实 LLaMA-3 8B 只是把两个分组扩大为八个分组,每组仍然是 4Q→1KV。32 个 Query head 计算完后拼接回 4096 维:
FlashAttention 的作用:
数学结果仍然等价于:
但实现上可以不把完整 $[seq,seq]$ 注意力矩阵全部写进显存,而是分块在 SRAM 中完成 softmax 和加权求和,从而减少显存读写。它主要优化的是内存访问和显存占用,不改变上面的矩阵形状,也不改变模型权重结构。
拼接后的 attention 输出:
乘以:
得到:
$W_O$ 的作用是重新混合 32 个 head 的信息。
形状:
图中右侧绿色区域是 SwiGLU,它和 GPT-3 常见的 GELU FFN 不一样。
输入先做 Pre-RMSNorm:
然后走两条上投影路径:
再过 SiLU:
形状不变:
形状:
SwiGLU 的直觉:
GELU FFN 像是 “升维 → 激活 → 降维”。SwiGLU 多了一条门控分支:
它让模型可以动态选择哪些中间特征通过,哪些被压低。
| 矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_Q$ | $[4096,4096]$ | 16,777,216 |
| $W_K$ | $[4096,1024]$ | 4,194,304 |
| $W_V$ | $[4096,1024]$ | 4,194,304 |
| $W_O$ | $[4096,4096]$ | 16,777,216 |
| Attention 合计 | 41,943,040 |
如果不用 GQA,而是传统 MHA:
GQA 让一层 attention 参数从 6711 万降到 4194 万,约减少 37.5%。
| 矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_{gate}$ | $[4096,14336]$ | 58,720,256 |
| $W_{up}$ | $[4096,14336]$ | 58,720,256 |
| $W_{down}$ | $[14336,4096]$ | 58,720,256 |
| MLP 合计 | 176,160,768 |
| 组件 | 参数量 |
|---|---|
| Attention | 41,943,040 |
| SwiGLU MLP | 176,160,768 |
| 两个 RMSNorm | 8,192 |
| 一层合计 | 218,112,000 |
32 层:
Embedding:
LM Head 不共享权重,所以再来一个:
最终 RMSNorm:
总参数量:
也就是约 8.03B。
自回归推理时,每生成一个新 token,旧 token 的 $K,V$ 不需要重复计算,可以缓存下来。
对于 LLaMA-3 8B,一层、一个 token 的 KV Cache 元素数:
BF16 每个数 2 字节:
32 层:
这里的 KB/MB/GB 为了讲解方便,按计算机里常用的二进制口径换算:$1\text{ KB}=1024\text{ bytes}$,所以严格写法也可以叫 KiB/MiB/GiB。
如果上下文长度为 8192:
| 上下文长度 | BF16 KV Cache,batch=1 |
|---|---|
| 1,024 | 128 MB |
| 8,192 | 1 GB |
| 32,768 | 4 GB |
| 128,000 | 15.6 GB |
如果不用 GQA,而是 32 个 KV head:
8K 上下文就会变成:
所以 GQA 对长上下文推理的意义非常直接:KV Cache 显存约降为 1/4。
推理分两个阶段:
| 阶段 | 含义 | 计算特点 |
|---|---|---|
| Prefill | 一次性读入 prompt | 可并行,attention 有 $S^2$ 项 |
| Decode | 每次生成 1 个新 token | token-by-token,主要受 KV Cache 和显存带宽影响 |
下面使用 $2mnp$ 估算矩阵乘法 FLOPs。
一层 attention + MLP 参数矩阵总量:
对 8192 个 token 做一次前向线性投影:
也就是一层约 3.57 TFLOPs 的线性计算。
对完整 8192 prompt,注意力分数可以用 dense/full attention 口径做上界估算:
再乘以 $V$ 也差不多一次:
32 层:
也就是 8K prompt 的一次 prefill,Transformer blocks 部分约 150 TFLOPs。
| seq_len | 一层线性 FLOPs | 一层 attention 二次项 | 32 层总 FLOPs |
|---|---|---|---|
| 128 | $5.58\times10^{10}$ | $2.68\times10^8$ | $1.80\times10^{12}$ |
| 1,024 | $4.47\times10^{11}$ | $1.72\times10^{10}$ | $1.48\times10^{13}$ |
| 8,192 | $3.57\times10^{12}$ | $1.10\times10^{12}$ | $1.50\times10^{14}$ |
短 prompt 时,线性层占主导;长 prompt 时,attention 的 $S^2$ 项开始明显抬头。
假设已经有 8192 个历史 token 的 KV Cache,现在要生成第 8193 个 token。
一层的线性层 FLOPs:
这个新 token 的 query 需要和历史 8192 个 token 加上当前 token 自己的 key 做注意力,所以 attention 长度是 8193:
再乘以 $V$:
一层合计:
32 层:
也就是在 8K cache 后继续生成 1 个 token,Transformer blocks 部分约 18.3 GFLOPs。
再加上 LM Head:
所以单 token decode 约 19.3 GFLOPs,但真实速度往往不是只由 FLOPs 决定,还会被 KV Cache 读写、batch、kernel 调度、采样逻辑限制。
32 层结束后:
Final RMSNorm:
Linear Head:
对最后一个位置取 logits:
Softmax 得到整个词表的概率分布。然后可以用:
| 方法 | 含义 |
|---|---|
| argmax | 永远选概率最高的 token |
| temperature | 调整分布尖锐程度 |
| top-k | 只在概率最高的 k 个 token 中采样 |
| top-p | 只在累计概率达到 p 的候选集中采样 |
采样出新 token 后,把它送回 tokenizer/embedding 路径,继续下一轮自回归生成。
大语言模型训练常用粗略估算:
其中:
LLaMA-3 8B:
所以:
这是一个理论级的训练计算量估算,不等于电表读数。
Meta model card 给出的 LLaMA-3 8B 预训练计算为:
硬件为 H100-80GB,表中功耗口径为:
电能:
Meta 披露的碳排放估计为:
如果把理论训练 FLOPs 除以 1.3M GPU hours:
即平均每张 GPU 的端到端有效吞吐约:
这不是 H100 的峰值算力,而是把数据加载、通信、并行切分、重算、优化器、空泡和系统开销都揉进去后的训练有效值。
| 设计 | 解决的问题 | 数值体现 |
|---|---|---|
| GQA | 降低 KV Cache 显存和 K/V 投影成本 | 8K KV Cache 从约 4GB 降到约 1GB |
| RoPE | 注入相对位置信息,适合长上下文扩展 | Q/K 形状不变,但点积包含位置信息 |
| FlashAttention | 优化 attention 的显存读写 | 避免显式保存完整 $S\times S$ 注意力矩阵,不改变权重结构 |
| SwiGLU | 提升 FFN 表达力 | 每层 MLP 约 1.76 亿参数,是主要参数来源 |
| KV Cache | 推理时复用历史 K/V | decode 不必重复计算历史 token 的 K/V |
| Pre-RMSNorm | 稳定深层训练 | 每层 attention 与 MLP 前都先归一化 |
以 batch=1、seq_len=4 为例:
| 步骤 | 张量形状 |
|---|---|
| token ids | $[1,4]$ |
| Embedding | $[1,4,4096]$ |
| RMSNorm | $[1,4,4096]$ |
| Q projection | $[1,4,4096]$ |
| Q reshape | $[1,4,32,128]$ |
| K projection | $[1,4,1024]$ |
| K reshape | $[1,4,8,128]$ |
| V projection | $[1,4,1024]$ |
| V reshape | $[1,4,8,128]$ |
| RoPE(Q,K) | Q/K 形状不变 |
| repeat KV for GQA | $[1,4,32,128]$ |
| attention output all heads | $[1,4,32,128]$ |
| concat heads | $[1,4,4096]$ |
| O projection | $[1,4,4096]$ |
| Add residual 1 | $[1,4,4096]$ |
| Gate projection | $[1,4,14336]$ |
| Up projection | $[1,4,14336]$ |
| SwiGLU multiply | $[1,4,14336]$ |
| Down projection | $[1,4,4096]$ |
| Add residual 2 | $[1,4,4096]$ |
| 32 层后 Final RMSNorm | $[1,4,4096]$ |
| LM Head | $[1,4,128256]$ |
| 取最后位置 logits | $[1,128256]$ |
LLaMA-3 8B 的现代 LLM 架构可以看成:用 RoPE 把位置信息旋进 Q/K,用 GQA 把 KV Cache 压到传统 MHA 的 1/4,用 SwiGLU 扩大并门控 FFN 表达力,再用 KV Cache 把自回归推理从“反复重算全文”变成“只算新 token + 读取历史 K/V”;实现上还可以用 FlashAttention 降低长序列 attention 的显存读写。
repeat_kv、KV Cache、SwiGLU 实现。hidden_size=4096、intermediate_size=14336、num_hidden_layers=32、num_attention_heads=32、num_key_value_heads=8、vocab_size=128256。llama3 代码整理的 Transformer 架构说明与示意。