diff --git a/setup.py b/setup.py index e16e1ed253b0f4..b34acd5986298e 100644 --- a/setup.py +++ b/setup.py @@ -89,9 +89,12 @@ ] extras["torch"] = ["torch>=1.0"] -extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] if os.name == "nt": # windows + extras["retrieval"] = ["datasets"] # faiss is not supported on windows extras["flax"] = [] # jax is not supported on windows +else: + extras["retrieval"] = ["faiss-cpu", "datasets"] + extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] extras["tokenizers"] = ["tokenizers==0.9.2"] extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]