|
20 | 20 | import yaml |
21 | 21 |
|
22 | 22 | from nemoguardrails.llm.types import Task |
23 | | -from nemoguardrails.rails.llm.config import RailsConfig, TaskPrompt |
| 23 | +from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt |
24 | 24 |
|
25 | 25 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
26 | 26 |
|
@@ -116,21 +116,33 @@ def _get_prompt( |
116 | 116 | raise ValueError(f"Could not find prompt for task {task_name} and model {model}") |
117 | 117 |
|
118 | 118 |
|
119 | | -def get_prompt(config: RailsConfig, task: Union[str, Task]) -> TaskPrompt: |
120 | | - """Return the prompt for the given task.""" |
121 | | - |
| 119 | +def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Model: |
| 120 | + """Return the model for the given task in the current config.""" |
122 | 121 | # Fetch current task parameters like name, models to use, and the prompting mode |
123 | 122 | task_name = str(task.value) if isinstance(task, Task) else task |
124 | 123 |
|
125 | | - task_model = "unknown" |
126 | 124 | if config.models: |
127 | 125 | _models = [model for model in config.models if model.type == task_name] |
128 | 126 | if not _models: |
129 | 127 | _models = [model for model in config.models if model.type == "main"] |
130 | 128 |
|
131 | | - task_model = _models[0].engine |
132 | | - if _models[0].model: |
133 | | - task_model += "/" + _models[0].model |
| 129 | + return _models[0] |
| 130 | + |
| 131 | + return None |
| 132 | + |
| 133 | + |
| 134 | +def get_prompt(config: RailsConfig, task: Union[str, Task]) -> TaskPrompt: |
| 135 | + """Return the prompt for the given task.""" |
| 136 | + |
| 137 | + # Fetch current task parameters like name, models to use, and the prompting mode |
| 138 | + task_name = str(task.value) if isinstance(task, Task) else task |
| 139 | + |
| 140 | + task_model = "unknown" |
| 141 | + _model = get_task_model(config, task) |
| 142 | + if _model: |
| 143 | + task_model = _model.engine |
| 144 | + if _model.model: |
| 145 | + task_model += "/" + _model.model |
134 | 146 |
|
135 | 147 | task_prompting_mode = "standard" |
136 | 148 | if config.prompting_mode: |
|
0 commit comments