diff --git a/setup.py b/setup.py index 7eb94528af..001877e23c 100644 --- a/setup.py +++ b/setup.py @@ -39,9 +39,12 @@ def _make_version_file(version, sha): def _get_pytorch_version(): - if "PYTORCH_VERSION" in os.environ: - return f"torch=={os.environ['PYTORCH_VERSION']}" - return "torch" + pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch") + if version_pin := os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + version_pin + elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")): + pytorch_dep += f">={version_pin_ge},<{version_pin_lt}" + return pytorch_dep class clean(distutils.command.clean.clean):