Skip to content

Commit ab96ab6

Browse files
committed
refactor(llm): standardize access to parsed task output text
- Updated all instances to use `result.text` for accessing parsed task output text. - Replaced direct usage of `result` with `text` where applicable for clarity and consistency. - Adjusted related logic to ensure proper handling of stripped and formatted text.
1 parent 92ede8b commit ab96ab6

File tree

7 files changed

+49
-11
lines changed

7 files changed

+49
-11
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,9 @@ async def generate_user_intent(
442442
result = self.llm_task_manager.parse_task_output(
443443
Task.GENERATE_USER_INTENT, output=result
444444
)
445+
text = result.text
445446

446-
user_intent = get_first_nonempty_line(result)
447+
user_intent = get_first_nonempty_line(text)
447448
if user_intent is None:
448449
user_intent = "unknown message"
449450

@@ -527,9 +528,11 @@ async def generate_user_intent(
527528
prompt,
528529
custom_callback_handlers=[streaming_handler_var.get()],
529530
)
530-
text = self.llm_task_manager.parse_task_output(
531+
result = self.llm_task_manager.parse_task_output(
531532
Task.GENERAL, output=text
532533
)
534+
text = result.text
535+
text = text.strip()
533536
else:
534537
# Initialize the LLMCallInfo object
535538
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
@@ -562,9 +565,10 @@ async def generate_user_intent(
562565
stop=["User:"],
563566
)
564567

565-
text = self.llm_task_manager.parse_task_output(
568+
result = self.llm_task_manager.parse_task_output(
566569
Task.GENERAL, output=result
567570
)
571+
text = result.text
568572
text = text.strip()
569573
if text.startswith('"'):
570574
text = text[1:-1]
@@ -646,10 +650,11 @@ async def generate_next_step(
646650
result = self.llm_task_manager.parse_task_output(
647651
Task.GENERATE_NEXT_STEPS, output=result
648652
)
653+
text = result.text
649654

650655
# If we don't have multi-step generation enabled, we only look at the first line.
651656
if not self.config.enable_multi_step_generation:
652-
result = get_first_nonempty_line(result)
657+
result = get_first_nonempty_line(text)
653658

654659
if result and result.startswith("bot "):
655660
bot_intent = result[4:]
@@ -687,7 +692,7 @@ async def generate_next_step(
687692
# Otherwise, we parse the output as a single flow.
688693
# If we have a parsing error, we try to reduce size of the flow, potentially
689694
# up to a single step.
690-
lines = result.split("\n")
695+
lines = text.split("\n")
691696
while True:
692697
try:
693698
parse_colang_file("dynamic.co", content="\n".join(lines))
@@ -896,10 +901,15 @@ async def generate_bot_message(
896901
llm, prompt, custom_callback_handlers=[streaming_handler]
897902
)
898903

904+
# it seems that removing the reasoning traces is llm_call responsibility
905+
#
906+
899907
result = self.llm_task_manager.parse_task_output(
900908
Task.GENERAL, output=result
901909
)
902910

911+
result = result.text
912+
903913
log.info(
904914
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
905915
time() - t0,
@@ -963,6 +973,8 @@ async def generate_bot_message(
963973
Task.GENERATE_BOT_MESSAGE, output=result
964974
)
965975

976+
result = result.text
977+
966978
# TODO: catch openai.error.InvalidRequestError from exceeding max token length
967979

968980
result = get_multiline_response(result)
@@ -1055,10 +1067,11 @@ async def generate_value(
10551067
result = self.llm_task_manager.parse_task_output(
10561068
Task.GENERATE_VALUE, output=result
10571069
)
1070+
text = result.text
10581071

10591072
# We only use the first line for now
10601073
# TODO: support multi-line values?
1061-
value = result.strip().split("\n")[0]
1074+
value = text.strip().split("\n")[0]
10621075

10631076
# Because of conventions from other languages, sometimes the LLM might add
10641077
# a ";" at the end of the line. We remove that
@@ -1266,22 +1279,23 @@ async def generate_intent_steps_message(
12661279
result = self.llm_task_manager.parse_task_output(
12671280
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
12681281
)
1282+
text = result.text
12691283

12701284
# TODO: Implement logic for generating more complex Colang next steps (multi-step),
12711285
# not just a single bot intent.
12721286

12731287
# Get the next 2 non-empty lines, these should contain:
12741288
# line 1 - user intent, line 2 - bot intent.
12751289
# Afterwards we have the bot message.
1276-
next_three_lines = get_top_k_nonempty_lines(result, k=2)
1290+
next_three_lines = get_top_k_nonempty_lines(text, k=2)
12771291
user_intent = next_three_lines[0] if len(next_three_lines) > 0 else None
12781292
bot_intent = next_three_lines[1] if len(next_three_lines) > 1 else None
12791293
bot_message = None
12801294
if bot_intent:
1281-
pos = result.find(bot_intent)
1295+
pos = text.find(bot_intent)
12821296
if pos != -1:
12831297
# The bot message could be multiline
1284-
bot_message = result[pos + len(bot_intent) :]
1298+
bot_message = text[pos + len(bot_intent) :]
12851299
bot_message = get_multiline_response(bot_message)
12861300
bot_message = strip_quotes(bot_message)
12871301
# Quick hack for degenerated / empty bot messages
@@ -1348,7 +1362,8 @@ async def generate_intent_steps_message(
13481362
result = self.llm_task_manager.parse_task_output(
13491363
Task.GENERAL, output=result
13501364
)
1351-
text = result.strip()
1365+
text = result.text
1366+
text = text.strip()
13521367
if text.startswith('"'):
13531368
text = text[1:-1]
13541369

nemoguardrails/actions/v2_x/generation.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def _collect_user_intent_and_examples(
197197

198198
# We add these in reverse order so the most relevant is towards the end.
199199
for result in reversed(results):
200-
examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
200+
examples += f'user action: user said "{result.text}"\nuser intent: {result.meta["intent"]}\n\n'
201201
if result.meta["intent"] not in potential_user_intents:
202202
potential_user_intents.append(result.meta["intent"])
203203

@@ -302,6 +302,8 @@ async def generate_user_intent(
302302
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
303303
)
304304

305+
result = result.text
306+
305307
user_intent = get_first_nonempty_line(result)
306308
# GTP-4o often adds 'user intent: ' in front
307309
if user_intent and ":" in user_intent:
@@ -378,6 +380,8 @@ async def generate_user_intent_and_bot_action(
378380
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
379381
)
380382

383+
result = result.text
384+
381385
user_intent = get_first_nonempty_line(result)
382386

383387
if user_intent and ":" in user_intent:
@@ -458,6 +462,8 @@ async def passthrough_llm_action(
458462

459463
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
460464

465+
text = result.text
466+
461467
return text
462468

463469
@action(name="CheckValidFlowExistsAction", is_system_action=True)
@@ -541,6 +547,8 @@ async def generate_flow_from_instructions(
541547
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
542548
)
543549

550+
result = result.text
551+
544552
# TODO: why this is not part of a filter or output_parser?
545553
#
546554
lines = _remove_leading_empty_lines(result).split("\n")
@@ -613,6 +621,8 @@ async def generate_flow_from_name(
613621
task=Task.GENERATE_FLOW_FROM_NAME, output=result
614622
)
615623

624+
result = result.text
625+
616626
lines = _remove_leading_empty_lines(result).split("\n")
617627

618628
if lines[0].startswith("flow"):
@@ -680,6 +690,8 @@ async def generate_flow_continuation(
680690
task=Task.GENERATE_FLOW_CONTINUATION, output=result
681691
)
682692

693+
result = result.text
694+
683695
lines = _remove_leading_empty_lines(result).split("\n")
684696

685697
if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
@@ -806,6 +818,8 @@ async def generate_value(
806818
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
807819
)
808820

821+
result = result.text
822+
809823
# We only use the first line for now
810824
# TODO: support multi-line values?
811825
value = result.strip().split("\n")[0]
@@ -913,6 +927,8 @@ async def generate_flow(
913927
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
914928
)
915929

930+
result = result.text
931+
916932
result = _remove_leading_empty_lines(result)
917933
lines = result.split("\n")
918934
if "codeblock" in lines[0]:

nemoguardrails/library/content_safety/actions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def content_safety_check_input(
8080
result = await llm_call(llm, check_input_prompt, stop=stop)
8181

8282
result = llm_task_manager.parse_task_output(task, output=result)
83+
result = result.text
8384

8485
try:
8586
is_safe, violated_policies = result
@@ -162,6 +163,8 @@ async def content_safety_check_output(
162163

163164
result = llm_task_manager.parse_task_output(task, output=result)
164165

166+
result = result.text
167+
165168
try:
166169
is_safe, violated_policies = result
167170
except TypeError:

nemoguardrails/library/self_check/facts/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def self_check_facts(
8282
task, output=response, forced_output_parser="is_content_safe"
8383
)
8484

85+
result = result.text
8586
is_not_safe, _ = result
8687

8788
result = float(not is_not_safe)

nemoguardrails/library/self_check/input_check/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ async def self_check_input(
8383
task, output=response, forced_output_parser="is_content_safe"
8484
)
8585

86+
result = result.text
8687
is_safe, _ = result
8788

8889
if not is_safe:

nemoguardrails/library/self_check/output_check/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def self_check_output(
8787
task, output=response, forced_output_parser="is_content_safe"
8888
)
8989

90+
result = result.text
9091
is_safe, _ = result
9192

9293
return is_safe

tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def custom_llm_request(
5050
result = await llm_call(llm, prompt, stop=stop)
5151

5252
result = llm_task_manager.parse_task_output(prompt_template_name, output=result)
53+
result = result.text
5354

5455
# Any additional parsing of the output
5556
value = result.strip().split("\n")[0]

0 commit comments

Comments
 (0)