# If you make changes below this line, please also make the corresponding changes to `dl-cpu-requirements.txt`!

tensorflow==2.15.1; python_version < '3.12' and (sys_platform != 'darwin' or platform_machine != 'arm64')
tensorflow-macos==2.15.1; python_version < '3.12' and sys_platform == 'darwin' and platform_machine == 'arm64'
tensorflow-probability==0.23.0; python_version < '3.12'
tensorflow-datasets; python_version < '3.12'

--extra-index-url https://download.pytorch.org/whl/cu121  # for GPU versions of torch, torchvision
--find-links https://data.pyg.org/whl/torch-2.3.0+cu121.html  # for GPU versions of torch-scatter, torch-sparse, torch-cluster, torch-spline-conv
# specifying explicit plus-notation below so pip overwrites the existing cpu verisons
torch==2.3.0+cu121
torchvision==0.18.0+cu121
torch-scatter==2.1.2+pt23cu121
torch-sparse==0.6.18+pt23cu121
torch-cluster==1.6.3+pt23cu121
torch-spline-conv==1.2.2+pt23cu121

cupy-cuda12x>=13.4.0; sys_platform != 'darwin'
nixl==0.4.0; sys_platform != 'darwin'

--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Downgrading to JAX 0.4.13 to be compatible with CUDA 12.1
jaxlib==0.4.13+cuda12.cudnn89; python_version < '3.12' and sys_platform != 'darwin'
