From 36ffc778acbcb6da8985ba6f68723fe95eb92681 Mon Sep 17 00:00:00 2001 From: subkanthi Date: Mon, 23 Aug 2021 13:35:01 -0400 Subject: [PATCH 1/4] Add hyperparameters to MLEngineStartTrainingJobOperator --- .../google/cloud/operators/mlengine.py | 10 ++++++ .../google/cloud/operators/test_mlengine.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index 047a39e770ec6..c227f8f785392 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -1124,6 +1124,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator): :type mode: str :param labels: a dictionary containing labels for the job; passed to BigQuery :type labels: Dict[str, str] + :param hyperparameters: Optional HyperparameterSpec dictionary for hyperparameter tuning. + For further reference, check: + https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#HyperparameterSpec + :type hyperparameters: Dict :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -1149,6 +1153,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator): '_python_version', '_job_dir', '_service_account', + '_hyperparameters', '_impersonation_chain', ] @@ -1175,6 +1180,7 @@ def __init__( mode: str = 'PRODUCTION', labels: Optional[Dict[str, str]] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + hyperparameters: Optional[Dict] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1195,6 +1201,7 @@ def __init__( self._delegate_to = delegate_to self._mode = mode self._labels = labels + self._hyperparameters = hyperparameters self._impersonation_chain = impersonation_chain custom = self._scale_tier is not None and self._scale_tier.upper() == 'CUSTOM' @@ -1260,6 +1267,9 @@ def execute(self, context): if self._service_account: training_request['trainingInput']['serviceAccount'] = self._service_account + if self._hyperparameters: + training_request['trainingInput']['hyperparameters'] = self._hyperparameters + if self._labels: training_request['labels'] = self._labels diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index d67c8a2b9986e..5a812a35c38a2 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -457,6 +457,40 @@ def test_success_create_training_job_with_optional_args(self, mock_hook): training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training' training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com' + hyperparams = { + 'goal': 'MAXIMIZE', + 'hyperparameterMetricTag': 'metric1', + 'maxTrials': 30, + 'maxParallelTrials': 1, + 'enableTrialEarlyStopping': True, + 'params': []} + + hyperparams['params'].append({ + 'parameterName': 'hidden1', + 'type': 'INTEGER', + 'minValue': 40, + 'maxValue': 400, + 'scaleType': 'UNIT_LINEAR_SCALE'}) + + hyperparams['params'].append({ + 'parameterName': 'numRnnCells', + 'type': 'DISCRETE', + 'discreteValues': [1, 2, 3, 4]}) + + hyperparams['params'].append({ + 'parameterName': 'rnnCellType', + 'type': 'CATEGORICAL', + 'categoricalValues': [ + 'BasicLSTMCell', + 'BasicRNNCell', + 'GRUCell', + 'LSTMCell', + 'LayerNormBasicLSTMCell' + ] + }) + + training_input['trainingInput']['hyperparameters'] = hyperparams + success_response = self.TRAINING_INPUT.copy() success_response['state'] = 'SUCCEEDED' hook_instance = mock_hook.return_value @@ -468,6 +502,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook): job_dir='gs://some-bucket/jobs/test_training', service_account='test@serviceaccount.com', **self.TRAINING_DEFAULT_ARGS, + hyperparameters=hyperparams ) training_op.execute(MagicMock()) From 8c3d83bc4e61b75daec81f1de84b8e9445309ec0 Mon Sep 17 00:00:00 2001 From: subkanthi Date: Mon, 23 Aug 2021 13:37:28 -0400 Subject: [PATCH 2/4] Fixed pre-commit errors --- .../google/cloud/operators/test_mlengine.py | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index 5a812a35c38a2..7a226e325ea05 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -463,31 +463,36 @@ def test_success_create_training_job_with_optional_args(self, mock_hook): 'maxTrials': 30, 'maxParallelTrials': 1, 'enableTrialEarlyStopping': True, - 'params': []} - - hyperparams['params'].append({ - 'parameterName': 'hidden1', - 'type': 'INTEGER', - 'minValue': 40, - 'maxValue': 400, - 'scaleType': 'UNIT_LINEAR_SCALE'}) - - hyperparams['params'].append({ - 'parameterName': 'numRnnCells', - 'type': 'DISCRETE', - 'discreteValues': [1, 2, 3, 4]}) - - hyperparams['params'].append({ - 'parameterName': 'rnnCellType', - 'type': 'CATEGORICAL', - 'categoricalValues': [ - 'BasicLSTMCell', - 'BasicRNNCell', - 'GRUCell', - 'LSTMCell', - 'LayerNormBasicLSTMCell' - ] - }) + 'params': [], + } + + hyperparams['params'].append( + { + 'parameterName': 'hidden1', + 'type': 'INTEGER', + 'minValue': 40, + 'maxValue': 400, + 'scaleType': 'UNIT_LINEAR_SCALE', + } + ) + + hyperparams['params'].append( + {'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]} + ) + + hyperparams['params'].append( + { + 'parameterName': 'rnnCellType', + 'type': 'CATEGORICAL', + 'categoricalValues': [ + 'BasicLSTMCell', + 'BasicRNNCell', + 'GRUCell', + 'LSTMCell', + 'LayerNormBasicLSTMCell', + ], + } + ) training_input['trainingInput']['hyperparameters'] = hyperparams @@ -502,7 +507,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook): job_dir='gs://some-bucket/jobs/test_training', service_account='test@serviceaccount.com', **self.TRAINING_DEFAULT_ARGS, - hyperparameters=hyperparams + hyperparameters=hyperparams, ) training_op.execute(MagicMock()) From 61e1c4c5bb734affb3640db6f97412270a248691 Mon Sep 17 00:00:00 2001 From: subkanthi Date: Mon, 23 Aug 2021 13:43:41 -0400 Subject: [PATCH 3/4] Added passing hyperparameters to MLEngineStartTrainingJobOperator in example_mlengine.py --- .../cloud/example_dags/example_mlengine.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/airflow/providers/google/cloud/example_dags/example_mlengine.py b/airflow/providers/google/cloud/example_dags/example_mlengine.py index 56214e924ec72..d368271cb2b0f 100644 --- a/airflow/providers/google/cloud/example_dags/example_mlengine.py +++ b/airflow/providers/google/cloud/example_dags/example_mlengine.py @@ -57,6 +57,7 @@ SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/") + with models.DAG( "example_gcp_mlengine", schedule_interval=None, # Override to match your needs @@ -64,6 +65,42 @@ tags=['example'], params={"model_name": MODEL_NAME}, ) as dag: + hyperparams = { + 'goal': 'MAXIMIZE', + 'hyperparameterMetricTag': 'metric1', + 'maxTrials': 30, + 'maxParallelTrials': 1, + 'enableTrialEarlyStopping': True, + 'params': [], + } + + hyperparams['params'].append( + { + 'parameterName': 'hidden1', + 'type': 'INTEGER', + 'minValue': 40, + 'maxValue': 400, + 'scaleType': 'UNIT_LINEAR_SCALE', + } + ) + + hyperparams['params'].append( + {'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]} + ) + + hyperparams['params'].append( + { + 'parameterName': 'rnnCellType', + 'type': 'CATEGORICAL', + 'categoricalValues': [ + 'BasicLSTMCell', + 'BasicRNNCell', + 'GRUCell', + 'LSTMCell', + 'LayerNormBasicLSTMCell', + ], + } + ) # [START howto_operator_gcp_mlengine_training] training = MLEngineStartTrainingJobOperator( task_id="training", @@ -77,6 +114,7 @@ training_python_module=TRAINER_PY_MODULE, training_args=[], labels={"job_type": "training"}, + hyperparameters=hyperparams ) # [END howto_operator_gcp_mlengine_training] From ceca3b48d1a1bd72f9fa4049dc65f72748853d06 Mon Sep 17 00:00:00 2001 From: subkanthi Date: Mon, 23 Aug 2021 14:12:16 -0400 Subject: [PATCH 4/4] Added passing hyperparameters to Google ML CreateTrainingJob operator --- .../providers/google/cloud/example_dags/example_mlengine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_mlengine.py b/airflow/providers/google/cloud/example_dags/example_mlengine.py index d368271cb2b0f..082392c3e3ffc 100644 --- a/airflow/providers/google/cloud/example_dags/example_mlengine.py +++ b/airflow/providers/google/cloud/example_dags/example_mlengine.py @@ -57,7 +57,6 @@ SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/") - with models.DAG( "example_gcp_mlengine", schedule_interval=None, # Override to match your needs @@ -114,7 +113,7 @@ training_python_module=TRAINER_PY_MODULE, training_args=[], labels={"job_type": "training"}, - hyperparameters=hyperparams + hyperparameters=hyperparams, ) # [END howto_operator_gcp_mlengine_training]