Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful inference #2513

Merged
merged 59 commits into from
Nov 8, 2023
Merged

stateful inference #2513

merged 59 commits into from
Nov 8, 2023

Conversation

lxning
Copy link
Collaborator

@lxning lxning commented Aug 1, 2023

Description

Please read our CONTRIBUTING.md prior to creating your first pull request.

Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes #(issue)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Regression test
    reg.txt

  • Normal Sequential Inference

# model_store/stateful/model-config.yaml
minWorkers: 2
maxWorkers: 2
batchSize: 4
maxBatchDelay: 100
sequenceMaxIdleMSec: 600000
maxNumSequence: 4
maxSequenceJobQueueSize: 10

handler:
  cache:
    capacity: 4

# Start model server and load  example model stateful.mar which responses the accumulated value from the sequential input
torchserve --ncs --start --model-store model_store --models stateful.mar --ts-config benchmarks/config.properties

# Run sequential inference
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
InferStream2 started
prediction: "1"

prediction: "3"

prediction: "6"

Sequence completed!
InferStream2 closed
  • Expired or Streaming Closed Sequential Inference: the second sequence inference call gets error.
# model_store/stateful/model-config.yaml
minWorkers: 2
maxWorkers: 2
batchSize: 4
maxBatchDelay: 5000
sequenceMaxIdleMSec: 600000
maxNumSequence: 4
maxSequenceJobQueueSize: 10

handler:
  cache:
    capacity: 4

# Start model server and load  example model stateful.mar which responses the accumulated value from the sequential input
torchserve --ncs --start --model-store model_store --models stateful.mar --ts-config benchmarks/config.properties

# Run the first sequential inference
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
InferStream2 started
prediction: "1"

prediction: "3"

prediction: "6"

Sequence completed!
InferStream2 closed

# Run the 2nd sequential inference
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
InferStream2 started
status {
  code: 13
  message: "Model \"stateful\" please check if the sequence is closed, or expired; or exceeds maxSequenceJobQueueSize in log"
  details {
    type_url: "type.googleapis.com/google.rpc.ErrorInfo"
    value: "\n\032InternalServerException.()"
  }
}

status {
  code: 13
  message: "Model \"stateful\" please check if the sequence is closed, or expired; or exceeds maxSequenceJobQueueSize in log"
  details {
    type_url: "type.googleapis.com/google.rpc.ErrorInfo"
    value: "\n\032InternalServerException.()"
  }
}

status {
  code: 13
  message: "Model \"stateful\" please check if the sequence is closed, or expired; or exceeds maxSequenceJobQueueSize in log"
  details {
    type_url: "type.googleapis.com/google.rpc.ErrorInfo"
    value: "\n\032InternalServerException.()"
  }
}

Sequence completed!
InferStream2 closed
  • Concurrently Run 2 Sequential Inferences on the same worker
# model_store/stateful/model-config.yaml
minWorkers: 2
maxWorkers: 2
batchSize: 4
maxBatchDelay: 5000
sequenceMaxIdleMSec: 600000
maxNumSequence: 4
maxSequenceJobQueueSize: 10

handler:
  cache:
    capacity: 4

# Start model server and load  example model stateful.mar which responses the accumulated value from the sequential input
torchserve --ncs --start --model-store model_store --models stateful.mar --ts-config benchmarks/config.properties

# The first sequential inference
 python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt,examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt,examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt,examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
InferStream2 started
prediction: "1"

prediction: "3"

prediction: "6"

prediction: "7"

status {
  code: 13
  message: "Model \"stateful\" please check if the sequence is closed, or expired; or exceeds maxSequenceJobQueueSize in log"
  details {
    type_url: "type.googleapis.com/google.rpc.ErrorInfo"
    value: "\n\032InternalServerException.()"
  }
}

prediction: "9"

prediction: "12"

prediction: "13"

prediction: "15"

prediction: "18"

prediction: "19"

prediction: "21"

Sequence completed!
InferStream2 closed

# The 2nd sequential inference
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_1 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt,examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
InferStream2 started
prediction: "1"

prediction: "3"

prediction: "6"

prediction: "7"

prediction: "9"

prediction: "12"

Sequence completed!
InferStream2 closed

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

@lxning lxning self-assigned this Aug 1, 2023
@lxning lxning changed the title [WIP] stateful inference stateful inference Aug 15, 2023
Copy link

@calebho calebho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with the implementation details so I can only comment on the API


self.sequence_ids = {}
results = []
for idx, row in enumerate(data):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, is it the case that batchSize is the least upper bound of len(data), i.e.len(data) <= batchSize and for all l such that len(data) <= l, batchSize <= l?

Is it possible for two separate requests to get batched to this worker? If so, suppose there are two separate streaming requests that are batched to this worker. What happens if one client is much much faster than the other? Do we throttle the faster client to match the speed of the slower one by buffering the faster client's messages?

Copy link
Collaborator Author

@lxning lxning Aug 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Q1: yes, len(data) <= batchSize. data is a batch of requests received at realtime.

  • Q2: Yes, a batch of requests comes from different sequences. eg. len(data) = 4, it means there are 4 sequences. Each sequence has its own dedicated jobQ. Only the parameter "maxBatchDelay" decides the msec of batching a group of requests from different sequences. In other words, the different traffic volume of different sequences has no impact on batching latency.

Copy link

@calebho calebho Aug 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but if two streams produce data at drastically different rates, how do you keep the batch index coherent? For instance, fix a stateful worker. At time t_0, the worker receives data d_0_0 and d_1_0 from two streams. So then len(data) == 2 and data[0] is the payload for stream 0 and data[1] is the payload for stream 1.

At t_1, stream 0 does not produce any data because it took longer than maxBatchDelay, but stream 1 produces data d_1_1. So then len(data) == 1 and data[0] is the payload for stream 1. In the line below, idx == 0, so then you fetch the sequence ID for index 0. It seems like this would fetch the sequence ID for stream 0,

sequence_id = self.context.get_sequence_id(idx)

but you actually want the sequence ID for stream 1. Am I understanding the API semantics correctly? Perhaps I am misunderstanding how context.get_sequence_id works. Does it keep track of which stream corresponds to the elements of the data list passed to the handler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each request's sequence id is added into its header with key = "ts_request_sequence_id". Backend can get a request's sequence id via its header. This can guarantee we can always get the sequence id regardless the real batch size is changed or the request of a sequence enters into a different batch slot.

@msaroufim msaroufim added the enhancement New feature or request label Aug 25, 2023
@codecov
Copy link

codecov bot commented Sep 29, 2023

Codecov Report

Merging #2513 (40991b3) into master (7f4419f) will decrease coverage by 0.02%.
Report is 2 commits behind head on master.
The diff coverage is 50.00%.

❗ Current head 40991b3 differs from pull request most recent head 0a90a87. Consider uploading reports for the commit 0a90a87 to get more accurate results

@@            Coverage Diff             @@
##           master    #2513      +/-   ##
==========================================
- Coverage   72.44%   72.43%   -0.02%     
==========================================
  Files          85       85              
  Lines        3963     3965       +2     
  Branches       58       58              
==========================================
+ Hits         2871     2872       +1     
- Misses       1088     1089       +1     
  Partials        4        4              
Files Coverage Δ
ts/context.py 77.21% <50.00%> (-0.71%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

&& model.getParallelLevel() > 1
&& model.getParallelType()
!= ModelConfig.ParallelType.PP)
? model.getParallelLevel()
: 1;
List<CompletableFuture<Void>> futureRequests = new ArrayList<>(repeats);
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, in that case we should move the check out of the loop condition and start from the beginning. Otherwise we're getting an undefined delay before we retry sending the job through the check for results (that cannot be there as we never sent the request).

CompletableFuture.runAsync(
() -> {
Job job = jobGroup.pollJob((long) model.getMaxBatchDelay());
if (job != null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this part into pushing the jobs instead of polling?

Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're already in a good shape, left some comments.

break;
}

if (cmd == WorkerCommands.STREAMPREDICT2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate still persists

examples/stateful/Readme.md Outdated Show resolved Hide resolved
examples/stateful/Readme.md Show resolved Hide resolved
test/postman/inference_stream2_data.json Outdated Show resolved Hide resolved
test/pytest/test_parallelism.py Outdated Show resolved Hide resolved
@lxning lxning enabled auto-merge November 3, 2023 19:22
Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted.

@lxning lxning added this pull request to the merge queue Nov 8, 2023
Merged via the queue into master with commit e1c31e1 Nov 8, 2023
12 checks passed
@lxning lxning added this to the v0.10.0 milestone Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants