Skip to content

Commit

Permalink
Merge pull request #5 from dagardner-nv/david-mdd_dfp-example
Browse files Browse the repository at this point in the history
Cache list of registered models, avoiding lookups for users which don't have models
  • Loading branch information
mdemoret-nv committed Aug 27, 2022
2 parents 92cf857 + 4cfaa85 commit cdb078e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/dfp_workflow/mlflow/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ FROM python:3.8-slim-buster
# Install curl for health check
RUN apt update && \
apt install -y --no-install-recommends \
curl && \
curl libyaml-cpp-dev libyaml-dev && \
apt autoremove -y && \
apt clean all && \
rm -rf /var/cache/apt/* /var/lib/apt/lists/*

# Install python packages
RUN pip install mlflow boto3 pymysql
RUN pip install mlflow boto3 pymysql pyyaml

# We run on port 5000
EXPOSE 5000
Expand Down
15 changes: 15 additions & 0 deletions examples/dfp_workflow/morpheus/dfp/stages/dfp_inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
import threading
import time
import typing
Expand All @@ -34,6 +35,15 @@
logger = logging.getLogger("morpheus.{}".format(__name__))


def get_registered_models():
client = MlflowClient(os.environ.get('DFP_TRACKING_URI'))
models = client.list_registered_models()
return set(model.name for model in models)


REGISTERED_MODELS = get_registered_models()


class ModelCache:

def __init__(self, reg_model_name: str, model_uri: str) -> None:
Expand Down Expand Up @@ -115,6 +125,8 @@ def load_model(self, client):
# Our model does not exist, use fallback
self._child_user_model_cache = self._manager.load_user_model_cache(
self._fallback_user_ids[0], fallback_user_ids=self._fallback_user_ids[1:])
else:
return model_cache

# See if we have a child cache and use that
if (self._child_user_model_cache is not None):
Expand Down Expand Up @@ -173,6 +185,9 @@ def load_model_cache(self, client, reg_model_name: str) -> ModelCache:

# Cache miss. Try to check for a model
try:
if reg_model_name not in REGISTERED_MODELS:
raise MlflowException("")

latest_versions = client.get_latest_versions(reg_model_name)

# Default to the first returned one
Expand Down

0 comments on commit cdb078e

Please sign in to comment.