diff --git a/launch/client.py b/launch/client.py index 1bae2350..adbd1401 100644 --- a/launch/client.py +++ b/launch/client.py @@ -39,7 +39,7 @@ SyncEndpoint, ) from launch.request_validation import validate_task_request -from launch.utils import trim_kwargs +from launch.utils import infer_env_params, trim_kwargs DEFAULT_NETWORK_TIMEOUT_SEC = 120 @@ -198,9 +198,10 @@ def create_model_bundle_from_dirs( model_bundle_name: str, base_paths: List[str], requirements_path: str, - env_params: Dict[str, str], load_predict_fn_module_path: str, load_model_fn_module_path: str, + env_params: Optional[Dict[str, str]], + env_selector: Optional[str], app_config: Optional[Union[Dict[str, Any], str]] = None, ) -> ModelBundle: """ @@ -275,6 +276,9 @@ def create_model_bundle_from_dirs( with open(requirements_path, "r", encoding="utf-8") as req_f: requirements = req_f.read().splitlines() + if env_params is None: + env_params = infer_env_params(env_selector) + tmpdir = tempfile.mkdtemp() try: zip_path = os.path.join(tmpdir, "bundle.zip") @@ -331,7 +335,8 @@ def create_model_bundle_from_dirs( def create_model_bundle( # pylint: disable=too-many-statements self, model_bundle_name: str, - env_params: Dict[str, str], + env_params: Optional[Dict[str, str]], + env_selector: Optional[str], *, load_predict_fn: Optional[ Callable[[LaunchModel_T], Callable[[Any], Any]] @@ -435,6 +440,9 @@ def create_model_bundle( # pylint: disable=too-many-statements ) # TODO should we try to catch when people intentionally pass both model and load_model_fn as None? + if env_params is None: + env_params = infer_env_params(env_selector) + if requirements is None: # TODO explore: does globals() actually work as expected? Should we use globals_copy instead? requirements_inferred = find_packages_from_imports(globals()) diff --git a/launch/utils.py b/launch/utils.py index 08d4fb75..7b596399 100644 --- a/launch/utils.py +++ b/launch/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional def trim_kwargs(kwargs_dict: Dict[Any, Any]): @@ -7,3 +7,52 @@ def trim_kwargs(kwargs_dict: Dict[Any, Any]): """ dict_copy = {k: v for k, v in kwargs_dict.items() if v is not None} return dict_copy + + +def infer_env_params(env_selector: Optional[str]): + """ + Returns an env_params dict from the env_selector. + + env_selector: str - Either "pytorch" or "tensorflow" + """ + if env_selector == "pytorch": + import torch + + try: + ver = torch.__version__.split("+") + torch_version = ver[0] + cuda_version = ver[1][2:] if len(ver) > 1 else "113" + cudnn_available = torch.backends.cudnn.is_available() + cudnn_version = ( + torch.backends.cudnn.version()[:1] + if cudnn_available is not None + else "8" + ) + + if ( + len(cuda_version) < 3 + ): # we can only parse cuda versions in the double digits + raise ValueError( + "PyTorch version parsing does not support CUDA versions below 10.0" + ) + tag = f"{torch_version}-cuda{cuda_version[:2]}.{cuda_version[2:]}-cudnn{cudnn_version}-runtime" + return { + "framework_type": "pytorch", + "pytorch_image_tag": tag, + } + except: + raise ValueError( + f"Failed to parse PyTorch version {torch.__version__}, try setting your own env_params." + ) + elif env_selector == "tensorflow": + import tensorflow as tf + + ver = tf.__version__ + return { + "framework_type": "tensorflow", + "tensorflow_version": ver, + } + else: + raise ValueError( + "Unsupported env_selector, please set to pytorch or tensorflow, or set your own env_params." + )