Distributed inference with JAX: GPU/TPU interconnect

Hello, this is a follow-up post to Benchmarks for distributed HMC?.

In particular, we’re interested in acquiring a machine with several GPUs in order to begin our own benchmarking. As part of our choice of machine, we’d like to know if GPU interconnect topology can significantly reduce overhead in distributed HMC. For example, is the jax.psum() backbone of the distributed log density built on a Reduce operation, and can it therefore leverage GPU-GPU direct communication via e.g. NVLink?

The motivation for this type of question: if and only if this and other technicalities hold, we may invest in something like a toroidal or fully connected GPU arrangement.

Thanks,
Jeremy