Skip to content

Add a guide for implementing server-side batching using the Python Predictor #1470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
RobertLucian opened this issue Oct 21, 2020 · 0 comments · Fixed by #1653
Closed

Add a guide for implementing server-side batching using the Python Predictor #1470

RobertLucian opened this issue Oct 21, 2020 · 0 comments · Fixed by #1653
Labels
docs Improvements or additions to documentation
Milestone

Comments

@RobertLucian
Copy link
Member

RobertLucian commented Oct 21, 2020

Description

Can be inspired by https://docs.cortex.dev/deployments/realtime-api/parallelism#server-side-batching.

Template implementation:

import threading as td
import time

class PythonPredictor:
    def __init__(self, config):
        self.model = None # initialize the model here

        self.waiter = td.Event()
        self.waiter.set()

        self.batch_max_size = config["batch_max_size"]
        self.batch_interval = config["batch_interval"] # measured in seconds
        self.barrier = td.Barrier(self.batch_max_size + 1)

        self.samples = {}
        self.predictions = {}
        td.Thread(target=self._batch_engine).start()

    def _batch_engine(self):
         while True:
            if len(self.predictions) > 0:
                time.sleep(0.001)
                continue

            try:
                self.barrier.wait(self.batch_interval)
            except td.BrokenBarrierError:
                pass
            self.waiter.clear()
            self.predictions = {}

            self.batch_inference()

            self.samples = {}
            self.barrier.reset()
            self.waiter.set()

    def batch_inference(self):
        """
        Run the batch inference here.
        """
        # batch process self.samples
        # store results in self.predictions
        # make sure to write the results to self.predictions using the keys from self.samples

    def enqueue_sample(self, sample):
        """
        Enqueue sample for batch inference. This is a blocking method.
        """
        thread_id = td.get_ident()

        self.waiter.wait()
        self.samples[thread_id] = sample
        try:
            self.barrier.wait()
        except td.BrokenBarrierError:
            pass

    def get_prediction(self):
        """
        Return the prediction. This is a blocking method.
        """
        thread_id = td.get_ident()
        while thread_id not in self.predictions:
            time.sleep(0.001)
        prediction = self.predictions[thread_id]
        del self.predictions[thread_id]

        return prediction

    def predict(self, payload):
        self.enqueue_sample(payload)
        prediction = self.get_prediction()

        return prediction

Motivation

Useful for those users who really need server-side batching for the Python Predictor.
Has been requested by @manneshiva.

@RobertLucian RobertLucian added the docs Improvements or additions to documentation label Oct 21, 2020
@deliahu deliahu added this to the v0.25 milestone Dec 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Improvements or additions to documentation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants