Defeating Nondeterminism in LLM Inference¶
Introduction¶
The article addresses the challenge of achieving reproducible results in large language model (LLM) inference. It challenges the common "concurrency + floating point" hypothesis for nondeterminism, proposing a more comprehensive explanation and a solution for obtaining deterministic results. The post explores the nuances of determinism in LLM inference, explaining why seemingly deterministic components can still lead to nondeterministic outcomes from a user's perspective.
The Original Sin: Floating-Point Non-Associativity¶
- Floating-point numbers' non-associativity (
(a + b) + c ≠ a + (b + c)) is a fundamental cause of numerical differences. - This non-associativity arises from floating-point numbers' dynamic precision, which allows representation of both very small and very large values.
- When adding floating-point numbers with different exponents, information can be lost due to rounding, leading to different results depending on the order of operations.
Why Kernels Don’t Always Add Numbers in the Same Order¶
- The "concurrency + floating point" hypothesis suggests that nondeterministic thread completion order in concurrent execution leads to nondeterministic accumulation.
- However, the article argues that concurrency and atomic adds are not the primary cause of LLM inference nondeterminism.
- Modern GPU kernels rarely need atomic adds due to sufficient parallelism along the "batch" dimension.
- Techniques like split reductions and semaphores are used to achieve determinism without sacrificing performance.
- The forward pass of an LLM typically does not involve operations that require atomic adds, making it "run-to-run deterministic."
Batch Invariance and "Determinism"¶
- Even if the forward pass is deterministic, the output of an LLM inference server can be nondeterministic from a user's perspective.
- This is because a request's output can depend on parallel user requests due to the forward pass lacking "batch invariance."
- Batch invariance means that the output for each element in a batch should be independent of the batch size and other elements in the batch.
- Non-batch-invariant kernels can lead to nondeterminism because the load on the server (and thus the batch size) varies nondeterministically.
How to Make Kernels Batch-Invariant¶
- To achieve determinism in inference servers, all kernels must be batch-invariant.
- Pointwise operations are generally batch-invariant.
- The focus is on making RMSNorm, matrix multiplication, and attention batch-invariant.
Batch-invariant RMSNorm¶
- Implemented using a data-parallel strategy where each batch element is assigned to one core, ensuring that each reduction is done entirely within a single core, maintaining batch invariance.
- Challenges arise with small batch sizes, where cores may be idle. However, split reductions or atomic adds should be avoided to maintain batch invariance.
Batch-invariant matrix multiplication¶
- Implemented using a data-parallel strategy by splitting the output tensor into 2D tiles and assigning each tile to a different core, with each core computing the dot products for that tile.
- Split-K Matmul (splitting along the reduction dimension) breaks batch invariance.
- The easiest way to ensure batch invariance for matmuls is to compile one kernel configuration and use that for all shapes, resulting in a minor performance loss.
Batch-invariant attention (FlashAttention2 Strategy)¶
- Parallelize along the query tensor (Q) and reduce along key/value tensors (K/V) simultaneously, enabling a data-parallel strategy.
- The reduction order for a given token should not depend on how many other tokens from its sequence are being processed simultaneously.
- To resolve issues with KV cache, update the KV cache and page table before the attention kernel.
- For small query lengths (decoding), a fixed split-size strategy is adopted to preserve batch invariance.
Implementation¶
- A demonstration of deterministic inference on top of vLLM is provided, leveraging its FlexAttention backend and torch.Library.
- The library of "batch-invariant" kernels is available at thinking-machines-lab/batch-invariant-ops.
Experiments¶
- Sampling 1000 completions using Qwen/Qwen3-235B-A22B-Instruct-2507 at temperature 0 with the prompt "Tell me about Richard Feynman" resulted in 80 unique completions.
- Enabling batch-invariant kernels resulted in all 1000 completions being identical.
Performance¶
- Experiments show that performance remains usable, even though the batch-invariant kernels have not been heavily optimized.
True On-Policy RL¶
- Deterministic inference enables modification of the training stack to obtain bitwise identical results between sampling and training, resulting in true on-policy RL.
- Experiments in a RLVR setup on Bigmath show that true on-policy RL allows training to proceed smoothly without off-policy correction.
Conclusion¶
The article successfully identifies and addresses the root causes of nondeterminism in LLM inference, offering a solution based on batch-invariant kernels. The authors encourage the community to pursue a deeper understanding of their systems and work towards deterministic outcomes.