diff --git a/setup.py b/setup.py index 6da562fc0c..cec5f41084 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,10 @@ def parse_requirements(): else: # detect the version of torch already installed # and set it so dependencies don't clobber the torch version - torch_version = version("torch") + try: + torch_version = version("torch") + except PackageNotFoundError: + torch_version = "2.5.1" _install_requires.append(f"torch=={torch_version}") version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) @@ -54,6 +57,10 @@ def parse_requirements(): if (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) + if patch == 0: + _install_requires.append("xformers==0.0.28.post2") + else: + _install_requires.append("xformers==0.0.28.post3") _install_requires.pop(_install_requires.index(autoawq_version)) elif (major, minor) >= (2, 4): if patch == 0: