diff --git a/label_studio_ml/api.py b/label_studio_ml/api.py index 767e54c..cf8065d 100644 --- a/label_studio_ml/api.py +++ b/label_studio_ml/api.py @@ -1,6 +1,12 @@ import hmac +import json import logging import os +import dagshub +import mlflow +import base64 +import cloudpickle +from dagshub.data_engine import datasources from flask import Flask, request, jsonify, Response @@ -11,18 +17,14 @@ logger = logging.getLogger(__name__) _server = Flask(__name__) -MODEL_CLASS = LabelStudioMLBase BASIC_AUTH = None -def init_app(model_class, basic_auth_user=None, basic_auth_pass=None): - global MODEL_CLASS +def init_app(model_instance, basic_auth_user=None, basic_auth_pass=None): + global model global BASIC_AUTH - if not issubclass(model_class, LabelStudioMLBase): - raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__) - - MODEL_CLASS = model_class + model = model_instance basic_auth_user = basic_auth_user or os.environ.get('BASIC_AUTH_USER') basic_auth_pass = basic_auth_pass or os.environ.get('BASIC_AUTH_PASS') if basic_auth_user and basic_auth_pass: @@ -30,6 +32,19 @@ def init_app(model_class, basic_auth_user=None, basic_auth_pass=None): return _server +@_server.post('/configure') +@exception_handler +def _configure(): + args = json.loads(request.get_json()) + dagshub.init(args['repo'], args['username']) # user-level privileged auth token + ls_model = mlflow.pyfunc.load_model(f'models:/{args["model"]}/{args["version"]}') + + model.configure(ls_model, *[cloudpickle.loads(base64.b64decode(args[hook])) for hook in ['pre_hook', 'post_hook']]) + # model.api = dagshub.common.api.repo.RepoAPI(f'https://dagshub.com/{args["username"]}/{args["repo"]}', host=args['host']) + + model.ds = datasources.get_datasource(args['datasource_repo'], args['datasource_name']) + model.dp_map = model.ds.all().dataframe[['path', 'datapoint_id']] + return [] @_server.route('/predict', methods=['POST']) @exception_handler @@ -61,8 +76,8 @@ def _predict(): params = data.get('params', {}) context = params.pop('context', {}) - model = MODEL_CLASS(project_id=project_id, - label_config=label_config) + model.project_id = project_id + model.use_label_config(label_config) # model.use_label_config(label_config) @@ -96,8 +111,8 @@ def _setup(): project_id = data.get('project').split('.', 1)[0] label_config = data.get('schema') extra_params = data.get('extra_params') - model = MODEL_CLASS(project_id=project_id, - label_config=label_config) + model.project_id = project_id + model.use_label_config(label_config) if extra_params: model.set_extra_params(extra_params) @@ -122,7 +137,8 @@ def webhook(): return jsonify({'status': 'Unknown event'}), 200 project_id = str(data['project']['id']) label_config = data['project']['label_config'] - model = MODEL_CLASS(project_id, label_config=label_config) + model.project_id = project_id + model.use_label_config(label_config) model.fit(event, data) return jsonify({}), 201 @@ -133,7 +149,6 @@ def webhook(): def health(): return jsonify({ 'status': 'UP', - 'model_class': MODEL_CLASS.__name__ })