Skip to content
Merged
21 changes: 21 additions & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
return exc if exc else _handle_func_requests(error, message, error_dict)


def handle_operation_error(error):
"""Constructs a ``FirebaseError`` from the given operation error.

Args:
error: An error returned by a long running operation.

Returns:
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
"""
if not isinstance(error, dict):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)


def _handle_func_requests(error, message, error_dict):
"""Constructs a ``FirebaseError`` from the given GCP error.

Expand Down
182 changes: 176 additions & 6 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import datetime
import numbers
import re
import time
import requests
import six


from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions


_MLKIT_ATTRIBUTE = '_mlkit'
Expand All @@ -36,6 +39,9 @@
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
r'/operation/[^/]+$')


def _get_mlkit_service(app):
Expand All @@ -53,18 +59,60 @@ def _get_mlkit_service(app):
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def create_model(model, app=None):
"""Creates a model in Firebase ML Kit.

Args:
model: An mlkit.Model to create.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The model that was created in Firebase ML Kit.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.create_model(model), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Args:
model_id: The id of the model to get.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The requested model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.get_model(model_id))
return Model.from_dict(mlkit_service.get_model(model_id), app=app)


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists models from Firebase ML Kit.

Args:
list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
page_size: A number between 1 and 100 inclusive that specifies the maximum
number of models to return per page. None for default.
page_token: A next page token returned from a previous page of results. None
for first page of results.
app: A Firebase app instance (or None to use the default app).

Returns:
ListModelsPage: A (filtered) list of models.
"""
mlkit_service = _get_mlkit_service(app)
return ListModelsPage(
mlkit_service.list_models, list_filter, page_size, page_token)
mlkit_service.list_models, list_filter, page_size, page_token, app=app)


def delete_model(model_id, app=None):
"""Deletes a model from Firebase ML Kit.

Args:
model_id: The id of the model you wish to delete.
app: A Firebase app instance (or None to use the default app).
"""
mlkit_service = _get_mlkit_service(app)
mlkit_service.delete_model(model_id)

Expand All @@ -78,6 +126,7 @@ class Model(object):
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
"""
def __init__(self, display_name=None, tags=None, model_format=None):
self._app = None # Only needed for wait_for_unlo
self._data = {}
self._model_format = None

Expand All @@ -89,16 +138,26 @@ def __init__(self, display_name=None, tags=None, model_format=None):
self.model_format = model_format

@classmethod
def from_dict(cls, data):
def from_dict(cls, data, app=None):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
self.model_format = tflite_format
self._data = data_copy

def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
Expand Down Expand Up @@ -173,6 +232,15 @@ def locked(self):
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)

def wait_for_unlocked(self, max_time_seconds=None):
if self.locked:
mlkit_service = _get_mlkit_service(self._app)
op_name = self._data.get('activeOperations')[0].get('name')
model_dict = mlkit_service.handle_operation(
mlkit_service.get_operation(op_name),
max_time_seconds=max_time_seconds)
self._update_from_dict(model_dict)

@property
def model_format(self):
return self._model_format
Expand Down Expand Up @@ -296,17 +364,20 @@ class ListModelsPage(object):
``iterate_all()`` can be used to iterate through all the models in the
Firebase project starting from this page.
"""
def __init__(self, list_models_func, list_filter, page_size, page_token):
def __init__(self, list_models_func, list_filter, page_size, page_token, app):
self._list_models_func = list_models_func
self._list_filter = list_filter
self._page_size = page_size
self._page_token = page_token
self._app = app
self._list_response = list_models_func(list_filter, page_size, page_token)

@property
def models(self):
"""A list of Models from this page."""
return [Model.from_dict(model) for model in self._list_response.get('models', [])]
return [
Model.from_dict(model, app=self._app) for model in self._list_response.get('models', [])
]

@property
def list_filter(self):
Expand All @@ -333,7 +404,8 @@ def get_next_page(self):
self._list_models_func,
self._list_filter,
self._page_size,
self.next_page_token)
self.next_page_token,
self._app)
return None

def iterate_all(self):
Expand Down Expand Up @@ -390,11 +462,25 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
raise ValueError('Model must have a display name.')


def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_and_parse_operation_name(op_name):
matcher = _OPERATION_NAME_PATTERN.match(op_name)
if not matcher:
raise ValueError('Operation name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
Expand All @@ -417,11 +503,13 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')
return model_format


def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, six.string_types):
Expand All @@ -448,6 +536,9 @@ class _MLKitService(object):
"""Firebase MLKit service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5
POLL_BASE_WAIT_TIME_SECONDS = 3

def __init__(self, app):
project_id = app.project_id
Expand All @@ -459,6 +550,85 @@ def __init__(self, app):
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=_MLKitService.OPERATION_URL)

def get_operation(self, op_name):
_validate_and_parse_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def handle_operation(self, operation, max_polling_attempts=None, max_time_seconds=None,
always_return_model=False):
"""Handles long running operations.

Args:
operation: The operation to handle.
max_polling_attempts: The maximum number of polling requests to make.
(None for no limit)
max_time_seconds: The maximum seconds to try polling for operation complete.
(None for no limit)
always_return_model: If true, returns a locked Model instead of raising deadline
exceeded exceptions.

Returns:
dict: A dictionary of the returned model properties.

Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)

current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while True:
if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
else:
# A 'done' operation must have either a response or an error.
raise ValueError('Operation is malformed.')
else:
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
if max_polling_attempts is not None and current_attempt >= max_polling_attempts:
if always_return_model:
return get_model(model_id).as_dict()
raise exceptions.DeadlineExceededError('Polling max attempts exceeded.')
delay_factor = pow(
_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS
after_sleep_time = (datetime.datetime.now() +
datetime.timedelta(seconds=wait_time_seconds))
if stop_time is not None and after_sleep_time > stop_time:
if always_return_model:
return get_model(model_id).as_dict()
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
time.sleep(wait_time_seconds)
operation = self.get_operation(op_name)
current_attempt += 1


def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()),
max_polling_attempts=1,
always_return_model=True)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
Expand Down
Loading