Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config):
openai_config, extra_kwargs = self._separate_openai_config(base_config)
if type(config_list) is list and len(config_list) == 0:
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
if isinstance(config_list, list) and any((c.get("model") is None for c in config_list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check shouldn't be here because one can provide the model in base_config and omit model in config_list.

logger.warning(
"OpenAIWrapper: one or more configs in config_list do not have a model specified, which may not be intended."
)
if config_list:
self._check_model_in_config_list(model=base_config.get("model"), config_list=config_list)
self._default_model = base_config.get("model", None)
config_list = [config.copy() for config in config_list] # make a copy before modifying
self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config
self._config_list = [
Expand Down Expand Up @@ -184,6 +190,16 @@ def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> D
]
return params

@staticmethod
def _check_model_in_config_list(model, config_list=None):
if model is None or config_list is None:
return

if not any(c.get("model", None) == model for c in config_list):
raise ValueError(
f"model {model} is not found in the config_list. Please add the model with the corresponding api_key when creating the OpenAIWrapper."
)

Comment on lines +193 to +202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check doesn't make sense as config_list is used to update the base_config. It's valid to have model in the base_config and not specified in config_list.

def create(self, **config):
"""Make a completion for a given config using openai's clients.
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
Expand Down Expand Up @@ -215,9 +231,14 @@ def yes_or_no_filter(context, response):
if ERROR:
raise ERROR
last = len(self._clients) - 1
model_to_use = config.get("model", self._default_model)
self._check_model_in_config_list(model=model_to_use, config_list=self._config_list)

for i, client in enumerate(self._clients):
if model_to_use and model_to_use != self._config_list[i].get("model"):
continue
# merge the input config with the i-th config in the config list
full_config = {**config, **self._config_list[i]}
full_config = {**self._config_list[i], **config}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# process for azure
Expand Down
1 change: 1 addition & 0 deletions test/agentchat/contrib/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def setUp(self):
"timeout": 600,
"seed": 42,
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}],
"model": "llava-fake",
},
)

Expand Down
1 change: 1 addition & 0 deletions test/agentchat/contrib/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self):
"timeout": 600,
"seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}],
"model": "gpt-4-vision-preview",
},
)

Expand Down
40 changes: 40 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,46 @@ def test_usage_summary():
assert client.actual_usage_summary is None, "No actual cost should be recorded"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_create_with_different_models():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-4"]},
)
messages = [{"role": "user", "content": "2+2="}]

# create with the same model in config_list
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=messages, model="gpt-3.5-turbo")
assert response.model in ["gpt-3.5-turbo", "gpt-35-turbo"], "Model not consistent."
response = client.create(messages=messages, model="gpt-4")
assert response.model in ["gpt-4", "gpt-4-0613"], "Model not consistent."

# create with non-existing model in config_list
try:
response = client.create(messages=messages, model="gpt-3.5-turbo-16k")
except ValueError as e:
print(e)
else:
raise ValueError("Expected ValueError")

# initialize with a specific model
client = OpenAIWrapper(config_list=config_list, model="gpt-4-1106-preview")
response = client.create(messages=messages)
assert response.model in ["gpt-4", "gpt-4-0613"], "Should create with the specified model."
response = client.create(messages=messages, model="gpt-3.5-turbo")
assert response.model in ["gpt-3.5-turbo", "gpt-35-turbo"], "Initialized model should be overwritten."

# initialize with inconsistent models in config_list
try:
client = OpenAIWrapper(config_list=config_list, model="gpt-3.5-turbo-16k")
except ValueError as e:
print(e)
else:
raise ValueError("Expected ValueError")


if __name__ == "__main__":
test_aoai_chat_completion()
test_chat_completion()
Expand Down