Skip to content

Commit

Permalink
Enable users to pass in Pandas Dataframe when calling import_data() a…
Browse files Browse the repository at this point in the history
…nd batch_predict() from AutoML Tables client (googleapis#9116)
  • Loading branch information
TrucHLe authored and emar-kar committed Sep 18, 2019
1 parent eb92473 commit b9c0892
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 23 deletions.
14 changes: 13 additions & 1 deletion automl/google/cloud/automl_v1beta1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.cloud.automl_v1beta1.gapic import auto_ml_client
from google.cloud.automl_v1beta1.gapic import enums
from google.cloud.automl_v1beta1.gapic import prediction_service_client
from google.cloud.automl_v1beta1.tables import gcs_client
from google.cloud.automl_v1beta1.tables import tables_client


Expand All @@ -38,4 +39,15 @@ class PredictionServiceClient(prediction_service_client.PredictionServiceClient)
enums = enums


__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", "TablesClient")
class GcsClient(gcs_client.GcsClient):
__doc__ = gcs_client.GcsClient.__doc__


__all__ = (
"enums",
"types",
"AutoMlClient",
"PredictionServiceClient",
"TablesClient",
"GcsClient",
)
112 changes: 112 additions & 0 deletions automl/google/cloud/automl_v1beta1/tables/gcs_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
#
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wraps the Google Cloud Storage client library for use in tables helper."""

import time

from google.api_core import exceptions

try:
import pandas
except ImportError: # pragma: NO COVER
pandas = None

try:
from google.cloud import storage
except ImportError: # pragma: NO COVER
storage = None

_PANDAS_REQUIRED = "pandas is required to verify type DataFrame."
_STORAGE_REQUIRED = (
"google-cloud-storage is required to create Google Cloud Storage client."
)


class GcsClient(object):
"""Uploads Pandas DataFrame to a bucket in Google Cloud Storage."""

def __init__(self, client=None, credentials=None):
"""Constructor.
Args:
client (Optional[storage.Client]): A Google Cloud Storage Client
instance.
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If none
are specified, the client will attempt to ascertain the
credentials from the environment.
"""
if storage is None:
raise ImportError(_STORAGE_REQUIRED)

if client is not None:
self.client = client
elif credentials is not None:
self.client = storage.Client(credentials=credentials)
else:
self.client = storage.Client()

def ensure_bucket_exists(self, project, region):
"""Checks if a bucket named '{project}-automl-tables-staging' exists.
Creates this bucket if it doesn't exist.
Args:
project (str): The project that stores the bucket.
region (str): The region of the bucket.
Returns:
A string representing the created bucket name.
"""
bucket_name = "{}-automl-tables-staging".format(project)

try:
self.client.get_bucket(bucket_name)
except exceptions.NotFound:
bucket = self.client.bucket(bucket_name)
bucket.create(project=project, location=region)
return bucket_name

def upload_pandas_dataframe(self, bucket_name, dataframe, uploaded_csv_name=None):
"""Uploads a Pandas DataFrame as CSV to the bucket.
Args:
bucket_name (str): The bucket name to upload the CSV to.
dataframe (pandas.DataFrame): The Pandas Dataframe to be uploaded.
uploaded_csv_name (Optional[str]): The name for the uploaded CSV.
Returns:
A string representing the GCS URI of the uploaded CSV.
"""
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

if not isinstance(dataframe, pandas.DataFrame):
raise ValueError("'dataframe' must be a pandas.DataFrame instance.")

if uploaded_csv_name is None:
uploaded_csv_name = "automl-tables-dataframe-{}.csv".format(
int(time.time())
)
csv_string = dataframe.to_csv()

bucket = self.client.get_bucket(bucket_name)
blob = bucket.blob(uploaded_csv_name)
blob.upload_from_string(csv_string)

return "gs://{}/{}".format(bucket_name, uploaded_csv_name)
85 changes: 76 additions & 9 deletions automl/google/cloud/automl_v1beta1/tables/tables_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import pkg_resources
import logging
import google.auth

from google.api_core.gapic_v1 import client_info
from google.api_core import exceptions
from google.cloud.automl_v1beta1 import gapic
from google.cloud.automl_v1beta1.proto import data_types_pb2
from google.cloud.automl_v1beta1.tables import gcs_client

_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version
_LOGGER = logging.getLogger(__name__)
Expand All @@ -43,6 +45,7 @@ def __init__(
region="us-central1",
client=None,
prediction_client=None,
gcs_client=None,
**kwargs
):
"""Constructor.
Expand Down Expand Up @@ -118,6 +121,7 @@ def __init__(

self.project = project
self.region = region
self.gcs_client = gcs_client

def __lookup_by_display_name(self, object_type, items, display_name):
relevant_items = [i for i in items if i.display_name == display_name]
Expand Down Expand Up @@ -403,6 +407,21 @@ def __type_code_to_value_type(self, type_code, value):
else:
raise ValueError("Unknown type_code: {}".format(type_code))

def __ensure_gcs_client_is_initialized(self, credentials=None):
"""Checks if GCS client is initialized. Initializes it if not.
Args:
credentials (google.auth.credentials.Credentials): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If none
are specified, the client will attempt to ascertain the
credentials from the environment.
"""
if self.gcs_client is None:
if credentials is None:
credentials, _ = google.auth.default()
self.gcs_client = gcs_client.GcsClient(credentials=credentials)

def list_datasets(self, project=None, region=None, **kwargs):
"""List all datasets in a particular project and region.
Expand Down Expand Up @@ -642,10 +661,12 @@ def import_data(
dataset=None,
dataset_display_name=None,
dataset_name=None,
pandas_dataframe=None,
gcs_input_uris=None,
bigquery_input_uri=None,
project=None,
region=None,
credentials=None,
**kwargs
):
"""Imports data into a dataset.
Expand Down Expand Up @@ -679,6 +700,11 @@ def import_data(
region (Optional[string]):
If you have initialized the client with a value for `region` it
will be used if this parameter is not supplied.
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If none
are specified, the client will attempt to ascertain the
credentials from the environment.
dataset_display_name (Optional[string]):
The human-readable name given to the dataset you want to import
data into. This must be supplied if `dataset` or `dataset_name`
Expand All @@ -691,13 +717,21 @@ def import_data(
The `Dataset` instance you want to import data into. This must
be supplied if `dataset_display_name` or `dataset_name` are not
supplied.
pandas_dataframe (Optional[pandas.DataFrame]):
A Pandas Dataframe object containing the data to import. The data
will be converted to CSV, and this CSV will be staged to GCS in
`gs://{project}-automl-tables-staging/{uploaded_csv_name}`
This parameter must be supplied if neither `gcs_input_uris` nor
`bigquery_input_uri` is supplied.
gcs_input_uris (Optional[Union[string, Sequence[string]]]):
Either a single `gs://..` prefixed URI, or a list of URIs
referring to GCS-hosted CSV files containing the data to
import. This must be supplied if `bigquery_input_uri` is not.
import. This must be supplied if neither `bigquery_input_uri`
nor `pandas_dataframe` is supplied.
bigquery_input_uri (Optional[string]):
A URI pointing to the BigQuery table containing the data to
import. This must be supplied if `gcs_input_uris` is not.
import. This must be supplied if neither `gcs_input_uris` nor
`pandas_dataframe` is supplied.
Returns:
A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
Expand All @@ -720,15 +754,23 @@ def import_data(
)

request = {}
if gcs_input_uris is not None:

if pandas_dataframe is not None:
self.__ensure_gcs_client_is_initialized(credentials)
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
bucket_name, pandas_dataframe
)
request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
gcs_input_uris = [gcs_input_uris]
request = {"gcs_source": {"input_uris": gcs_input_uris}}
elif bigquery_input_uri is not None:
request = {"bigquery_source": {"input_uri": bigquery_input_uri}}
else:
raise ValueError(
"One of 'gcs_input_uris', or " "'bigquery_input_uri' must be set."
"One of 'gcs_input_uris', or 'bigquery_input_uri', or 'pandas_dataframe' must be set."
)

op = self.auto_ml_client.import_data(dataset_name, request, **kwargs)
Expand Down Expand Up @@ -2605,6 +2647,7 @@ def predict(

def batch_predict(
self,
pandas_dataframe=None,
bigquery_input_uri=None,
bigquery_output_uri=None,
gcs_input_uris=None,
Expand All @@ -2614,6 +2657,7 @@ def batch_predict(
model_display_name=None,
project=None,
region=None,
credentials=None,
inputs=None,
**kwargs
):
Expand Down Expand Up @@ -2645,15 +2689,30 @@ def batch_predict(
region (Optional[string]):
If you have initialized the client with a value for `region` it
will be used if this parameter is not supplied.
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If none
are specified, the client will attempt to ascertain the
credentials from the environment.
pandas_dataframe (Optional[pandas.DataFrame]):
A Pandas Dataframe object containing the data you want to predict
off of. The data will be converted to CSV, and this CSV will be
staged to GCS in `gs://{project}-automl-tables-staging/{uploaded_csv_name}`
This must be supplied if neither `gcs_input_uris` nor
`bigquery_input_uri` is supplied.
gcs_input_uris (Optional(Union[List[string], string]))
Either a list of or a single GCS URI containing the data you
want to predict off of.
want to predict off of. This must be supplied if neither
`pandas_dataframe` nor `bigquery_input_uri` is supplied.
gcs_output_uri_prefix (Optional[string])
The folder in GCS you want to write output to.
The folder in GCS you want to write output to. This must be
supplied if `bigquery_output_uri` is not.
bigquery_input_uri (Optional[string])
The BigQuery table to input data from.
The BigQuery table to input data from. This must be supplied if
neither `pandas_dataframe` nor `gcs_input_uris` is supplied.
bigquery_output_uri (Optional[string])
The BigQuery table to output data to.
The BigQuery table to output data to. This must be supplied if
`gcs_output_uri_prefix` is not.
model_display_name (Optional[string]):
The human-readable name given to the model you want to predict
with. This must be supplied if `model` or `model_name` are not
Expand Down Expand Up @@ -2688,7 +2747,15 @@ def batch_predict(
)

input_request = None
if gcs_input_uris is not None:

if pandas_dataframe is not None:
self.__ensure_gcs_client_is_initialized(credentials)
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
bucket_name, pandas_dataframe
)
input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
gcs_input_uris = [gcs_input_uris]
input_request = {"gcs_source": {"input_uris": gcs_input_uris}}
Expand Down
2 changes: 2 additions & 0 deletions automl/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def default(session):
for local_dep in LOCAL_DEPS:
session.install("-e", local_dep)
session.install("-e", ".")
session.install("-e", ".[pandas,storage]")

# Run py.test against the unit tests.
session.run(
Expand Down Expand Up @@ -117,6 +118,7 @@ def system(session):
session.install("-e", local_dep)
session.install("-e", "../test_utils/")
session.install("-e", ".")
session.install("-e", ".[pandas,storage]")

# Run py.test against the system tests.
if system_test_exists:
Expand Down
5 changes: 5 additions & 0 deletions automl/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
"google-api-core[grpc] >= 1.14.0, < 2.0.0dev",
'enum34; python_version < "3.4"',
]
extras = {
"pandas": ["pandas>=0.24.0"],
"storage": ["google-cloud-storage >= 1.18.0, < 2.0.0dev"],
}

package_root = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -67,6 +71,7 @@
packages=packages,
namespace_packages=namespaces,
install_requires=dependencies,
extras_require=extras,
python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*",
include_package_data=True,
zip_safe=False,
Expand Down
Loading

0 comments on commit b9c0892

Please sign in to comment.