# These requirements are used for the CI and CPU-only Docker images so we install CPU only versions of torch.
# For GPU Docker images, you should install dl-gpu-requirements.txt afterwards.

tensorflow==2.15.1; python_version < '3.12' and (sys_platform != 'darwin' or platform_machine != 'arm64')
tensorflow-macos==2.15.1; python_version < '3.12' and sys_platform == 'darwin' and platform_machine == 'arm64'
tensorflow-probability==0.23.0; python_version < '3.12'
tensorflow-io-gcs-filesystem==0.31.0; python_version < '3.12'
tensorflow-datasets; python_version < '3.12'
array-record==0.5.1; python_version < '3.12' and sys_platform != 'darwin' and platform_system != 'Windows'
etils==1.5.2; python_version < '3.12'

# If you make changes below this line, please also make the corresponding changes to `dl-gpu-requirements.txt`
# and to `install-dependencies.sh`!

--extra-index-url https://download.pytorch.org/whl/cpu  # for CPU versions of torch, torchvision
--find-links https://data.pyg.org/whl/torch-2.3.0+cpu.html  # for CPU versions of torch-scatter, torch-sparse, torch-cluster, torch-spline-conv

torch==2.3.0
torchmetrics==0.10.3
torchtext==0.18.0
torchvision==0.18.0

torch-scatter==2.1.2
torch-sparse==0.6.18
torch-cluster==1.6.3
torch-spline-conv==1.2.2
torch-geometric==2.5.3

cupy-cuda12x>=13.4.0; sys_platform != 'darwin'

# Keep JAX version consistent with dl-gpu-requirements.txt
jax==0.4.13; python_version < '3.12' and sys_platform != 'darwin'
jaxlib==0.4.13; python_version < '3.12' and sys_platform != 'darwin'
