Shiva-DiT: Residual-Based Differentiable Top-$k$ Selection for Efficient Diffusion Transformers
AI 摘要
Shiva-DiT通过残差学习的可微Top-k选择加速Diffusion Transformer。
主要贡献
- 提出基于残差的可微Top-k选择方法,实现高效DiT剪枝
- 引入上下文感知路由和自适应比率策略,自动学习剪枝策略
- 在SD3.5上取得1.54倍加速,性能优于现有方法
方法论
利用残差感知直通估计器,强制静态token数量,通过残差梯度估计保持端到端可学习性。
原文摘要
Diffusion Transformers (DiTs) incur prohibitive computational costs due to the quadratic scaling of self-attention. Existing pruning methods fail to simultaneously satisfy differentiability, efficiency, and the strict static budgets required for hardware overhead. To address this, we propose Shiva-DiT, which effectively reconciles these conflicting requirements via Residual-Based Differentiable Top-$k$ Selection. By leveraging a residual-aware straight-through estimator, our method enforces deterministic token counts for static compilation while preserving end-to-end learnability through residual gradient estimation. Furthermore, we introduce a Context-Aware Router and Adaptive Ratio Policy to autonomously learn an adaptive pruning schedule. Experiments on mainstream models, including SD3.5, demonstrate that Shiva-DiT establishes a new Pareto frontier, achieving a 1.54$\times$ wall-clock speedup with superior fidelity compared to existing baselines, effectively eliminating ragged tensor overheads.