Scaling State-Space Models on Multiple GPUs with Tensor Parallelism
AI 摘要
论文提出了一种通信高效的张量并行化方法,用于加速选择性状态空间模型在大规模GPU上的推理。
主要贡献
- 针对SSM模型提出通信高效的张量并行化设计
- 优化了SSM状态缓存以提升TTFT
- 通过量化AllReduce降低了TP聚合开销
方法论
通过优化SSM state cache,划分参数张量,和量化 AllReduce,实现了高效的SSM张量并行推理。
原文摘要
Selective state space models (SSMs) have rapidly become a compelling backbone for large language models, especially for long-context workloads. Yet in deployment, their inference performance is often bounded by the memory capacity, bandwidth, and latency limits of a single GPU, making multi-GPU execution increasingly necessary. Although tensor parallelism (TP) is widely used to scale Transformer inference, applying it to selective SSM blocks is non-trivial because the SSM mixer couples large projections with a sequence-wise recurrent state update and local mixing whose efficiency depends on preserving locality and avoiding synchronization in the critical path. This paper presents a communication-efficient TP design for selective SSM inference that addresses three practical engineering challenges: enabling TTFT improvements via an SSM state cache across prefill and decode, partitioning the mixer's packed parameter tensor so that recurrent updates remain local while minimizing communication, and reducing TP aggregation overhead with quantized AllReduce. We evaluate on three representative SSM-based LLMs spanning pure-SSM and hybrid architectures - Mamba, Falcon-Mamba, and Zamba - on NVIDIA A6000 and A100 clusters. Our experiments show substantial throughput gains from tensor-parallel SSM inference, improving batch-request throughput by ~1.6-2.1x on 2 GPUs and ~2.6-4.0x on 4 GPUs for Mamba, with the largest benefits at long context lengths, and achieving a further ~10-18% throughput improvement from quantized all-reduce by lowering synchronization bandwidth overhead.