diff --git a/automl/tables/automl_tables_dataset.py b/automl/tables/automl_tables_dataset.py new file mode 100644 index 000000000000..144f2ee6b65e --- /dev/null +++ b/automl/tables/automl_tables_dataset.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This application demonstrates how to perform basic operations on dataset +with the Google AutoML Tables API. + +For more information, the documentation at +https://cloud.google.com/automl-tables/docs. +""" + +import argparse +import os + + +def create_dataset(project_id, compute_region, dataset_display_name): + """Create a dataset.""" + # [START automl_tables_create_dataset] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Create a dataset with the given display name + dataset = client.create_dataset(dataset_display_name) + + # Display the dataset information. + print("Dataset name: {}".format(dataset.name)) + print("Dataset id: {}".format(dataset.name.split("/")[-1])) + print("Dataset display name: {}".format(dataset.display_name)) + print("Dataset metadata:") + print("\t{}".format(dataset.tables_dataset_metadata)) + print("Dataset example count: {}".format(dataset.example_count)) + print("Dataset create time:") + print("\tseconds: {}".format(dataset.create_time.seconds)) + print("\tnanos: {}".format(dataset.create_time.nanos)) + + # [END automl_tables_create_dataset] + + return dataset + + +def list_datasets(project_id, compute_region, filter_=None): + """List all datasets.""" + result = [] + # [START automl_tables_list_datasets] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # filter_ = 'filter expression here' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # List all the datasets available in the region by applying filter. + response = client.list_datasets(filter_=filter_) + + print("List of datasets:") + for dataset in response: + # Display the dataset information. + print("Dataset name: {}".format(dataset.name)) + print("Dataset id: {}".format(dataset.name.split("/")[-1])) + print("Dataset display name: {}".format(dataset.display_name)) + metadata = dataset.tables_dataset_metadata + print( + "Dataset primary table spec id: {}".format( + metadata.primary_table_spec_id + ) + ) + print( + "Dataset target column spec id: {}".format( + metadata.target_column_spec_id + ) + ) + print( + "Dataset target column spec id: {}".format( + metadata.target_column_spec_id + ) + ) + print( + "Dataset weight column spec id: {}".format( + metadata.weight_column_spec_id + ) + ) + print( + "Dataset ml use column spec id: {}".format( + metadata.ml_use_column_spec_id + ) + ) + print("Dataset example count: {}".format(dataset.example_count)) + print("Dataset create time:") + print("\tseconds: {}".format(dataset.create_time.seconds)) + print("\tnanos: {}".format(dataset.create_time.nanos)) + print("\n") + + # [END automl_tables_list_datasets] + result.append(dataset) + + return result + + +def get_dataset(project_id, compute_region, dataset_display_name): + """Get the dataset.""" + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Get complete detail of the dataset. + dataset = client.get_dataset(dataset_display_name=dataset_display_name) + + # Display the dataset information. + print("Dataset name: {}".format(dataset.name)) + print("Dataset id: {}".format(dataset.name.split("/")[-1])) + print("Dataset display name: {}".format(dataset.display_name)) + print("Dataset metadata:") + print("\t{}".format(dataset.tables_dataset_metadata)) + print("Dataset example count: {}".format(dataset.example_count)) + print("Dataset create time:") + print("\tseconds: {}".format(dataset.create_time.seconds)) + print("\tnanos: {}".format(dataset.create_time.nanos)) + + return dataset + + +def import_data(project_id, compute_region, dataset_display_name, path): + """Import structured data.""" + # [START automl_tables_import_data] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME' + # path = 'gs://path/to/file.csv' or 'bq://project_id.dataset.table_id' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + response = None + if path.startswith("bq"): + response = client.import_data( + dataset_display_name=dataset_display_name, bigquery_input_uri=path + ) + else: + # Get the multiple Google Cloud Storage URIs. + input_uris = path.split(",") + response = client.import_data( + dataset_display_name=dataset_display_name, + gcs_input_uris=input_uris, + ) + + print("Processing import...") + # synchronous check of operation status. + print("Data imported. {}".format(response.result())) + + # [END automl_tables_import_data] + + +def update_dataset( + project_id, + compute_region, + dataset_display_name, + target_column_spec_name=None, + weight_column_spec_name=None, + test_train_column_spec_name=None, +): + """Update dataset.""" + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME_HERE' + # target_column_spec_name = 'TARGET_COLUMN_SPEC_NAME_HERE' or None + # weight_column_spec_name = 'WEIGHT_COLUMN_SPEC_NAME_HERE' or None + # test_train_column_spec_name = 'TEST_TRAIN_COLUMN_SPEC_NAME_HERE' or None + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + if target_column_spec_name is not None: + response = client.set_target_column( + dataset_display_name=dataset_display_name, + column_spec_display_name=target_column_spec_name, + ) + print("Target column updated. {}".format(response)) + if weight_column_spec_name is not None: + response = client.set_weight_column( + dataset_display_name=dataset_display_name, + column_spec_display_name=weight_column_spec_name, + ) + print("Weight column updated. {}".format(response)) + if test_train_column_spec_name is not None: + response = client.set_test_train_column( + dataset_display_name=dataset_display_name, + column_spec_display_name=test_train_column_spec_name, + ) + print("Test/train column updated. {}".format(response)) + + +def delete_dataset(project_id, compute_region, dataset_display_name): + """Delete a dataset""" + # [START automl_tables_delete_dataset] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME_HERE + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Delete a dataset. + response = client.delete_dataset(dataset_display_name=dataset_display_name) + + # synchronous check of operation status. + print("Dataset deleted. {}".format(response.result())) + # [END automl_tables_delete_dataset] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + subparsers = parser.add_subparsers(dest="command") + + create_dataset_parser = subparsers.add_parser( + "create_dataset", help=create_dataset.__doc__ + ) + create_dataset_parser.add_argument("--dataset_name") + + list_datasets_parser = subparsers.add_parser( + "list_datasets", help=list_datasets.__doc__ + ) + list_datasets_parser.add_argument("--filter_") + + get_dataset_parser = subparsers.add_parser( + "get_dataset", help=get_dataset.__doc__ + ) + get_dataset_parser.add_argument("--dataset_display_name") + + import_data_parser = subparsers.add_parser( + "import_data", help=import_data.__doc__ + ) + import_data_parser.add_argument("--dataset_display_name") + import_data_parser.add_argument("--path") + + update_dataset_parser = subparsers.add_parser( + "update_dataset", help=update_dataset.__doc__ + ) + update_dataset_parser.add_argument("--dataset_display_name") + update_dataset_parser.add_argument("--target_column_spec_name") + update_dataset_parser.add_argument("--weight_column_spec_name") + update_dataset_parser.add_argument("--ml_use_column_spec_name") + + delete_dataset_parser = subparsers.add_parser( + "delete_dataset", help=delete_dataset.__doc__ + ) + delete_dataset_parser.add_argument("--dataset_display_name") + + project_id = os.environ["PROJECT_ID"] + compute_region = os.environ["REGION_NAME"] + + args = parser.parse_args() + if args.command == "create_dataset": + create_dataset(project_id, compute_region, args.dataset_name) + if args.command == "list_datasets": + list_datasets(project_id, compute_region, args.filter_) + if args.command == "get_dataset": + get_dataset(project_id, compute_region, args.dataset_display_name) + if args.command == "import_data": + import_data( + project_id, compute_region, args.dataset_display_name, args.path + ) + if args.command == "update_dataset": + update_dataset( + project_id, + compute_region, + args.dataset_display_name, + args.target_column_spec_name, + args.weight_column_spec_name, + args.ml_use_column_spec_name, + ) + if args.command == "delete_dataset": + delete_dataset(project_id, compute_region, args.dataset_display_name) diff --git a/automl/tables/automl_tables_model.py b/automl/tables/automl_tables_model.py new file mode 100644 index 000000000000..a77dfe62d7a1 --- /dev/null +++ b/automl/tables/automl_tables_model.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This application demonstrates how to perform basic operations on model +with the Google AutoML Tables API. + +For more information, the documentation at +https://cloud.google.com/automl-tables/docs. +""" + +import argparse +import os + + +def create_model( + project_id, + compute_region, + dataset_display_name, + model_display_name, + train_budget_milli_node_hours, + include_column_spec_names=None, + exclude_column_spec_names=None, +): + """Create a model.""" + # [START automl_tables_create_model] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # dataset_display_name = 'DATASET_DISPLAY_NAME_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # train_budget_milli_node_hours = 'TRAIN_BUDGET_MILLI_NODE_HOURS_HERE' + # include_column_spec_names = 'INCLUDE_COLUMN_SPEC_NAMES_HERE' + # or None if unspecified + # exclude_column_spec_names = 'EXCLUDE_COLUMN_SPEC_NAMES_HERE' + # or None if unspecified + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Create a model with the model metadata in the region. + response = client.create_model( + model_display_name, + train_budget_milli_node_hours=train_budget_milli_node_hours, + dataset_display_name=dataset_display_name, + include_column_spec_names=include_column_spec_names, + exclude_column_spec_names=exclude_column_spec_names, + ) + + print("Training model...") + print("Training operation name: {}".format(response.operation.name)) + print("Training completed: {}".format(response.result())) + + # [END automl_tables_create_model] + + +def get_operation_status(operation_full_id): + """Get operation status.""" + # [START automl_tables_get_operation_status] + # TODO(developer): Uncomment and set the following variables + # operation_full_id = + # 'projects//locations//operations/' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient() + + # Get the latest state of a long-running operation. + op = client.auto_ml_client.transport._operations_client.get_operation( + operation_full_id + ) + + print("Operation status: {}".format(op)) + + # [END automl_tables_get_operation_status] + + +def list_models(project_id, compute_region, filter_=None): + """List all models.""" + result = [] + # [START automl_tables_list_models] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # filter_ = 'DATASET_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + from google.cloud.automl_v1beta1 import enums + + client = automl.TablesClient(project=project_id, region=compute_region) + + # List all the models available in the region by applying filter. + response = client.list_models(filter_=filter_) + + print("List of models:") + for model in response: + # Retrieve deployment state. + if model.deployment_state == enums.Model.DeploymentState.DEPLOYED: + deployment_state = "deployed" + else: + deployment_state = "undeployed" + + # Display the model information. + print("Model name: {}".format(model.name)) + print("Model id: {}".format(model.name.split("/")[-1])) + print("Model display name: {}".format(model.display_name)) + metadata = model.tables_model_metadata + print( + "Target column display name: {}".format( + metadata.target_column_spec.display_name + ) + ) + print( + "Training budget in node milli hours: {}".format( + metadata.train_budget_milli_node_hours + ) + ) + print( + "Training cost in node milli hours: {}".format( + metadata.train_cost_milli_node_hours + ) + ) + print("Model create time:") + print("\tseconds: {}".format(model.create_time.seconds)) + print("\tnanos: {}".format(model.create_time.nanos)) + print("Model deployment state: {}".format(deployment_state)) + print("\n") + + # [END automl_tables_list_models] + result.append(model) + + return result + + +def get_model(project_id, compute_region, model_display_name): + """Get model details.""" + # [START automl_tables_get_model] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + from google.cloud.automl_v1beta1 import enums + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Get complete detail of the model. + model = client.get_model(model_display_name=model_display_name) + + # Retrieve deployment state. + if model.deployment_state == enums.Model.DeploymentState.DEPLOYED: + deployment_state = "deployed" + else: + deployment_state = "undeployed" + + # get features of top importance + feat_list = [ + (column.feature_importance, column.column_display_name) + for column in model.tables_model_metadata.tables_model_column_info + ] + feat_list.sort(reverse=True) + if len(feat_list) < 10: + feat_to_show = len(feat_list) + else: + feat_to_show = 10 + + # Display the model information. + print("Model name: {}".format(model.name)) + print("Model id: {}".format(model.name.split("/")[-1])) + print("Model display name: {}".format(model.display_name)) + print("Features of top importance:") + for feat in feat_list[:feat_to_show]: + print(feat) + print("Model create time:") + print("\tseconds: {}".format(model.create_time.seconds)) + print("\tnanos: {}".format(model.create_time.nanos)) + print("Model deployment state: {}".format(deployment_state)) + + # [END automl_tables_get_model] + + return model + + +def list_model_evaluations( + project_id, compute_region, model_display_name, filter_=None +): + + """List model evaluations.""" + result = [] + # [START automl_tables_list_model_evaluations] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # filter_ = 'filter expression here' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # List all the model evaluations in the model by applying filter. + response = client.list_model_evaluations( + model_display_name=model_display_name, filter_=filter_ + ) + + print("List of model evaluations:") + for evaluation in response: + print("Model evaluation name: {}".format(evaluation.name)) + print("Model evaluation id: {}".format(evaluation.name.split("/")[-1])) + print( + "Model evaluation example count: {}".format( + evaluation.evaluated_example_count + ) + ) + print("Model evaluation time:") + print("\tseconds: {}".format(evaluation.create_time.seconds)) + print("\tnanos: {}".format(evaluation.create_time.nanos)) + print("\n") + # [END automl_tables_list_model_evaluations] + result.append(evaluation) + + return result + + +def get_model_evaluation( + project_id, compute_region, model_id, model_evaluation_id +): + """Get model evaluation.""" + # [START automl_tables_get_model_evaluation] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_id = 'MODEL_ID_HERE' + # model_evaluation_id = 'MODEL_EVALUATION_ID_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient() + + # Get the full path of the model evaluation. + model_evaluation_full_id = client.auto_ml_client.model_evaluation_path( + project_id, compute_region, model_id, model_evaluation_id + ) + + # Get complete detail of the model evaluation. + response = client.get_model_evaluation( + model_evaluation_name=model_evaluation_full_id + ) + + print(response) + # [END automl_tables_get_model_evaluation] + return response + + +def display_evaluation( + project_id, compute_region, model_display_name, filter_=None +): + """Display evaluation.""" + # [START automl_tables_display_evaluation] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # filter_ = 'filter expression here' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # List all the model evaluations in the model by applying filter. + response = client.list_model_evaluations( + model_display_name=model_display_name, filter_=filter_ + ) + + # Iterate through the results. + for evaluation in response: + # There is evaluation for each class in a model and for overall model. + # Get only the evaluation of overall model. + if not evaluation.annotation_spec_id: + model_evaluation_name = evaluation.name + break + + # Get a model evaluation. + model_evaluation = client.get_model_evaluation( + model_evaluation_name=model_evaluation_name + ) + + classification_metrics = model_evaluation.classification_evaluation_metrics + if str(classification_metrics): + confidence_metrics = classification_metrics.confidence_metrics_entry + + # Showing model score based on threshold of 0.5 + print("Model classification metrics (threshold at 0.5):") + for confidence_metrics_entry in confidence_metrics: + if confidence_metrics_entry.confidence_threshold == 0.5: + print( + "Model Precision: {}%".format( + round(confidence_metrics_entry.precision * 100, 2) + ) + ) + print( + "Model Recall: {}%".format( + round(confidence_metrics_entry.recall * 100, 2) + ) + ) + print( + "Model F1 score: {}%".format( + round(confidence_metrics_entry.f1_score * 100, 2) + ) + ) + print("Model AUPRC: {}".format(classification_metrics.au_prc)) + print("Model AUROC: {}".format(classification_metrics.au_roc)) + print("Model log loss: {}".format(classification_metrics.log_loss)) + + regression_metrics = model_evaluation.regression_evaluation_metrics + if str(regression_metrics): + print("Model regression metrics:") + print( + "Model RMSE: {}".format(regression_metrics.root_mean_squared_error) + ) + print("Model MAE: {}".format(regression_metrics.mean_absolute_error)) + print( + "Model MAPE: {}".format( + regression_metrics.mean_absolute_percentage_error + ) + ) + print("Model R^2: {}".format(regression_metrics.r_squared)) + + # [END automl_tables_display_evaluation] + + +def deploy_model(project_id, compute_region, model_display_name): + """Deploy model.""" + # [START automl_tables_deploy_model] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Deploy model + response = client.deploy_model(model_display_name=model_display_name) + + # synchronous check of operation status. + print("Model deployed. {}".format(response.result())) + + # [END automl_tables_deploy_model] + + +def undeploy_model(project_id, compute_region, model_display_name): + """Undeploy model.""" + # [START automl_tables_undeploy_model] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Undeploy model + response = client.undeploy_model(model_display_name=model_display_name) + + # synchronous check of operation status. + print("Model undeployed. {}".format(response.result())) + + # [END automl_tables_undeploy_model] + + +def delete_model(project_id, compute_region, model_display_name): + """Delete a model.""" + # [START automl_tables_delete_model] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Undeploy model + response = client.delete_model(model_display_name=model_display_name) + + # synchronous check of operation status. + print("Model deleted. {}".format(response.result())) + + # [END automl_tables_delete_model] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + subparsers = parser.add_subparsers(dest="command") + + create_model_parser = subparsers.add_parser( + "create_model", help=create_model.__doc__ + ) + create_model_parser.add_argument("--dataset_display_name") + create_model_parser.add_argument("--model_display_name") + create_model_parser.add_argument( + "--train_budget_milli_node_hours", type=int + ) + + get_operation_status_parser = subparsers.add_parser( + "get_operation_status", help=get_operation_status.__doc__ + ) + get_operation_status_parser.add_argument("--operation_full_id") + + list_models_parser = subparsers.add_parser( + "list_models", help=list_models.__doc__ + ) + list_models_parser.add_argument("--filter_") + + get_model_parser = subparsers.add_parser( + "get_model", help=get_model.__doc__ + ) + get_model_parser.add_argument("--model_display_name") + + list_model_evaluations_parser = subparsers.add_parser( + "list_model_evaluations", help=list_model_evaluations.__doc__ + ) + list_model_evaluations_parser.add_argument("--model_display_name") + list_model_evaluations_parser.add_argument("--filter_") + + get_model_evaluation_parser = subparsers.add_parser( + "get_model_evaluation", help=get_model_evaluation.__doc__ + ) + get_model_evaluation_parser.add_argument("--model_id") + get_model_evaluation_parser.add_argument("--model_evaluation_id") + + display_evaluation_parser = subparsers.add_parser( + "display_evaluation", help=display_evaluation.__doc__ + ) + display_evaluation_parser.add_argument("--model_display_name") + display_evaluation_parser.add_argument("--filter_") + + deploy_model_parser = subparsers.add_parser( + "deploy_model", help=deploy_model.__doc__ + ) + deploy_model_parser.add_argument("--model_display_name") + + undeploy_model_parser = subparsers.add_parser( + "undeploy_model", help=undeploy_model.__doc__ + ) + undeploy_model_parser.add_argument("--model_display_name") + + delete_model_parser = subparsers.add_parser( + "delete_model", help=delete_model.__doc__ + ) + delete_model_parser.add_argument("--model_display_name") + + project_id = os.environ["PROJECT_ID"] + compute_region = os.environ["REGION_NAME"] + + args = parser.parse_args() + + if args.command == "create_model": + create_model( + project_id, + compute_region, + args.dataset_display_name, + args.model_display_name, + args.train_budget_milli_node_hours, + # Input columns are omitted here as argparse does not support + # column spec objects, but it is still included in function def. + ) + if args.command == "get_operation_status": + get_operation_status(args.operation_full_id) + if args.command == "list_models": + list_models(project_id, compute_region, args.filter_) + if args.command == "get_model": + get_model(project_id, compute_region, args.model_display_name) + if args.command == "list_model_evaluations": + list_model_evaluations( + project_id, compute_region, args.model_display_name, args.filter_ + ) + if args.command == "get_model_evaluation": + get_model_evaluation( + project_id, + compute_region, + args.model_display_name, + args.model_evaluation_id, + ) + if args.command == "display_evaluation": + display_evaluation( + project_id, compute_region, args.model_display_name, args.filter_ + ) + if args.command == "deploy_model": + deploy_model(project_id, compute_region, args.model_display_name) + if args.command == "undeploy_model": + undeploy_model(project_id, compute_region, args.model_display_name) + if args.command == "delete_model": + delete_model(project_id, compute_region, args.model_display_name) diff --git a/automl/tables/automl_tables_predict.py b/automl/tables/automl_tables_predict.py new file mode 100644 index 000000000000..9787e1b9b4a6 --- /dev/null +++ b/automl/tables/automl_tables_predict.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This application demonstrates how to perform basic operations on prediction +with the Google AutoML Tables API. + +For more information, the documentation at +https://cloud.google.com/automl-tables/docs. +""" + +import argparse +import os + + +def predict( + project_id, + compute_region, + model_display_name, + inputs, + feature_importance=None, +): + """Make a prediction.""" + # [START automl_tables_predict] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # inputs = {'value': 3, ...} + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + if feature_importance: + response = client.predict( + model_display_name=model_display_name, + inputs=inputs, + feature_importance=True, + ) + else: + response = client.predict( + model_display_name=model_display_name, inputs=inputs + ) + + print("Prediction results:") + for result in response.payload: + print( + "Predicted class name: {}".format(result.tables.value.string_value) + ) + print("Predicted class score: {}".format(result.tables.score)) + + if feature_importance: + # get features of top importance + feat_list = [ + (column.feature_importance, column.column_display_name) + for column in result.tables.tables_model_column_info + ] + feat_list.sort(reverse=True) + if len(feat_list) < 10: + feat_to_show = len(feat_list) + else: + feat_to_show = 10 + + print("Features of top importance:") + for feat in feat_list[:feat_to_show]: + print(feat) + + # [END automl_tables_predict] + + +def batch_predict_bq( + project_id, + compute_region, + model_display_name, + bq_input_uri, + bq_output_uri, + params +): + """Make a batch of predictions.""" + # [START automl_tables_batch_predict_bq] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # bq_input_uri = 'bq://my-project.my-dataset.my-table' + # bq_output_uri = 'bq://my-project' + # params = {} + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Query model + response = client.batch_predict(bigquery_input_uri=bq_input_uri, + bigquery_output_uri=bq_output_uri, + model_display_name=model_display_name, + params=params) + print("Making batch prediction... ") + # `response` is a async operation descriptor, + # you can register a callback for the operation to complete via `add_done_callback`: + # def callback(operation_future): + # result = operation_future.result() + # response.add_done_callback(callback) + # + # or block the thread polling for the operation's results: + response.result() + # AutoML puts predictions in a newly generated dataset with a name by a mask "prediction_" + model_id + "_" + timestamp + # here's how to get the dataset name: + dataset_name = response.metadata.batch_predict_details.output_info.bigquery_output_dataset + + print("Batch prediction complete.\nResults are in '{}' dataset.\n{}".format( + dataset_name, response.metadata)) + + # [END automl_tables_batch_predict_bq] + + +def batch_predict( + project_id, + compute_region, + model_display_name, + gcs_input_uri, + gcs_output_uri, + params, +): + """Make a batch of predictions.""" + # [START automl_tables_batch_predict] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # gcs_input_uri = 'gs://YOUR_BUCKET_ID/path_to_your_input_csv' + # gcs_output_uri = 'gs://YOUR_BUCKET_ID/path_to_save_results/' + # params = {} + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Query model + response = client.batch_predict( + gcs_input_uris=gcs_input_uri, + gcs_output_uri_prefix=gcs_output_uri, + model_display_name=model_display_name, + params=params + ) + print("Making batch prediction... ") + # `response` is a async operation descriptor, + # you can register a callback for the operation to complete via `add_done_callback`: + # def callback(operation_future): + # result = operation_future.result() + # response.add_done_callback(callback) + # + # or block the thread polling for the operation's results: + response.result() + + print("Batch prediction complete.\n{}".format(response.metadata)) + + # [END automl_tables_batch_predict] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + subparsers = parser.add_subparsers(dest="command") + + predict_parser = subparsers.add_parser("predict", help=predict.__doc__) + predict_parser.add_argument("--model_display_name") + predict_parser.add_argument("--file_path") + + batch_predict_parser = subparsers.add_parser( + "batch_predict", help=predict.__doc__ + ) + batch_predict_parser.add_argument("--model_display_name") + batch_predict_parser.add_argument("--input_path") + batch_predict_parser.add_argument("--output_path") + + project_id = os.environ["PROJECT_ID"] + compute_region = os.environ["REGION_NAME"] + + args = parser.parse_args() + + if args.command == "predict": + predict( + project_id, compute_region, args.model_display_name, args.file_path + ) + + if args.command == "batch_predict": + batch_predict( + project_id, + compute_region, + args.model_display_name, + args.input_path, + args.output_path, + ) diff --git a/automl/tables/automl_tables_set_endpoint.py b/automl/tables/automl_tables_set_endpoint.py new file mode 100644 index 000000000000..d6ab898b4f5d --- /dev/null +++ b/automl/tables/automl_tables_set_endpoint.py @@ -0,0 +1,33 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def create_client_with_endpoint(gcp_project_id): + """Create a Tables client with a non-default endpoint.""" + # [START automl_set_endpoint] + from google.cloud import automl_v1beta1 as automl + from google.api_core.client_options import ClientOptions + + # Set the endpoint you want to use via the ClientOptions. + # gcp_project_id = 'YOUR_PROJECT_ID' + client_options = ClientOptions(api_endpoint="eu-automl.googleapis.com:443") + client = automl.TablesClient( + project=gcp_project_id, region="eu", client_options=client_options + ) + # [END automl_set_endpoint] + + # do simple test to check client connectivity + print(client.list_datasets()) + + return client diff --git a/automl/tables/batch_predict_test.py b/automl/tables/batch_predict_test.py new file mode 100644 index 000000000000..203f4c8d55a2 --- /dev/null +++ b/automl/tables/batch_predict_test.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud.automl_v1beta1.gapic import enums + +import pytest + +import automl_tables_model +import automl_tables_predict +import model_test + + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +REGION = "us-central1" +STATIC_MODEL = model_test.STATIC_MODEL +GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT) +GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT) +BQ_INPUT = "bq://{}.automl_test.bank_marketing".format(PROJECT) +BQ_OUTPUT = "bq://{}".format(PROJECT) +PARAMS = {} + + +@pytest.mark.slow +def test_batch_predict(capsys): + ensure_model_online() + + automl_tables_predict.batch_predict( + PROJECT, REGION, STATIC_MODEL, GCS_INPUT, GCS_OUTPUT, PARAMS + ) + out, _ = capsys.readouterr() + assert "Batch prediction complete" in out + + +@pytest.mark.slow +def test_batch_predict_bq(capsys): + ensure_model_online() + automl_tables_predict.batch_predict_bq( + PROJECT, REGION, STATIC_MODEL, BQ_INPUT, BQ_OUTPUT, PARAMS + ) + out, _ = capsys.readouterr() + assert "Batch prediction complete" in out + + +def ensure_model_online(): + model = model_test.ensure_model_ready() + if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: + automl_tables_model.deploy_model(PROJECT, REGION, model.display_name) + + return automl_tables_model.get_model(PROJECT, REGION, model.display_name) diff --git a/automl/tables/dataset_test.py b/automl/tables/dataset_test.py new file mode 100644 index 000000000000..27570f0bee98 --- /dev/null +++ b/automl/tables/dataset_test.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import string +import time + +from google.api_core import exceptions +import pytest + +import automl_tables_dataset + + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +REGION = "us-central1" +STATIC_DATASET = "do_not_delete_this_table_python" +GCS_DATASET = ("gs://python-docs-samples-tests-automl-tables-test" + "/bank-marketing.csv") + +ID = "{rand}_{time}".format( + rand="".join( + [random.choice(string.ascii_letters + string.digits) for n in range(4)] + ), + time=int(time.time()), +) + + +def _id(name): + return "{}_{}".format(name, ID) + + +def ensure_dataset_ready(): + dataset = None + name = STATIC_DATASET + try: + dataset = automl_tables_dataset.get_dataset(PROJECT, REGION, name) + except exceptions.NotFound: + dataset = automl_tables_dataset.create_dataset(PROJECT, REGION, name) + + if dataset.example_count is None or dataset.example_count == 0: + automl_tables_dataset.import_data(PROJECT, REGION, name, GCS_DATASET) + dataset = automl_tables_dataset.get_dataset(PROJECT, REGION, name) + + automl_tables_dataset.update_dataset( + PROJECT, + REGION, + dataset.display_name, + target_column_spec_name="Deposit", + ) + + return dataset + + +@pytest.mark.slow +def test_dataset_create_import_delete(capsys): + name = _id("d_cr_dl") + dataset = automl_tables_dataset.create_dataset(PROJECT, REGION, name) + assert dataset is not None + assert dataset.display_name == name + + automl_tables_dataset.import_data(PROJECT, REGION, name, GCS_DATASET) + + out, _ = capsys.readouterr() + assert "Data imported." in out + + automl_tables_dataset.delete_dataset(PROJECT, REGION, name) + + with pytest.raises(exceptions.NotFound): + automl_tables_dataset.get_dataset(PROJECT, REGION, name) + + +def test_dataset_update(capsys): + dataset = ensure_dataset_ready() + automl_tables_dataset.update_dataset( + PROJECT, + REGION, + dataset.display_name, + target_column_spec_name="Deposit", + weight_column_spec_name="Balance", + ) + + out, _ = capsys.readouterr() + assert "Target column updated." in out + assert "Weight column updated." in out + + +def test_list_datasets(): + ensure_dataset_ready() + assert ( + next( + ( + d + for d in automl_tables_dataset.list_datasets(PROJECT, REGION) + if d.display_name == STATIC_DATASET + ), + None, + ) + is not None + ) diff --git a/automl/tables/endpoint_test.py b/automl/tables/endpoint_test.py new file mode 100644 index 000000000000..5a20aba5c488 --- /dev/null +++ b/automl/tables/endpoint_test.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import automl_tables_set_endpoint + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] + + +def test_client_creation(capsys): + automl_tables_set_endpoint.create_client_with_endpoint(PROJECT) + out, _ = capsys.readouterr() + assert "GRPCIterator" in out diff --git a/automl/tables/model_test.py b/automl/tables/model_test.py new file mode 100644 index 000000000000..484eaf824878 --- /dev/null +++ b/automl/tables/model_test.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import string +import time + +from google.api_core import exceptions + +import automl_tables_model +import dataset_test + + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +REGION = "us-central1" +STATIC_MODEL = "do_not_delete_this_model_0" +GCS_DATASET = "gs://cloud-ml-tables-data/bank-marketing.csv" + +ID = "{rand}_{time}".format( + rand="".join( + [random.choice(string.ascii_letters + string.digits) for n in range(4)] + ), + time=int(time.time()), +) + + +def _id(name): + return "{}_{}".format(name, ID) + + +def test_list_models(): + ensure_model_ready() + assert ( + next( + ( + m + for m in automl_tables_model.list_models(PROJECT, REGION) + if m.display_name == STATIC_MODEL + ), + None, + ) + is not None + ) + + +def test_list_model_evaluations(): + model = ensure_model_ready() + mes = automl_tables_model.list_model_evaluations( + PROJECT, REGION, model.display_name + ) + assert len(mes) > 0 + for me in mes: + assert me.name.startswith(model.name) + + +def test_get_model_evaluations(): + model = ensure_model_ready() + me = automl_tables_model.list_model_evaluations( + PROJECT, REGION, model.display_name + )[0] + mep = automl_tables_model.get_model_evaluation( + PROJECT, + REGION, + model.name.rpartition("/")[2], + me.name.rpartition("/")[2], + ) + + assert mep.name == me.name + + +def ensure_model_ready(): + name = STATIC_MODEL + try: + return automl_tables_model.get_model(PROJECT, REGION, name) + except exceptions.NotFound: + pass + + dataset = dataset_test.ensure_dataset_ready() + return automl_tables_model.create_model( + PROJECT, REGION, dataset.display_name, name, 1000 + ) diff --git a/automl/tables/noxfile.py b/automl/tables/noxfile.py new file mode 100644 index 000000000000..ba55d7ce53ca --- /dev/null +++ b/automl/tables/noxfile.py @@ -0,0 +1,224 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +from pathlib import Path +import sys + +import nox + + +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING +# DO NOT EDIT THIS FILE EVER! +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING + +# Copy `noxfile_config.py` to your directory and modify it instead. + + +# `TEST_CONFIG` dict is a configuration hook that allows users to +# modify the test configurations. The values here should be in sync +# with `noxfile_config.py`. Users will copy `noxfile_config.py` into +# their directory and modify it. + +TEST_CONFIG = { + # You can opt out from the test for specific Python versions. + 'ignored_versions': ["2.7"], + + # An envvar key for determining the project id to use. Change it + # to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a + # build specific Cloud project. You can also use your own string + # to use your own Cloud project. + 'gcloud_project_env': 'GOOGLE_CLOUD_PROJECT', + # 'gcloud_project_env': 'BUILD_SPECIFIC_GCLOUD_PROJECT', + + # A dictionary you want to inject into your test. Don't put any + # secrets here. These values will override predefined values. + 'envs': {}, +} + + +try: + # Ensure we can import noxfile_config in the project's directory. + sys.path.append('.') + from noxfile_config import TEST_CONFIG_OVERRIDE +except ImportError as e: + print("No user noxfile_config found: detail: {}".format(e)) + TEST_CONFIG_OVERRIDE = {} + +# Update the TEST_CONFIG with the user supplied values. +TEST_CONFIG.update(TEST_CONFIG_OVERRIDE) + + +def get_pytest_env_vars(): + """Returns a dict for pytest invocation.""" + ret = {} + + # Override the GCLOUD_PROJECT and the alias. + env_key = TEST_CONFIG['gcloud_project_env'] + # This should error out if not set. + ret['GOOGLE_CLOUD_PROJECT'] = os.environ[env_key] + + # Apply user supplied envs. + ret.update(TEST_CONFIG['envs']) + return ret + + +# DO NOT EDIT - automatically generated. +# All versions used to tested samples. +ALL_VERSIONS = ["2.7", "3.6", "3.7", "3.8"] + +# Any default versions that should be ignored. +IGNORED_VERSIONS = TEST_CONFIG['ignored_versions'] + +TESTED_VERSIONS = sorted([v for v in ALL_VERSIONS if v not in IGNORED_VERSIONS]) + +INSTALL_LIBRARY_FROM_SOURCE = bool(os.environ.get("INSTALL_LIBRARY_FROM_SOURCE", False)) +# +# Style Checks +# + + +def _determine_local_import_names(start_dir): + """Determines all import names that should be considered "local". + + This is used when running the linter to insure that import order is + properly checked. + """ + file_ext_pairs = [os.path.splitext(path) for path in os.listdir(start_dir)] + return [ + basename + for basename, extension in file_ext_pairs + if extension == ".py" + or os.path.isdir(os.path.join(start_dir, basename)) + and basename not in ("__pycache__") + ] + + +# Linting with flake8. +# +# We ignore the following rules: +# E203: whitespace before ‘:’ +# E266: too many leading ‘#’ for block comment +# E501: line too long +# I202: Additional newline in a section of imports +# +# We also need to specify the rules which are ignored by default: +# ['E226', 'W504', 'E126', 'E123', 'W503', 'E24', 'E704', 'E121'] +FLAKE8_COMMON_ARGS = [ + "--show-source", + "--builtin=gettext", + "--max-complexity=20", + "--import-order-style=google", + "--exclude=.nox,.cache,env,lib,generated_pb2,*_pb2.py,*_pb2_grpc.py", + "--ignore=E121,E123,E126,E203,E226,E24,E266,E501,E704,W503,W504,I202", + "--max-line-length=88", +] + + +@nox.session +def lint(session): + session.install("flake8", "flake8-import-order") + + local_names = _determine_local_import_names(".") + args = FLAKE8_COMMON_ARGS + [ + "--application-import-names", + ",".join(local_names), + "." + ] + session.run("flake8", *args) + + +# +# Sample Tests +# + + +PYTEST_COMMON_ARGS = ["--junitxml=sponge_log.xml"] + + +def _session_tests(session, post_install=None): + """Runs py.test for a particular project.""" + if os.path.exists("requirements.txt"): + session.install("-r", "requirements.txt") + + if os.path.exists("requirements-test.txt"): + session.install("-r", "requirements-test.txt") + + if INSTALL_LIBRARY_FROM_SOURCE: + session.install("-e", _get_repo_root()) + + if post_install: + post_install(session) + + session.run( + "pytest", + *(PYTEST_COMMON_ARGS + session.posargs), + # Pytest will return 5 when no tests are collected. This can happen + # on travis where slow and flaky tests are excluded. + # See http://doc.pytest.org/en/latest/_modules/_pytest/main.html + success_codes=[0, 5], + env=get_pytest_env_vars() + ) + + +@nox.session(python=ALL_VERSIONS) +def py(session): + """Runs py.test for a sample using the specified version of Python.""" + if session.python in TESTED_VERSIONS: + _session_tests(session) + else: + session.skip("SKIPPED: {} tests are disabled for this sample.".format( + session.python + )) + + +# +# Readmegen +# + + +def _get_repo_root(): + """ Returns the root folder of the project. """ + # Get root of this repository. Assume we don't have directories nested deeper than 10 items. + p = Path(os.getcwd()) + for i in range(10): + if p is None: + break + if Path(p / ".git").exists(): + return str(p) + p = p.parent + raise Exception("Unable to detect repository root.") + + +GENERATED_READMES = sorted([x for x in Path(".").rglob("*.rst.in")]) + + +@nox.session +@nox.parametrize("path", GENERATED_READMES) +def readmegen(session, path): + """(Re-)generates the readme for a sample.""" + session.install("jinja2", "pyyaml") + dir_ = os.path.dirname(path) + + if os.path.exists(os.path.join(dir_, "requirements.txt")): + session.install("-r", os.path.join(dir_, "requirements.txt")) + + in_file = os.path.join(dir_, "README.rst.in") + session.run( + "python", _get_repo_root() + "/scripts/readme-gen/readme_gen.py", in_file + ) diff --git a/automl/tables/predict_test.py b/automl/tables/predict_test.py new file mode 100644 index 000000000000..d608e182f1f0 --- /dev/null +++ b/automl/tables/predict_test.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud.automl_v1beta1.gapic import enums + +import automl_tables_model +import automl_tables_predict +import model_test + + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +REGION = "us-central1" +STATIC_MODEL = model_test.STATIC_MODEL + + +def test_predict(capsys): + inputs = { + "Age": 31, + "Balance": 200, + "Campaign": 2, + "Contact": "cellular", + "Day": "4", + "Default": "no", + "Duration": 12, + "Education": "primary", + "Housing": "yes", + "Job": "blue-collar", + "Loan": "no", + "MaritalStatus": "divorced", + "Month": "jul", + "PDays": 4, + "POutcome": "0", + "Previous": 12, + } + + ensure_model_online() + automl_tables_predict.predict(PROJECT, REGION, STATIC_MODEL, inputs, True) + out, _ = capsys.readouterr() + assert "Predicted class name:" in out + assert "Predicted class score:" in out + assert "Features of top importance:" in out + + +def ensure_model_online(): + model = model_test.ensure_model_ready() + if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: + automl_tables_model.deploy_model(PROJECT, REGION, model.display_name) + + return automl_tables_model.get_model(PROJECT, REGION, model.display_name) diff --git a/automl/tables/requirements-test.txt b/automl/tables/requirements-test.txt new file mode 100644 index 000000000000..7e460c8c866e --- /dev/null +++ b/automl/tables/requirements-test.txt @@ -0,0 +1 @@ +pytest==6.0.1 diff --git a/automl/tables/requirements.txt b/automl/tables/requirements.txt new file mode 100644 index 000000000000..867dfc61e77d --- /dev/null +++ b/automl/tables/requirements.txt @@ -0,0 +1 @@ +google-cloud-automl==1.0.1