# Tensorflow cpu-only version (needed for testing).
tensorflow-cpu~=2.20.0
tf2onnx

# Torch with cuda support.
# - torch is pinned to a version that is compatible with torch-xla.
--extra-index-url https://download.pytorch.org/whl/cu126
torch==2.9.1+cu126
torch-xla==2.9.0;sys_platform != 'darwin'

# Jax cpu-only version (needed for testing).
jax[cpu]

-r requirements-common.txt
