Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-866 | Inference TaskType workflow #614

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fedn/network/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
from fedn.network.api.v1.inference_routes import bp as inference_bp
from fedn.network.api.v1.model_routes import bp as model_bp
from fedn.network.api.v1.package_routes import bp as package_bp
from fedn.network.api.v1.round_routes import bp as round_bp
from fedn.network.api.v1.session_routes import bp as session_bp
from fedn.network.api.v1.status_routes import bp as status_bp
from fedn.network.api.v1.validation_routes import bp as validation_bp

_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp]
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp]
34 changes: 34 additions & 0 deletions fedn/network/api/v1/inference_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import threading

from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version

bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new inference session.
param: session_id: The session id to start.
type: session_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

session_config = {"session_id": session_id}

threading.Thread(target=control.inference_session, kwargs={"config": session_config}).start()

return jsonify({"message": "Inference session started"}), 200
except Exception:
return jsonify({"message": "Failed to start inference session"}), 500
33 changes: 19 additions & 14 deletions fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import json
import queue
import traceback
from abc import ABC, abstractmethod

from fedn.common.log_config import logger
Expand All @@ -9,7 +10,7 @@


class AggregatorBase(ABC):
""" Abstract class defining an aggregator.
"""Abstract class defining an aggregator.

:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
:type id: str
Expand All @@ -25,7 +26,7 @@ class AggregatorBase(ABC):

@abstractmethod
def __init__(self, storage, server, modelservice, round_handler):
""" Initialize the aggregator."""
"""Initialize the aggregator."""
self.name = self.__class__.__name__
self.storage = storage
self.server = server
Expand Down Expand Up @@ -75,25 +76,31 @@ def on_model_update(self, model_update):
else:
logger.warning("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
except Exception as e:
logger.error("AGGREGATOR({}): failed to receive model update! {}".format(self.name, e))
tb = traceback.format_exc()
logger.error("AGGREGATOR({}): failed to receive model update: {}".format(self.name, e))
logger.error(tb)
pass

def _validate_model_update(self, model_update):
""" Validate the model update.
"""Validate the model update.

:param model_update: A ModelUpdate message.
:type model_update: object
:return: True if the model update is valid, False otherwise.
:rtype: bool
"""
data = json.loads(model_update.meta)["training_metadata"]
if "num_examples" not in data.keys():
logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
try:
data = json.loads(model_update.meta)["training_metadata"]
_ = data["num_examples"]
except KeyError:
tb = traceback.format_exc()
logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name))
logger.error(tb)
return False
return True

def next_model_update(self):
""" Get the next model update from the queue.
"""Get the next model update from the queue.

:param helper: A helper object.
:type helper: object
Expand All @@ -104,7 +111,7 @@ def next_model_update(self):
return model_update

def load_model_update(self, model_update, helper):
""" Load the memory representation of the model update.
"""Load the memory representation of the model update.

Load the model update paramters and the
associate metadata into memory.
Expand Down Expand Up @@ -132,15 +139,13 @@ def load_model_update(self, model_update, helper):
return model, training_metadata

def get_state(self):
""" Get the state of the aggregator's queue, including the number of model updates."""
state = {
"queue_len": self.model_updates.qsize()
}
"""Get the state of the aggregator's queue, including the number of model updates."""
state = {"queue_len": self.model_updates.qsize()}
return state


def get_aggregator(aggregator_module_name, storage, server, modelservice, control):
""" Return an instance of the helper class.
"""Return an instance of the helper class.

:param helper_module_name: The name of the helper plugin module.
:type helper_module_name: str
Expand Down
95 changes: 61 additions & 34 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.combiner.connect import ConnectorCombiner, Status
from fedn.network.combiner.modelservice import ModelService
from fedn.network.combiner.roundhandler import RoundHandler
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
from fedn.network.grpc.server import Server
from fedn.network.storage.s3.repository import Repository
from fedn.network.storage.statestore.mongostatestore import MongoStateStore
Expand Down Expand Up @@ -65,7 +65,6 @@ def __init__(self, config):
# Client queues
self.clients = {}


# Validate combiner name
match = re.search(VALID_NAME_REGEX, config["name"])
if not match:
Expand Down Expand Up @@ -161,7 +160,7 @@ def __whoami(self, client, instance):
client.role = role_to_proto_role(instance.role)
return client

def request_model_update(self, config, clients=[]):
def request_model_update(self, session_id, model_id, config, clients=[]):
"""Ask clients to update the current global model.

:param config: the model configuration to send to clients
Expand All @@ -170,32 +169,14 @@ def request_model_update(self, config, clients=[]):
:type clients: list

"""
# The request to be added to the client queue
request = fedn.TaskRequest()
request.model_id = config["model_id"]
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.data = json.dumps(config)
request.type = fedn.StatusType.MODEL_UPDATE
request.session_id = config["session_id"]

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if len(clients) == 0:
clients = self.get_active_trainers()

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)
request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_validation(self, model_id, config, clients=[]):
def request_model_validation(self, session_id, model_id, clients=[]):
"""Ask clients to validate the current global model.

:param model_id: the model id to validate
Expand All @@ -206,30 +187,76 @@ def request_model_validation(self, model_id, config, clients=[]):
:type clients: list

"""
# The request to be added to the client queue
request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None:
"""Ask clients to perform inference on the model.

:param model_id: the model id to perform inference on
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list

"""
request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients)))

def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]):
"""Send a request of a specific type to clients.

:param request_type: the type of request
:type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType`
:param model_id: the model id to send in the request
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list
:return: the request and the clients
:rtype: tuple
"""
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
# request.is_inference = (config['task'] == 'inference')
request.type = fedn.StatusType.MODEL_VALIDATION
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER
request.session_id = config["session_id"]

if len(clients) == 0:
clients = self.get_active_validators()
if request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
if len(clients) == 0:
clients = self.get_active_trainers()
elif request_type == fedn.StatusType.MODEL_VALIDATION:
if len(clients) == 0:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
request.data = json.dumps(config)
if len(clients) == 0:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))
return request, clients

def get_active_trainers(self):
"""Get a list of active trainers.
Expand Down Expand Up @@ -410,7 +437,7 @@ def Start(self, control: fedn.ControlRequest, context):
"""
logger.info("grpc.Combiner.Start: Starting round")

config = {}
config = RoundConfig()
for parameter in control.parameter:
config.update({parameter.key: parameter.value})

Expand Down
3 changes: 2 additions & 1 deletion fedn/network/combiner/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.network.combiner.roundhandler import RoundConfig


class CombinerUnavailableError(Exception):
Expand Down Expand Up @@ -202,7 +203,7 @@ def set_aggregator(self, aggregator):
else:
raise

def submit(self, config):
def submit(self, config: RoundConfig):
"""Submit a compute plan to the combiner.

:param config: The job configuration.
Expand Down
Loading
Loading