From e7aa3c3a8f676b197725a74ab961db150e9798f4 Mon Sep 17 00:00:00 2001 From: fariquelme Date: Thu, 2 May 2024 10:49:29 -0400 Subject: [PATCH] :bug:: Add function to replace env vars when using multi-llm config --- taskweaver/config/config_mgt.py | 6 +++++- taskweaver/utils/__init__.py | 1 + taskweaver/utils/json_replace.py | 12 ++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 taskweaver/utils/json_replace.py diff --git a/taskweaver/config/config_mgt.py b/taskweaver/config/config_mgt.py index 4a71532d..afd4cbac 100644 --- a/taskweaver/config/config_mgt.py +++ b/taskweaver/config/config_mgt.py @@ -4,6 +4,7 @@ import re from dataclasses import dataclass from typing import Any, Dict, List, Literal, NamedTuple, Optional +from ..utils import replace_env_vars AppConfigSourceType = Literal["override", "env", "json", "app", "default"] AppConfigValueType = Literal["str", "int", "float", "bool", "list", "enum", "path", "dict"] @@ -96,7 +97,10 @@ def _get_config_value( return val if var_name in self.json_file_store.keys(): - return self.json_file_store.get(var_name, default_value) + val = self.json_file_store.get(var_name, default_value) + # e.g., llm.api_base -> LLM_API_BASE + val = replace_env_vars(val) + return val if default_value is not None: return default_value diff --git a/taskweaver/utils/__init__.py b/taskweaver/utils/__init__.py index 51b8dd41..023e564a 100644 --- a/taskweaver/utils/__init__.py +++ b/taskweaver/utils/__init__.py @@ -11,6 +11,7 @@ from datetime import datetime from hashlib import md5 from typing import Any, Dict, List, Union +from .json_replace import replace_env_vars def create_id(length: int = 4) -> str: diff --git a/taskweaver/utils/json_replace.py b/taskweaver/utils/json_replace.py new file mode 100644 index 00000000..5b5592cb --- /dev/null +++ b/taskweaver/utils/json_replace.py @@ -0,0 +1,12 @@ +import os + +# Function to replace env variables on nested JSON +def replace_env_vars(data, key=None): + if isinstance(data, list): + return data + if isinstance(data, dict): + return {k: replace_env_vars(v) for k, v in data.items()} + env_var_name = data.upper().replace(".", "_") + env_val = os.environ.get(env_var_name, None) + return env_val if env_val != None else data +