Skip to content

Commit

Permalink
configured for use with customizable mlflow model
Browse files Browse the repository at this point in the history
  • Loading branch information
jinensetpal committed Jul 4, 2024
1 parent d182572 commit 03002a0
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions label_studio_ml/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import hmac
import json
import logging
import os
import dagshub
import mlflow
import base64
import cloudpickle
from dagshub.data_engine import datasources

from flask import Flask, request, jsonify, Response

Expand All @@ -11,25 +17,34 @@
logger = logging.getLogger(__name__)

_server = Flask(__name__)
MODEL_CLASS = LabelStudioMLBase
BASIC_AUTH = None


def init_app(model_class, basic_auth_user=None, basic_auth_pass=None):
global MODEL_CLASS
def init_app(model_instance, basic_auth_user=None, basic_auth_pass=None):
global model
global BASIC_AUTH

if not issubclass(model_class, LabelStudioMLBase):
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)

MODEL_CLASS = model_class
model = model_instance
basic_auth_user = basic_auth_user or os.environ.get('BASIC_AUTH_USER')
basic_auth_pass = basic_auth_pass or os.environ.get('BASIC_AUTH_PASS')
if basic_auth_user and basic_auth_pass:
BASIC_AUTH = (basic_auth_user, basic_auth_pass)

return _server

@_server.post('/configure')
@exception_handler
def _configure():
args = json.loads(request.get_json())
dagshub.init(args['repo'], args['username']) # user-level privileged auth token
ls_model = mlflow.pyfunc.load_model(f'models:/{args["model"]}/{args["version"]}')

model.configure(ls_model, *[cloudpickle.loads(base64.b64decode(args[hook])) for hook in ['pre_hook', 'post_hook']])
# model.api = dagshub.common.api.repo.RepoAPI(f'https://dagshub.com/{args["username"]}/{args["repo"]}', host=args['host'])

model.ds = datasources.get_datasource(args['datasource_repo'], args['datasource_name'])
model.dp_map = model.ds.all().dataframe[['path', 'datapoint_id']]
return []

@_server.route('/predict', methods=['POST'])
@exception_handler
Expand Down Expand Up @@ -61,8 +76,8 @@ def _predict():
params = data.get('params', {})
context = params.pop('context', {})

model = MODEL_CLASS(project_id=project_id,
label_config=label_config)
model.project_id = project_id
model.use_label_config(label_config)

# model.use_label_config(label_config)

Expand Down Expand Up @@ -96,8 +111,8 @@ def _setup():
project_id = data.get('project').split('.', 1)[0]
label_config = data.get('schema')
extra_params = data.get('extra_params')
model = MODEL_CLASS(project_id=project_id,
label_config=label_config)
model.project_id = project_id
model.use_label_config(label_config)

if extra_params:
model.set_extra_params(extra_params)
Expand All @@ -122,7 +137,8 @@ def webhook():
return jsonify({'status': 'Unknown event'}), 200
project_id = str(data['project']['id'])
label_config = data['project']['label_config']
model = MODEL_CLASS(project_id, label_config=label_config)
model.project_id = project_id
model.use_label_config(label_config)
model.fit(event, data)
return jsonify({}), 201

Expand All @@ -133,7 +149,6 @@ def webhook():
def health():
return jsonify({
'status': 'UP',
'model_class': MODEL_CLASS.__name__
})


Expand Down

0 comments on commit 03002a0

Please sign in to comment.