LLaMA-3 8B 现代 LLM 架构、计算规模与能耗

参考 Llm arch 1080P.html 架构图,以及文末基于官方代码整理的在线架构示意资料,围绕现代 decoder-only LLM 的核心设计展开:GQA + RoPE + SwiGLU + KV Cache,并说明 FlashAttention 这类实现优化的作用。本文以 LLaMA-3 8B 为真实参考对象,用具体矩阵形状和数值例子演示一次前向传播、推理缓存、FLOPs 与训练能耗。


一、LLaMA-3 8B 基本参数

参数 数值
模型名称 LLaMA-3 8B
发布机构 Meta
发布日期 2024-04-18
架构类型 Decoder-only Transformer
参数量 约 8.03B
训练 token 数 15T+
知识截止 2023-03
上下文长度 8,192 tokens
词表大小 $\lvert V\rvert$ 128,256
层数 $L$ 32
模型维度 $d_{model}$ 4,096
Query heads 32
KV heads 8
每头维度 $d_{head}$ 128
GQA 分组比例 每 4 个 Query head 共享 1 组 K/V
SwiGLU 中间维度 $d_{ff}$ 14,336
位置编码 RoPE,$\theta=500000$
Norm Pre-RMSNorm
权重精度 BF16
Embedding / LM Head 不共享权重

关键直觉:

GPT-3 的每层注意力通常可以近似看成 $Q,K,V,O$ 都是 $[d,d]$ 方阵;LLaMA-3 8B 因为用了 GQA,$K,V$ 不再投影到 4096 维,而是只投影到:

$$ n_{kv} \times d_{head} = 8 \times 128 = 1024 $$

所以:

$$ W_Q:[4096,4096],\quad W_K:[4096,1024],\quad W_V:[4096,1024],\quad W_O:[4096,4096] $$

这就是图中 “GQA 共享” 和 “KV Cache 在推理时复用历史 K/V” 的核心。


二、从架构图看一次完整数据流

图中的主路径可以读成:

文本输入
  ↓ Tokenizer
token ids
  ↓ Embeddings
X: [batch, seq_len, 4096]
  ↓ 重复 32 层 Transformer Block
  ├─ Pre-RMSNorm
  ├─ GQA / FlashAttention / RoPE / KV Cache
  ├─ Add 残差
  ├─ Pre-RMSNorm
  ├─ SwiGLU MLP
  └─ Add 残差
  ↓ Final RMSNorm
  ↓ Linear Head
logits: [batch, seq_len, 128256]
  ↓ Softmax / top-p / top-k
新 token 输出
  ↓ 自回归循环

为了把形状看清楚,下面先用一个小例子:

输入句子:机器学习很有趣
假设分词后:["机器", "学习", "很", "有趣"]
seq_len = 4
batch = 1

真实模型里 token id 不是中文词本身,而是 tokenizer 产生的整数。这里用中文词只是为了便于理解。


三、Embedding:从 token id 到 4096 维向量

词表矩阵:

$$ E:[128256,\ 4096] $$

参数量:

$$ 128256 \times 4096 = 525,336,576 \approx \textbf{5.25 亿} $$

假设 “学习” 的 token id 是 12345,查表就是取出第 12345 行:

$$ x_{\text{学习}}:[1,\ 4096] $$

对于 4 个 token:

$$ X:[1,\ 4,\ 4096] $$

注意:LLaMA 系列没有传统的绝对位置 embedding 表。位置信息主要通过 RoPE 注入到 $Q$ 和 $K$ 中,而不是加到 embedding 上。


四、一层 Transformer Block 的完整计算

下面追踪一层。输入为:

$$ X:[1,\ 4,\ 4096] $$

1. Pre-RMSNorm

RMSNorm 形状不变:

$$ [1,\ 4,\ 4096] \rightarrow [1,\ 4,\ 4096] $$

RMSNorm 与 LayerNorm 类似,都是让数值更稳定;区别是 RMSNorm 不减均值,只按均方根缩放:

$$ \text{RMS}(x)=\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2+\epsilon} $$
$$ y_i = \frac{x_i}{\text{RMS}(x)} \cdot w_i $$

其中 $d=4096$,$w$ 是可训练缩放参数,形状为 $[4096]$。


2. 生成 Q、K、V:GQA 的形状差异

Q 投影

$$ Q = XW_Q $$
$$ [1,\ 4,\ 4096] \times [4096,\ 4096] \rightarrow [1,\ 4,\ 4096] $$

拆成 32 个 query head:

$$ [1,\ 4,\ 4096] \rightarrow [1,\ 4,\ 32,\ 128] $$

K 投影

$$ K = XW_K $$
$$ [1,\ 4,\ 4096] \times [4096,\ 1024] \rightarrow [1,\ 4,\ 1024] $$

拆成 8 个 KV head:

$$ [1,\ 4,\ 1024] \rightarrow [1,\ 4,\ 8,\ 128] $$

V 投影

$$ V = XW_V $$
$$ [1,\ 4,\ 4096] \times [4096,\ 1024] \rightarrow [1,\ 4,\ 1024] $$

拆成:

$$ [1,\ 4,\ 8,\ 128] $$

这就是 GQA:

类型 head 数 总维度
Query 32 4096
Key 8 1024
Value 8 1024

每 4 个 Query head 共享 1 个 Key head 和 1 个 Value head:

Q heads 0,1,2,3     → 使用 KV head 0
Q heads 4,5,6,7     → 使用 KV head 1
...
Q heads 28,29,30,31 → 使用 KV head 7

不要把“4 个 token”和“4 个 Query head”混为一谈

前面例子中的“机器、学习、很、有趣”是 4 个 token 位置,属于序列维度 seq_len=4。而“每 4 个 Query head 共享 1 个 KV head”属于注意力头维度。两者彼此独立:

4 个 token       → 一个 Query head 要查看的 4 个序列位置
4 个 Query heads → 同一 GQA 分组中共享一个 KV head 的 4 个特征通道

因此,只展示一个 Query head 如何关注 4 个 token,不能证明 4Q→1KV。下面会用 8 Query heads / 2 KV heads 的完整玩具模型,显式展示两个 GQA 分组。

如果是传统 MHA,K/V 也会有 32 个 head,KV Cache 会是 GQA 的 4 倍。


3. RoPE:把位置旋转进 Q 和 K

RoPE 不改变张量形状:

$$ Q:[1,\ 4,\ 32,\ 128] \rightarrow Q_{\text{rope}}:[1,\ 4,\ 32,\ 128] $$
$$ K:[1,\ 4,\ 8,\ 128] \rightarrow K_{\text{rope}}:[1,\ 4,\ 8,\ 128] $$

它的做法是把每个 head 的 128 维向量按两两一组看成二维平面:

$$ (x_0,x_1),\ (x_2,x_3),\ ...,\ (x_{126},x_{127}) $$

在位置 $p$ 上,每一组会旋转一个角度:

$$ \begin{bmatrix} x'_{2i}\\ x'_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos \theta_{p,i} & -\sin \theta_{p,i}\\ \sin \theta_{p,i} & \cos \theta_{p,i} \end{bmatrix} \begin{bmatrix} x_{2i}\\ x_{2i+1} \end{bmatrix} $$

直观理解:RoPE 让 “第 4 个 token 的 query” 和 “第 2 个 token 的 key” 的点积天然包含相对距离信息。


4. GQA 注意力计算与 FlashAttention 优化

下面把真实 LLaMA-3 8B 缩小为一个仍然属于 GQA、但可以手算的模型。数值是人为设计的,只用于展示映射和计算。

4.1 玩具模型参数

输入:["机器", "学习", "很", "有趣"]

batch       = 1
seq_len     = 4
Query heads = 8
KV heads    = 2
head_dim    = 2
d_model     = 16

模型维度和共享比例:

$$ d_{model}=n_qd_{head}=8\times2=16 $$
$$ n_{rep}=\frac{n_q}{n_{kv}}=\frac{8}{2}=4 $$

所以它有两个真正的 GQA 分组:

Q heads 0~3 → KV head 0
Q heads 4~7 → KV head 1

如果只有 4 Query heads / 1 KV head,严格分类更接近 MQA;这里采用 8Q/2KV,避免把 MQA 特例误称为完整 GQA。

4.2 投影与张量形状

输入:

$$ X:[1,\ 4,\ 16] $$

投影矩阵:

$$ W_Q:[16,\ 16],\quad W_K:[16,\ 4],\quad W_V:[16,\ 4] $$

拆分 head 后:

$$ Q:[1,\ 4,\ 8,\ 2] $$
$$ K,V:[1,\ 4,\ 2,\ 2] $$

维度顺序是:

Q:[batch, token位置, Query head, head_dim]
K:[batch, token位置, KV head, head_dim]
V:[batch, token位置, KV head, head_dim]

Q 和 K 随后应用 RoPE,V 不应用 RoPE。下面给出的 Q/K 数值视为完成投影和 RoPE 后的结果;V 是完成投影后的结果。

4.3 两个 KV head 在四个 token 位置的数据

KV head 0:

位置 $j$ token $k_{j,0}$ $v_{j,0}$
1 机器 $[1,0]$ $[1,0]$
2 学习 $[0,1]$ $[0,1]$
3 $[-1,0]$ $[-1,0]$
4 有趣 $[0,-1]$ $[0,-1]$

KV head 1:

位置 $j$ token $k_{j,1}$ $v_{j,1}$
1 机器 $[1,1]$ $[2,0]$
2 学习 $[-1,1]$ $[0,2]$
3 $[-1,-1]$ $[-2,0]$
4 有趣 $[1,-1]$ $[0,-2]$

“共享一个 KV head”不是所有 token 共用一个 K/V 向量。每个 KV head 在每个 token 位置都有不同的 K/V;同组 Query head 共享的是该 KV head 的整组投影结果,以及推理时对应的 KV Cache。

4.4 第 4 个 token 的 8 个 Query head

第 4 个 token“有趣”可以看到位置 1~4。保留 batch 和 token 维度时:

$$ Q_{t=4}:[1,\ 1,\ 8,\ 2] $$

假设它产生:

Query head $h$ $q_{4,h}$ 使用的 KV head
0 $[2,0]$ 0
1 $[0,2]$ 0
2 $[-2,0]$ 0
3 $[0,-2]$ 0
4 $[1,1]$ 1
5 $[-1,1]$ 1
6 $[-1,-1]$ 1
7 $[1,-1]$ 1

映射函数:

$$ g(h)=\left\lfloor\frac{h}{4}\right\rfloor $$

因此:

Q0 ─┐
Q1 ─┤
Q2 ─┼── KV head 0
Q3 ─┘

Q4 ─┐
Q5 ─┤
Q6 ─┼── KV head 1
Q7 ─┘

Q0~Q3 不读取 KV head 1,Q4~Q7 也不读取 KV head 0。各 Query head 之间没有顺序依赖,通常由批量张量运算同时处理,而不是轮流占用 KV head。

实现时可以在逻辑上把每个 KV head 提供给组内 4 个 Query head,但不一定需要在显存中真正复制四份;GQA kernel 可以直接读取共享数据。

4.5 第一组:Q0~Q3 读取 KV head 0

计算公式:

$$ s_{h,4,j} = \frac{q_{4,h}\cdot k_{j,0}}{\sqrt{2}}, \qquad h\in\{0,1,2,3\} $$

分数:

Query 机器 学习 有趣
$Q_0$ 1.414 0 -1.414 0
$Q_1$ 0 1.414 0 -1.414
$Q_2$ -1.414 0 1.414 0
$Q_3$ 0 -1.414 0 1.414

沿 token 位置做 Softmax:

Query 机器 学习 有趣
$Q_0$ 0.647 0.157 0.038 0.157
$Q_1$ 0.157 0.647 0.157 0.038
$Q_2$ 0.038 0.157 0.647 0.157
$Q_3$ 0.157 0.038 0.157 0.647

每个 Query head 使用自己的权重,对 KV head 0 的 Value 加权:

$$ o_{4,h}=\sum_{j=1}^{4}\alpha_{h,4,j}v_{j,0} $$
Query 输出
$Q_0$ $[0.609,0]$
$Q_1$ $[0,0.609]$
$Q_2$ $[-0.609,0]$
$Q_3$ $[0,-0.609]$

4.6 第二组:Q4~Q7 读取 KV head 1

$$ s_{h,4,j} = \frac{q_{4,h}\cdot k_{j,1}}{\sqrt{2}}, \qquad h\in\{4,5,6,7\} $$

分数:

Query 机器 学习 有趣
$Q_4$ 1.414 0 -1.414 0
$Q_5$ 0 1.414 0 -1.414
$Q_6$ -1.414 0 1.414 0
$Q_7$ 0 -1.414 0 1.414

Softmax:

Query 机器 学习 有趣
$Q_4$ 0.647 0.157 0.038 0.157
$Q_5$ 0.157 0.647 0.157 0.038
$Q_6$ 0.038 0.157 0.647 0.157
$Q_7$ 0.157 0.038 0.157 0.647

对 KV head 1 的 Value 加权:

Query 输出
$Q_4$ $[1.218,0]$
$Q_5$ $[0,1.218]$
$Q_6$ $[-1.218,0]$
$Q_7$ $[0,-1.218]$

第二组的权重碰巧与第一组相同,是因为这里选用了对称向量;两组实际读取的是不同 K/V,输出也不同。真实模型中,两组的注意力权重通常不会相同。

4.7 拼接与输出投影

两个分组一共得到 8 个 head 输出:

$$ O_{t=4}:[1,\ 1,\ 8,\ 2] $$

拼接后:

$$ [1,\ 1,\ 8,\ 2]\rightarrow[1,\ 1,\ 16] $$

输出投影:

$$ [1,\ 1,\ 16]W_O[16,\ 16]\rightarrow[1,\ 1,\ 16] $$

这个玩具例子完整展示了:

  1. 一个 KV head 被组内 4 个 Query head 共享;
  2. 不同分组使用不同 KV head;
  3. 同组 Query 共享 K/V,但不共享 Query、注意力权重或输出;
  4. Query 不与所有 KV head 做笛卡尔积;
  5. 4 个 token 位置和 4 个 Query head 是两个独立维度。

真实 LLaMA-3 8B 只是把两个分组扩大为八个分组,每组仍然是 4Q→1KV。32 个 Query head 计算完后拼接回 4096 维:

$$ 32\times128=4096 $$
$$ \text{attn\_concat}:[1,\ 4,\ 4096] $$

FlashAttention 的作用:

数学结果仍然等价于:

$$ \text{softmax}\left(\frac{QK^T}{\sqrt{d_{head}}}+\text{mask}\right)V $$

但实现上可以不把完整 $[seq,seq]$ 注意力矩阵全部写进显存,而是分块在 SRAM 中完成 softmax 和加权求和,从而减少显存读写。它主要优化的是内存访问和显存占用,不改变上面的矩阵形状,也不改变模型权重结构。


5. O 输出投影

拼接后的 attention 输出:

$$ [1,\ 4,\ 4096] $$

乘以:

$$ W_O:[4096,\ 4096] $$

得到:

$$ [1,\ 4,\ 4096] $$

$W_O$ 的作用是重新混合 32 个 head 的信息。


6. Add 残差连接

$$ H = X + \text{Attention}(X) $$

形状:

$$ [1,\ 4,\ 4096] $$

五、SwiGLU 前馈网络:三次矩阵乘法

图中右侧绿色区域是 SwiGLU,它和 GPT-3 常见的 GELU FFN 不一样。

输入先做 Pre-RMSNorm:

$$ H:[1,\ 4,\ 4096] \rightarrow \hat{H}:[1,\ 4,\ 4096] $$

然后走两条上投影路径:

1. Gate 投影

$$ G = \hat{H}W_{gate} $$
$$ [1,\ 4,\ 4096] \times [4096,\ 14336] \rightarrow [1,\ 4,\ 14336] $$

再过 SiLU:

$$ \text{SiLU}(G)=G\cdot\sigma(G) $$

2. Up 投影

$$ U = \hat{H}W_{up} $$
$$ [1,\ 4,\ 4096] \times [4096,\ 14336] \rightarrow [1,\ 4,\ 14336] $$

3. 门控相乘

$$ M = \text{SiLU}(G) \odot U $$

形状不变:

$$ [1,\ 4,\ 14336] $$

4. Down 降维

$$ Y = MW_{down} $$
$$ [1,\ 4,\ 14336] \times [14336,\ 4096] \rightarrow [1,\ 4,\ 4096] $$

5. 第二次残差

$$ X_{\text{next}} = H + Y $$

形状:

$$ [1,\ 4,\ 4096] $$

SwiGLU 的直觉:

GELU FFN 像是 “升维 → 激活 → 降维”。SwiGLU 多了一条门控分支:

$$ \text{SwiGLU}(x)=W_{down}\left(\text{SiLU}(xW_{gate})\odot xW_{up}\right) $$

它让模型可以动态选择哪些中间特征通过,哪些被压低。


六、一层参数量逐项计算

Attention 参数

矩阵 形状 参数量
$W_Q$ $[4096,4096]$ 16,777,216
$W_K$ $[4096,1024]$ 4,194,304
$W_V$ $[4096,1024]$ 4,194,304
$W_O$ $[4096,4096]$ 16,777,216
Attention 合计 41,943,040

如果不用 GQA,而是传统 MHA:

$$ 4 \times 4096^2 = 67,108,864 $$

GQA 让一层 attention 参数从 6711 万降到 4194 万,约减少 37.5%。

SwiGLU MLP 参数

矩阵 形状 参数量
$W_{gate}$ $[4096,14336]$ 58,720,256
$W_{up}$ $[4096,14336]$ 58,720,256
$W_{down}$ $[14336,4096]$ 58,720,256
MLP 合计 176,160,768

一层 Block 合计

组件 参数量
Attention 41,943,040
SwiGLU MLP 176,160,768
两个 RMSNorm 8,192
一层合计 218,112,000

32 层:

$$ 32 \times 218,112,000 = 6,979,584,000 $$

Embedding:

$$ 128256 \times 4096 = 525,336,576 $$

LM Head 不共享权重,所以再来一个:

$$ 4096 \times 128256 = 525,336,576 $$

最终 RMSNorm:

$$ 4096 $$

总参数量:

$$ 6,979,584,000 + 525,336,576 + 525,336,576 + 4,096 = \textbf{8,030,261,248} $$

也就是约 8.03B


七、KV Cache:为什么 GQA 对推理很重要?

自回归推理时,每生成一个新 token,旧 token 的 $K,V$ 不需要重复计算,可以缓存下来。

对于 LLaMA-3 8B,一层、一个 token 的 KV Cache 元素数:

$$ 2 \times n_{kv} \times d_{head} =2 \times 8 \times 128 =2048 $$

BF16 每个数 2 字节:

$$ 2048 \times 2 = 4096 \text{ bytes} = 4 \text{ KB} $$

32 层:

$$ 32 \times 4\text{ KB} = 128\text{ KB/token} $$

这里的 KB/MB/GB 为了讲解方便,按计算机里常用的二进制口径换算:$1\text{ KB}=1024\text{ bytes}$,所以严格写法也可以叫 KiB/MiB/GiB。

如果上下文长度为 8192:

$$ 8192 \times 128\text{ KB} = 1,073,741,824\text{ bytes} = \textbf{1 GB} $$
上下文长度 BF16 KV Cache,batch=1
1,024 128 MB
8,192 1 GB
32,768 4 GB
128,000 15.6 GB

如果不用 GQA,而是 32 个 KV head:

$$ 2 \times 32 \times 128 \times 2 \times 32 = 524,288 \text{ bytes/token} = 512 \text{ KB/token} $$

8K 上下文就会变成:

$$ 8192 \times 512\text{ KB} = \textbf{4 GB} $$

所以 GQA 对长上下文推理的意义非常直接:KV Cache 显存约降为 1/4


八、FLOPs 计算:Prefill 与 Decode 分开看

推理分两个阶段:

阶段 含义 计算特点
Prefill 一次性读入 prompt 可并行,attention 有 $S^2$ 项
Decode 每次生成 1 个新 token token-by-token,主要受 KV Cache 和显存带宽影响

下面使用 $2mnp$ 估算矩阵乘法 FLOPs。

1. 一层线性层 FLOPs,seq_len = 8192

一层 attention + MLP 参数矩阵总量:

$$ 41,943,040 + 176,160,768 = 218,103,808 $$

对 8192 个 token 做一次前向线性投影:

$$ 2 \times 8192 \times 218,103,808 \approx 3.57 \times 10^{12} $$

也就是一层约 3.57 TFLOPs 的线性计算。

2. 一层 attention 二次项 FLOPs

对完整 8192 prompt,注意力分数可以用 dense/full attention 口径做上界估算:

$$ QK^T:\quad 2 \times n_{heads} \times S^2 \times d_{head} $$
$$ 2 \times 32 \times 8192^2 \times 128 \approx 5.50 \times 10^{11} $$

再乘以 $V$ 也差不多一次:

$$ \text{Attention 二次项合计} \approx 1.10 \times 10^{12} $$

3. 一层总前向 FLOPs

$$ 3.57 \times 10^{12} + 1.10 \times 10^{12} = 4.67 \times 10^{12} $$

32 层:

$$ 32 \times 4.67 \times 10^{12} = \textbf{1.50} \times 10^{14} $$

也就是 8K prompt 的一次 prefill,Transformer blocks 部分约 150 TFLOPs

4. 不同 prompt 长度的前向计算量

seq_len 一层线性 FLOPs 一层 attention 二次项 32 层总 FLOPs
128 $5.58\times10^{10}$ $2.68\times10^8$ $1.80\times10^{12}$
1,024 $4.47\times10^{11}$ $1.72\times10^{10}$ $1.48\times10^{13}$
8,192 $3.57\times10^{12}$ $1.10\times10^{12}$ $1.50\times10^{14}$

短 prompt 时,线性层占主导;长 prompt 时,attention 的 $S^2$ 项开始明显抬头。


九、Decode:生成一个新 token 的计算

假设已经有 8192 个历史 token 的 KV Cache,现在要生成第 8193 个 token。

一层的线性层 FLOPs:

$$ 2 \times 1 \times 218,103,808 = 4.36 \times 10^8 $$

这个新 token 的 query 需要和历史 8192 个 token 加上当前 token 自己的 key 做注意力,所以 attention 长度是 8193:

$$ QK^T:\quad 2 \times 32 \times 1 \times 8193 \times 128 \approx 6.71 \times 10^7 $$

再乘以 $V$:

$$ \approx 6.71 \times 10^7 $$

一层合计:

$$ 4.36 \times 10^8 + 1.34 \times 10^8 = 5.70 \times 10^8 $$

32 层:

$$ 32 \times 5.70 \times 10^8 = \textbf{1.83} \times 10^{10} $$

也就是在 8K cache 后继续生成 1 个 token,Transformer blocks 部分约 18.3 GFLOPs

再加上 LM Head:

$$ 2 \times 4096 \times 128256 \approx 1.05 \times 10^9 $$

所以单 token decode 约 19.3 GFLOPs,但真实速度往往不是只由 FLOPs 决定,还会被 KV Cache 读写、batch、kernel 调度、采样逻辑限制。


十、输出层与采样

32 层结束后:

$$ H_{\text{final}}:[1,\ 4,\ 4096] $$

Final RMSNorm:

$$ [1,\ 4,\ 4096] \rightarrow [1,\ 4,\ 4096] $$

Linear Head:

$$ [1,\ 4,\ 4096] \times [4096,\ 128256] \rightarrow [1,\ 4,\ 128256] $$

对最后一个位置取 logits:

$$ [1,\ 128256] $$

Softmax 得到整个词表的概率分布。然后可以用:

方法 含义
argmax 永远选概率最高的 token
temperature 调整分布尖锐程度
top-k 只在概率最高的 k 个 token 中采样
top-p 只在累计概率达到 p 的候选集中采样

采样出新 token 后,把它送回 tokenizer/embedding 路径,继续下一轮自回归生成。


十一、训练 FLOPs 与能耗换算

1. 用经验公式估算训练 FLOPs

大语言模型训练常用粗略估算:

$$ \text{训练 FLOPs} \approx 6 \times N \times T $$

其中:

LLaMA-3 8B:

$$ N \approx 8.03 \times 10^9 $$
$$ T \approx 15 \times 10^{12} $$

所以:

$$ 6 \times 8.03\times10^9 \times 15\times10^{12} \approx \textbf{7.23}\times10^{23}\text{ FLOPs} $$

这是一个理论级的训练计算量估算,不等于电表读数。

2. Meta 披露的训练 GPU 小时

Meta model card 给出的 LLaMA-3 8B 预训练计算为:

$$ 1.3\text{M GPU hours} $$

硬件为 H100-80GB,表中功耗口径为:

$$ 700W $$

电能:

$$ 1.3\times10^6\text{ GPU小时} \times 0.7\text{ kW} = 910,000\text{ kWh} $$
$$ = \textbf{910 MWh} = \textbf{0.91 GWh} $$

Meta 披露的碳排放估计为:

$$ \textbf{390 tCO}_2\text{eq} $$

3. FLOPs 与 GPU 小时的连接

如果把理论训练 FLOPs 除以 1.3M GPU hours:

$$ \frac{7.23\times10^{23}}{1.3\times10^6 \times 3600} \approx 1.54\times10^{14}\text{ FLOP/s/GPU} $$

即平均每张 GPU 的端到端有效吞吐约:

$$ \textbf{154 TFLOP/s} $$

这不是 H100 的峰值算力,而是把数据加载、通信、并行切分、重算、优化器、空泡和系统开销都揉进去后的训练有效值。


十二、现代设计为什么会出现在这张图里?

设计 解决的问题 数值体现
GQA 降低 KV Cache 显存和 K/V 投影成本 8K KV Cache 从约 4GB 降到约 1GB
RoPE 注入相对位置信息,适合长上下文扩展 Q/K 形状不变,但点积包含位置信息
FlashAttention 优化 attention 的显存读写 避免显式保存完整 $S\times S$ 注意力矩阵,不改变权重结构
SwiGLU 提升 FFN 表达力 每层 MLP 约 1.76 亿参数,是主要参数来源
KV Cache 推理时复用历史 K/V decode 不必重复计算历史 token 的 K/V
Pre-RMSNorm 稳定深层训练 每层 attention 与 MLP 前都先归一化

十三、完整形状速查表

以 batch=1、seq_len=4 为例:

步骤 张量形状
token ids $[1,4]$
Embedding $[1,4,4096]$
RMSNorm $[1,4,4096]$
Q projection $[1,4,4096]$
Q reshape $[1,4,32,128]$
K projection $[1,4,1024]$
K reshape $[1,4,8,128]$
V projection $[1,4,1024]$
V reshape $[1,4,8,128]$
RoPE(Q,K) Q/K 形状不变
repeat KV for GQA $[1,4,32,128]$
attention output all heads $[1,4,32,128]$
concat heads $[1,4,4096]$
O projection $[1,4,4096]$
Add residual 1 $[1,4,4096]$
Gate projection $[1,4,14336]$
Up projection $[1,4,14336]$
SwiGLU multiply $[1,4,14336]$
Down projection $[1,4,4096]$
Add residual 2 $[1,4,4096]$
32 层后 Final RMSNorm $[1,4,4096]$
LM Head $[1,4,128256]$
取最后位置 logits $[1,128256]$

十四、一句话总结

LLaMA-3 8B 的现代 LLM 架构可以看成:用 RoPE 把位置信息旋进 Q/K,用 GQA 把 KV Cache 压到传统 MHA 的 1/4,用 SwiGLU 扩大并门控 FFN 表达力,再用 KV Cache 把自回归推理从“反复重算全文”变成“只算新 token + 读取历史 K/V”;实现上还可以用 FlashAttention 降低长序列 attention 的显存读写。


参考来源