diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index fc3a5b2ddae..c52b55ab482 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -140,6 +140,7 @@ def tune( parameters: Dict[str, Any], base_image: str = constants.BASE_IMAGE_TENSORFLOW, namespace: Optional[str] = None, + env_per_trial: Optional[Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]] = None, algorithm_name: str = "random", algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None, objective_metric_name: str = None, @@ -172,6 +173,12 @@ def tune( objective function. base_image: Image to use when executing the objective function. namespace: Namespace for the Experiment. + env_per_trial: Environment variable(s) to be attached to each trial container. + You can specify a dictionary as a mapping object representing the environment variables. + Otherwise, you can specify a list, in which the element can either be a kubernetes.client.models.V1EnvVar (documented here: + https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) + or a kubernetes.client.models.V1EnvFromSource (documented here: + https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) algorithm_name: Search algorithm for the HyperParameter tuning. algorithm_settings: Settings for the search algorithm given. For available fields, check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail. @@ -318,6 +325,15 @@ def tune( requests=resources_per_trial, limits=resources_per_trial, ) + + if isinstance(env_per_trial, dict): + env, env_from = [client.V1EnvVar(name=str(k), value=str(v)) for k, v in env_per_trial.items()] or None, None + + if env_per_trial: + env = [x for x in env_per_trial if isinstance(x, client.V1EnvVar)] or None + env_from = [x for x in env_per_trial if isinstance(x, client.V1EnvFromSource)] or None + else: + env, env_from = None, None # Create Trial specification. trial_spec = client.V1Job( @@ -336,6 +352,8 @@ def tune( image=base_image, command=["bash", "-c"], args=[exec_script], + env=env, + env_from=env_from, resources=resources_per_trial, ) ],