# Tensorflow cpu-only version (needed for testing).
tensorflow-cpu~=2.20.0
tf2onnx

# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.9.1+cpu

# Jax with cuda support.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12]==0.6.2
flax

-r requirements-common.txt
