Skip to content

Commit

Permalink
Gcp ai hyperparameter tuning (#17790)
Browse files Browse the repository at this point in the history
* Add hyperparameters to MLEngineStartTrainingJobOperator

* Fixed pre-commit errors

* Added passing hyperparameters to MLEngineStartTrainingJobOperator in example_mlengine.py

* Added passing hyperparameters to Google ML CreateTrainingJob operator
  • Loading branch information
subkanthi authored Aug 27, 2021
1 parent 87769db commit aa5952e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
37 changes: 37 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,42 @@
tags=['example'],
params={"model_name": MODEL_NAME},
) as dag:
hyperparams = {
'goal': 'MAXIMIZE',
'hyperparameterMetricTag': 'metric1',
'maxTrials': 30,
'maxParallelTrials': 1,
'enableTrialEarlyStopping': True,
'params': [],
}

hyperparams['params'].append(
{
'parameterName': 'hidden1',
'type': 'INTEGER',
'minValue': 40,
'maxValue': 400,
'scaleType': 'UNIT_LINEAR_SCALE',
}
)

hyperparams['params'].append(
{'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]}
)

hyperparams['params'].append(
{
'parameterName': 'rnnCellType',
'type': 'CATEGORICAL',
'categoricalValues': [
'BasicLSTMCell',
'BasicRNNCell',
'GRUCell',
'LSTMCell',
'LayerNormBasicLSTMCell',
],
}
)
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator(
task_id="training",
Expand All @@ -77,6 +113,7 @@
training_python_module=TRAINER_PY_MODULE,
training_args=[],
labels={"job_type": "training"},
hyperparameters=hyperparams,
)
# [END howto_operator_gcp_mlengine_training]

Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:type mode: str
:param labels: a dictionary containing labels for the job; passed to BigQuery
:type labels: Dict[str, str]
:param hyperparameters: Optional HyperparameterSpec dictionary for hyperparameter tuning.
For further reference, check:
https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#HyperparameterSpec
:type hyperparameters: Dict
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
Expand All @@ -1149,6 +1153,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
'_python_version',
'_job_dir',
'_service_account',
'_hyperparameters',
'_impersonation_chain',
]

Expand All @@ -1175,6 +1180,7 @@ def __init__(
mode: str = 'PRODUCTION',
labels: Optional[Dict[str, str]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
hyperparameters: Optional[Dict] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1195,6 +1201,7 @@ def __init__(
self._delegate_to = delegate_to
self._mode = mode
self._labels = labels
self._hyperparameters = hyperparameters
self._impersonation_chain = impersonation_chain

custom = self._scale_tier is not None and self._scale_tier.upper() == 'CUSTOM'
Expand Down Expand Up @@ -1260,6 +1267,9 @@ def execute(self, context):
if self._service_account:
training_request['trainingInput']['serviceAccount'] = self._service_account

if self._hyperparameters:
training_request['trainingInput']['hyperparameters'] = self._hyperparameters

if self._labels:
training_request['labels'] = self._labels

Expand Down
40 changes: 40 additions & 0 deletions tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,45 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com'

hyperparams = {
'goal': 'MAXIMIZE',
'hyperparameterMetricTag': 'metric1',
'maxTrials': 30,
'maxParallelTrials': 1,
'enableTrialEarlyStopping': True,
'params': [],
}

hyperparams['params'].append(
{
'parameterName': 'hidden1',
'type': 'INTEGER',
'minValue': 40,
'maxValue': 400,
'scaleType': 'UNIT_LINEAR_SCALE',
}
)

hyperparams['params'].append(
{'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]}
)

hyperparams['params'].append(
{
'parameterName': 'rnnCellType',
'type': 'CATEGORICAL',
'categoricalValues': [
'BasicLSTMCell',
'BasicRNNCell',
'GRUCell',
'LSTMCell',
'LayerNormBasicLSTMCell',
],
}
)

training_input['trainingInput']['hyperparameters'] = hyperparams

success_response = self.TRAINING_INPUT.copy()
success_response['state'] = 'SUCCEEDED'
hook_instance = mock_hook.return_value
Expand All @@ -468,6 +507,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
job_dir='gs://some-bucket/jobs/test_training',
service_account='test@serviceaccount.com',
**self.TRAINING_DEFAULT_ARGS,
hyperparameters=hyperparams,
)
training_op.execute(MagicMock())

Expand Down

0 comments on commit aa5952e

Please sign in to comment.