Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix log_parsing example pipeline null output issue #2024

11 changes: 6 additions & 5 deletions examples/log_parsing/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(ControlMessage)

@staticmethod
def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: TensorMemory) -> ControlMessage:
def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: TensorMemory,
batch_offset: int) -> ControlMessage:
memory = output.tensors()

out_seq_ids = memory.get_tensor('seq_ids')
Expand All @@ -153,17 +154,17 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
seq_offset = seq_ids[0, 0].item()
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset

input_ids[0:inf.tensors().count, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[0:inf.tensors().count, :] = seq_ids
input_ids[batch_offset:inf.tensors().count + batch_offset, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[batch_offset:inf.tensors().count + batch_offset, :] = seq_ids

resp_confidences = res.get_tensor('confidences')
resp_labels = res.get_tensor('labels')

# Two scenarios:
if (inf.payload().count == inf.tensors().count):
assert seq_count == res.count
confidences[0:inf.tensors().count, :] = resp_confidences
labels[0:inf.tensors().count, :] = resp_labels
confidences[batch_offset:inf.tensors().count + batch_offset, :] = resp_confidences
labels[batch_offset:inf.tensors().count + batch_offset, :] = resp_labels
else:
assert inf.tensors().count == res.count

Expand Down
10 changes: 7 additions & 3 deletions python/morpheus/morpheus/stages/inference/inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,18 @@ def on_next(message: ControlMessage):

fut_list = []

batch_offset = 0

for batch in batches:
outstanding_requests += 1

completion_future = mrc.Future()

def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future):
nonlocal outstanding_requests
mess = self._convert_one_response(output_message, inner_batch, resp)

nonlocal batch_offset
mess = self._convert_one_response(output_message, inner_batch, resp, batch_offset)
batch_offset += inner_batch.tensors().count
outstanding_requests -= 1

batch_future.set_result(mess)
Expand Down Expand Up @@ -340,7 +343,8 @@ def _split_batches(msg: ControlMessage, max_batch_size: int) -> typing.List[Cont
return out_resp

@staticmethod
def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: TensorMemory):
def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: TensorMemory,
batch_offset: int) -> ControlMessage: # pylint:disable=unused-argument
# Make sure we have a continuous list
# assert inf.mess_offset == saved_offset + saved_count

Expand Down
2 changes: 1 addition & 1 deletion tests/examples/log_parsing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis

input_inf = build_inf_message(filter_probs_df, mess_count=mess_count, count=count, num_cols=num_cols)

output_msg = inference_mod.LogParsingInferenceStage._convert_one_response(resp_msg, input_inf, input_res)
output_msg = inference_mod.LogParsingInferenceStage._convert_one_response(resp_msg, input_inf, input_res, 0)

assert isinstance(output_msg, ControlMessage)
assert output_msg.payload() is input_inf.payload()
Expand Down
8 changes: 5 additions & 3 deletions tests/morpheus/stages/test_inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def test_convert_one_response():
res = ResponseMemory(count=4, tensors={"probs": cp.random.rand(4, 3)})
output = _mk_control_message(mess_count=4, count=4)
output.tensors(mem)
batch_offset = 0

cm = InferenceStageT._convert_one_response(output, inf, res)
cm = InferenceStageT._convert_one_response(output, inf, res, batch_offset)
assert cm.payload() == inf.payload()
assert cm.payload().count == 4
assert cm.tensors().count == 4
Expand All @@ -120,14 +121,15 @@ def test_convert_one_response():
mem = ResponseMemory(count=2, tensors={"probs": cp.zeros((2, 3))})
output = _mk_control_message(mess_count=2, count=3)
output.tensors(mem)
cm = InferenceStageT._convert_one_response(output, inf, res)
cm = InferenceStageT._convert_one_response(output, inf, res, batch_offset)
assert cm.tensors().get_tensor("probs").tolist() == [[0, 0.6, 0.7], [5.6, 6.7, 9.2]]


def test_convert_one_response_error():
inf = _mk_control_message(mess_count=2, count=2)
res = _mk_control_message(mess_count=1, count=1)
output = inf
batch_offset = 0

with pytest.raises(AssertionError):
InferenceStageT._convert_one_response(output, inf, res.tensors())
InferenceStageT._convert_one_response(output, inf, res.tensors(), batch_offset)
Loading