Some colleagues and I are interested in applying sharded Hamiltonian Monte Carlo using the Jax pmap framework as described here. Has anyone done any benchmarking on how the TFP implementation scales across multiple accelerators or compared sharding regimes (i.e. sharding over observations and/or parameters)? We’re particularly interested in cases where data is too big too fit on one GPU/TPU (i.e. the non embarrassingly parallel version). Thanks!
I’m not aware of any specific benchmarking on sharded Hamiltonian Monte Carlo using the Jax pmap framework. It might be worth reaching out to the JAX community or checking forums like GitHub discussions or Stack Overflow to see if others have shared their experiences or insights on this topic.