jaxlib<=0.5.3,>=0.5.3
ml_dtypes>=0.4.0
numpy>=1.25
opt_einsum
scipy>=1.11.1

[:python_version >= "3.12"]
numpy>=1.26.0

[ci]
jaxlib==0.5.1

[cpu]

[cuda]
jaxlib==0.5.3
jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3

[cuda12]
jaxlib==0.5.3
jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3

[cuda12_local]
jaxlib==0.5.3
jax-cuda12-plugin==0.5.3

[cuda12_pip]
jaxlib==0.5.3
jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3

[k8s]
kubernetes

[minimum-jaxlib]
jaxlib==0.5.3

[rocm]
jaxlib==0.5.3
jax-rocm60-plugin<=0.5.3,>=0.5.3

[tpu]
jaxlib<=0.5.3,>=0.5.3
libtpu==0.0.11.*
requests
