Dear all,
I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax
, but I’ve encountered some issues. While using jax
with jit
for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here’s a summary of the time taken:
- TFP GPU: NUTS alone took 118.2952 seconds
- TFP GPU: NUTS + Bijector took 1986.8306 seconds
- TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
- TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
- Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds
I’ve conducted speed tests comparing with Numpypro
, and essentially, Numpypro
with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability
NUTS alone.
Could there be something I’ve missed? Is there room for optimization in this process?
Please find the data and code for reproducibility bellow:
Please note that I’m only using the first 100 lines of the data.
Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)
Thank you in advance for your assistance.
Sebastian