Scalable Simulation-Based Model Inference with Test-Time Complexity Control
AI 摘要
PRISM提出了一种可扩展的基于模拟的模型推断方法,可在测试时控制模型复杂度。
主要贡献
- 提出PRISM模型,用于联合推断离散模型结构和连续参数。
- 实现测试时模型复杂度的可控性。
- 在合成符号回归和扩散MRI数据上验证了PRISM的有效性。
方法论
PRISM使用基于模拟的encoder-decoder结构,通过可调的模型先验,在测试时控制模型复杂度,进行模型选择。
原文摘要
Simulation plays a central role in scientific discovery. In many applications, the bottleneck is no longer running a simulator; it is choosing among large families of plausible simulators, each corresponding to different forward models/hypotheses consistent with observations. Over large model families, classical Bayesian workflows for model selection are impractical. Furthermore, amortized model selection methods typically hard-code a fixed model prior or complexity penalty at training time, requiring users to commit to a particular parsimony assumption before seeing the data. We introduce PRISM, a simulation-based encoder-decoder that infers a joint posterior over both discrete model structures and associated continuous parameters, while enabling test-time control of model complexity via a tunable model prior that the network is conditioned on. We show that PRISM scales to families with combinatorially many (up to billions) of model instantiations on a synthetic symbolic regression task. As a scientific application, we evaluate PRISM on biophysical modeling for diffusion MRI data, showing the ability to perform model selection across several multi-compartment models, on both synthetic and in vivo neuroimaging data.