LLM 模型显存占用计算公式

发布时间: 更新时间: 总字数:2667 阅读时间:6m 作者: IP上海 分享 网址

理解 LLM 的资源占用主要分为两个方面:1. 显存占用 (Memory Footprint):模型在运行时(推理或训练)需要占用多少内存(通常是 GPU 的 VRAM),这是决定需要多大显存的 GPU 的关键。2. 计算量 (Computational Cost):模型进行一次完整的计算需要多少次浮点运算(FLOPs),这决定模型的运行速度。

在线工具

LLM 运行内存计算器

计算结果

请在左侧输入参数并点击计算。

注意: 这是一个估算值。实际显存占用可能因模型具体架构、CUDA版本和底层库的差异而有少量浮动。模型的层数和隐藏维度是基于参数量估算的。

显存占用计算 (Memory/VRAM Usage)主要分为推理(Inference)训练(Training)两种情况,训练所需的显存远大于推理。

1. 推理(Inference)时的显存占用

推理时,显存主要由三部分组成:模型参数KV Cache一些临时缓存区

总显存占用 ≈ 模型参数占用 + KV Cache 占用 + 临时缓存

1.1 模型参数占用 (Model Weights)

这是最主要且最固定的部分。它等于模型的参数量乘以每个参数占用的字节数。

$$\text{显存}_{参数} = \text{模型参数量} \times \text{每个参数的字节数}$$

每个参数的字节数取决于其数据类型(精度):

  • FP32 (单精度): 4 字节
  • FP16 (半精度) / BF16 (bfloat16): 2 字节
  • INT8 (8 位整型): 1 字节
  • INT4 (4 位整型): 0.5 字节

示例: 对于一个 70 亿(7B)参数的 Llama 2 模型:

  • FP16 精度下占用:$7 \times 10^9 \times 2 \text{ bytes} = 14 \times 10^9 \text{ bytes} = \textbf{14 GB}$
  • INT4 量化后占用:$7 \times 10^9 \times 0.5 \text{ bytes} = 3.5 \times 10^9 \text{ bytes} = \textbf{3.5 GB}$

1.2 KV Cache 占用

在自回归生成文本时,为了避免重复计算,模型会缓存每个 token 在每一层中的 Key (K) 和 Value (V) 向量。这部分的显存是动态变化的,随着生成序列的增长而增长。

$$\text{显存}_{\text{KV Cache}} = \text{批处理大小} \times \text{序列长度} \times \text{层数} \times \text{隐藏层维度} \times 2 \times \text{每个参数的字节数}$$

  • 批处理大小 (Batch Size): 同时处理的序列数量。
  • 序列长度 (Sequence Length): 当前已经处理的 token 数量。
  • 层数 (Number of Layers): 模型的总层数。
  • 隐藏层维度 (Hidden Size): 模型内部向量的维度。
  • x 2: 因为需要同时存储 Key 和 Value 两个向量。

示例: 对于 Llama 3 8B 模型(32 层,隐藏层维度 4096),在 FP16 精度下,处理一个批次大小为 1,序列长度为 2048 的请求:

  • KV Cache 占用 = $1 \times 2048 \times 32 \times 4096 \times 2 \times 2 \text{ bytes} \approx \textbf{1.07 GB}$

1.3 临时缓存 (Workspace)

这部分用于存储中间计算结果,如注意力得分(Attention Scores)等。它的大小相对较小,且难以精确计算,通常预留 1-2 GB 的空间就足够了。

推理总结: 一个 7B 模型,在 FP16 精度下,即使只处理一个请求,也需要大约 14 GB (参数) + small KV Cache + 1 GB (缓存) ≈ 15-16 GB 的显存。如果序列很长或批处理很大,KV Cache 会显著增加显存需求。

2. 训练(Training)时的显存占用

训练时的显存需求要大得多,因为它除了包含推理时的部分,还需要存储梯度(Gradients)优化器状态(Optimizer States)

总显存占用 ≈ 模型参数 + 梯度 + 优化器状态 + 激活值 + 临时缓存

2.1 模型参数 (FP16/BF16)

与推理相同,通常使用半精度以节省显存。 $$\text{显存}_{参数} = P \times 2 \text{ bytes}$$ ($P$ 为参数量)

2.2 梯度 (Gradients, FP16/BF16)

反向传播计算出的梯度,其数量与参数量完全相同。 $$\text{显存}_{梯度} = P \times 2 \text{ bytes}$$

2.3 优化器状态 (Optimizer States, 通常为 FP32)

为了训练的稳定性,优化器的状态(如动量和方差)通常以 FP32 存储,即使模型本身用 FP16 训练。

  • Adam / AdamW 优化器: 需要存储一阶动量(m)和二阶动量(v),每个都是 FP32。 $$\text{显存}_{优化器} = P \times 4 \text{ bytes (m)} + P \times 4 \text{ bytes (v)} = P \times 8 \text{ bytes}$$
  • SGD with Momentum: 只需要存储动量。 $$\text{显存}_{优化器} = P \times 4 \text{ bytes}$$

一个重要的经验法则: 对于使用 AdamW 优化器的混合精度训练,仅参数、梯度和优化器状态三项加起来就需要: $P \times (2 + 2 + 8) = \mathbf{P \times 12 \text{ bytes}}$ 对于一个 70 亿参数的模型,这部分就是 $7 \times 12 \approx \textbf{84 GB}$!

2.4 激活值 (Activations)

这是训练时最耗费显存的动态部分。在前向传播过程中,每一层的输出(激活值)都必须被保存下来,以便在反向传播时计算梯度。其大小与批处理大小序列长度成正比。

一个粗略的估算公式是: $$\text{显存}_{激活} \approx \text{批处理大小} \times \text{序列长度} \times \text{层数} \times \text{隐藏层维度} \times \text{系数}$$ 这个系数会根据模型结构(如注意力机制和前馈网络的具体实现)而变化,通常在 10 到 20 之间。这部分显存非常巨大,也是为什么训练时需要使用梯度检查点(Gradient Checkpointing)等技术来用计算换显存。

训练总结: 全量微调一个 7B 模型,即使使用混合精度,也至少需要 84 GB (静态部分) + 巨大的激活值显存。这就是为什么全量微调大模型通常需要多张 A100/H100 (80GB) 这样的高端 GPU。像 LoRA 这样的参数高效微调(PEFT)技术,通过只训练一小部分参数,极大地减少了梯度和优化器状态的显存占用,从而可以在消费级 GPU 上进行微调。

3. 计算量计算 (FLOPs)

计算量衡量模型处理速度,单位是 FLOPs (Floating Point Operations per Second)。我们通常关心的是模型处理一个 token 的计算成本。

一个被广泛接受的、非常简洁的估算公式(来自 OpenAI 的论文《Scaling Laws for Neural Language Models》):

3.1 前向传播 (Forward Pass)

模型处理一个 token 所需的计算量大约是参数量的两倍。 $$\text{FLOPs}_{\text{前向}} \approx 2 \times \text{模型参数量}$$

处理一个包含 S 个 token 的序列: $$\text{FLOPs}_{\text{前向}} \approx 2 \times \text{模型参数量} \times \text{序列长度 (S)}$$

3.2 反向传播 (Backward Pass)

反向传播的计算量大约是前向传播的两倍:

$$\text{FLOPs}_{\text{反向}} \approx 2 \times \text{FLOPs}_{\text{前向}} \approx 4 \times \text{模型参数量}$$

3.3 训练一步 (Forward + Backward)

$$\text{FLOPs}_{\text{训练}} = \text{FLOPs}_{\text{前向}} + \text{FLOPs}_{\text{反向}} \approx 6 \times \text{模型参数量}$$

示例: 对于一个 70 亿(7B)参数的模型,训练一步处理一个长度为 2048 的序列(假设批处理大小为 1):

  • 总处理 token 数 = $1 \times 2048 = 2048$
  • 总计算量 ≈ $6 \times (7 \times 10^9) \times 2048 \approx 8.6 \times 10^{16} \text{ FLOPs} = \textbf{86 PetaFLOPs}$

这个公式为估算模型训练和推理所需的时间提供了理论基础。例如,如果知道你的 GPU 每秒可以提供多少 FLOPs(TFLOPS),就可以估算出处理一个序列或完成一次训练所需的时间。

本文总阅读量 次 本站总访问量 次 本站总访客数
Home Archives Categories Tags Statistics