jaxlib<=0.6.1,>=0.6.1
ml_dtypes>=0.5.0
numpy>=1.25
opt_einsum
scipy>=1.11.1

[:python_version >= "3.12"]
numpy>=1.26.0

[ci]
jaxlib==0.6.0

[cpu]

[cuda]
jaxlib<=0.6.1,>=0.6.1
jax-cuda12-plugin[with-cuda]<=0.6.1,>=0.6.1

[cuda12]
jaxlib<=0.6.1,>=0.6.1
jax-cuda12-plugin[with-cuda]<=0.6.1,>=0.6.1

[cuda12-local]
jaxlib<=0.6.1,>=0.6.1
jax-cuda12-plugin<=0.6.1,>=0.6.1

[k8s]
kubernetes

[minimum-jaxlib]
jaxlib==0.6.1

[rocm]
jaxlib<=0.6.1,>=0.6.1
jax-rocm60-plugin<=0.6.1,>=0.6.1

[tpu]
jaxlib<=0.6.1,>=0.6.1
libtpu==0.0.15.*
requests
