Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate groq / cerebras #466

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions refact_known_models/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,92 @@
"pp1000t_generated": 600, # $0.60 / 1M tokens
"filter_caps": ["chat", "tools"],
},
"groq-llama-3.1-8b": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.1-8b-instant",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"groq-llama-3.1-70b": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.1-70b-versatile",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"groq-llama-3.2-1b": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.2-1b-preview",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"groq-llama-3.2-3b": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.2-3b-preview",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"groq-llama-3.2-11b-vision": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.2-11b-vision-preview",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"groq-llama-3.2-90b-vision": {
"backend": "litellm",
"provider": "groq",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "groq/llama-3.2-90b-vision-preview",
"T": 128_000,
"T_out": 8000,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"cerebras-llama3.1-8b": {
"backend": "litellm",
"provider": "cerebras",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "cerebras/llama3.1-8b",
"T": 8192,
"T_out": 4096,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
},
"cerebras-llama3.1-70b": {
"backend": "litellm",
"provider": "cerebras",
"tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer",
"resolve_as": "cerebras/llama3.1-70b",
"T": 8192,
"T_out": 4096,
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat"],
}
}
6 changes: 6 additions & 0 deletions refact_utils/finetune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def _add_results_for_passthrough_provider(provider: str) -> None:
if data.get('anthropic_api_enable'):
_add_results_for_passthrough_provider('anthropic')

if data.get('cerebras_api_enable'):
_add_results_for_passthrough_provider('cerebras')

if data.get('groq_api_enable'):
_add_results_for_passthrough_provider('groq')

for k, v in data.get("model_assign", {}).items():
if model_dict := [d for d in data['models'] if d['name'] == k]:
model_dict = model_dict[0]
Expand Down
7 changes: 6 additions & 1 deletion refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def _integrations_env_setup(env_var_name: str, api_key_name: str, api_enable_nam

_integrations_env_setup("OPENAI_API_KEY", "openai_api_key", "openai_api_enable")
_integrations_env_setup("ANTHROPIC_API_KEY", "anthropic_api_key", "anthropic_api_enable")
_integrations_env_setup("GROQ_API_KEY", "groq_api_key", "groq_api_enable")
_integrations_env_setup("CEREBRAS_API_KEY", "cerebras_api_key", "cerebras_api_enable")

def _models_available_dict_rewrite(self, models_available: List[str]) -> Dict[str, Any]:
rewrite_dict = {}
Expand Down Expand Up @@ -608,7 +610,10 @@ async def chat_completion_streamer():
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list:
if model_dict.get('backend') == 'litellm':
model_name = model_dict.get('resolve_as', post.model)
if model_name not in litellm.model_list:
log(f"warning: requested model {model_name} is not in the litellm.model_list (this might not be the issue for some providers)")
log(f"chat/completions: model resolve {post.model} -> {model_name}")
prompt_tokens_n = litellm.token_counter(model_name, messages=messages)
if post.tools:
Expand Down
4 changes: 4 additions & 0 deletions refact_webgui/webgui/selfhost_model_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def first_run(self):
},
"openai_api_enable": False,
"anthropic_api_enable": False,
"groq_api_enable": False,
"cerebras_api_enable": False,
}
self.models_to_watchdog_configs(default_config)

Expand Down Expand Up @@ -255,6 +257,8 @@ def models_info(self):
def model_assignment(self):
if os.path.exists(env.CONFIG_INFERENCE):
j = json.load(open(env.CONFIG_INFERENCE, "r"))
j["groq_api_enable"] = j.get("groq_api_enable", False)
j["cerebras_api_enable"] = j.get("cerebras_api_enable", False)
else:
j = {"model_assign": {}}

Expand Down
4 changes: 4 additions & 0 deletions refact_webgui/webgui/selfhost_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def _add_models_for_passthrough_provider(provider):
_add_models_for_passthrough_provider('openai')
if j.get("anthropic_api_enable"):
_add_models_for_passthrough_provider('anthropic')
if j.get("groq_api_enable"):
_add_models_for_passthrough_provider('groq')
if j.get("cerebras_api_enable"):
_add_models_for_passthrough_provider('cerebras')

return self._models_available

Expand Down
8 changes: 8 additions & 0 deletions refact_webgui/webgui/static/tab-model-hosting.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ <h3>3rd Party APIs</h3>
<input class="form-check-input" type="checkbox" role="switch" id="enable_anthropic">
<label class="form-check-label" for="enable_anthropic">Enable Anthropic API</label>
</div>
<div class="form-check form-switch">
<input class="form-check-input" type="checkbox" role="switch" id="enable_groq">
<label class="form-check-label" for="enable_groq">Enable Groq API</label>
</div>
<div class="form-check form-switch">
<input class="form-check-input" type="checkbox" role="switch" id="enable_cerebras">
<label class="form-check-label" for="enable_cerebras">Enable Cerebras API</label>
</div>
<div class="chat-enabler-status">
To enable Chat GPT add your API key in the <span id="redirect2credentials" class="main-tab-button fake-link" data-tab="settings">API Keys tab</span>.
</div>
Expand Down
6 changes: 6 additions & 0 deletions refact_webgui/webgui/static/tab-model-hosting.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ function get_models()

integration_switch_init('enable_chat_gpt', models_data['openai_api_enable']);
integration_switch_init('enable_anthropic', models_data['anthropic_api_enable']);
integration_switch_init('enable_groq', models_data['groq_api_enable']);
integration_switch_init('enable_cerebras', models_data['cerebras_api_enable']);

const more_gpus_notification = document.querySelector('.model-hosting-error');
if(data.hasOwnProperty('more_models_than_gpus') && data.more_models_than_gpus) {
Expand All @@ -140,12 +142,16 @@ function get_models()
function save_model_assigned() {
const openai_enable = document.querySelector('#enable_chat_gpt');
const anthropic_enable = document.querySelector('#enable_anthropic');
const groq_enable = document.querySelector('#enable_groq');
const cerebras_enable = document.querySelector('#enable_cerebras');
const data = {
model_assign: {
...models_data.model_assign,
},
openai_api_enable: openai_enable.checked,
anthropic_api_enable: anthropic_enable.checked,
groq_api_enable: groq_enable.checked,
cerebras_api_enable: cerebras_enable.checked,
};
console.log(data);
fetch("/tab-host-models-assign", {
Expand Down
4 changes: 4 additions & 0 deletions refact_webgui/webgui/static/tab-settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ <h2>API Integrations</h2>
<input type="text" name="openai_api_key" value="" class="form-control" id="openai_api_key">
<label for="anthropic_api_key" class="form-label mt-4">Anthropic API Key</label>
<input type="text" name="anthropic_api_key" value="" class="form-control" id="anthropic_api_key">
<label for="groq_api_key" class="form-label mt-4">Groq API Key</label>
<input type="text" name="groq_api_key" value="" class="form-control" id="groq_api_key">
<label for="cerebras_api_key" class="form-label mt-4">Cerebras API Key</label>
<input type="text" name="cerebras_api_key" value="" class="form-control" id="cerebras_api_key">
<!-- <div class="d-flex flex-row-reverse mt-3"><button type="button" class="btn btn-primary" id="integrations-save">Save</button></div>-->
</div>
</div>
Expand Down
8 changes: 8 additions & 0 deletions refact_webgui/webgui/static/tab-settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ function throw_int_saved_success_toast(msg) {
function save_integration_api_keys() {
const openai_api_key = document.getElementById('openai_api_key');
const anthropic_api_key = document.getElementById('anthropic_api_key');
const groq_api_key = document.getElementById('groq_api_key');
const cerebras_api_key = document.getElementById('cerebras_api_key');
const huggingface_api_key = document.getElementById('huggingface_api_key');
fetch("/tab-settings-integrations-save", {
method: "POST",
Expand All @@ -181,6 +183,8 @@ function save_integration_api_keys() {
body: JSON.stringify({
openai_api_key: openai_api_key.getAttribute('data-value'),
anthropic_api_key: anthropic_api_key.getAttribute('data-value'),
groq_api_key: groq_api_key.getAttribute('data-value'),
cerebras_api_key: cerebras_api_key.getAttribute('data-value'),
huggingface_api_key: huggingface_api_key.getAttribute('data-value'),
})
})
Expand All @@ -189,6 +193,8 @@ function save_integration_api_keys() {
throw_int_saved_success_toast('API Key saved')
openai_api_key.setAttribute('data-saved-value', openai_api_key.getAttribute('data-value'))
anthropic_api_key.setAttribute('data-saved-value', anthropic_api_key.getAttribute('data-value'))
groq_api_key.setAttribute('data-saved-value', groq_api_key.getAttribute('data-value'))
cerebras_api_key.setAttribute('data-saved-value', cerebras_api_key.getAttribute('data-value'))
huggingface_api_key.setAttribute('data-saved-value', huggingface_api_key.getAttribute('data-value'))
});
}
Expand Down Expand Up @@ -222,6 +228,8 @@ export function tab_settings_integrations_get() {
.then(function(data) {
integrations_input_init(document.getElementById('openai_api_key'), data['openai_api_key']);
integrations_input_init(document.getElementById('anthropic_api_key'), data['anthropic_api_key']);
integrations_input_init(document.getElementById('groq_api_key'), data['groq_api_key']);
integrations_input_init(document.getElementById('cerebras_api_key'), data['cerebras_api_key']);
integrations_input_init(document.getElementById('huggingface_api_key'), data['huggingface_api_key']);
});
}
Expand Down
2 changes: 2 additions & 0 deletions refact_webgui/webgui/tab_models_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class TabHostModelsAssign(BaseModel):
# integrations
openai_api_enable: bool = False
anthropic_api_enable: bool = False
groq_api_enable: bool = False
cerebras_api_enable: bool = False

model_config = ConfigDict(protected_namespaces=()) # avoiding model_ namespace protection

Expand Down
2 changes: 2 additions & 0 deletions refact_webgui/webgui/tab_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class SSHKey(BaseModel):
class Integrations(BaseModel):
openai_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
groq_api_key: Optional[str] = None
cerebras_api_key: Optional[str] = None
huggingface_api_key: Optional[str] = None

def __init__(self, models_assigner: ModelAssigner, *args, **kwargs):
Expand Down