Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refresh Cohere #141

Merged
merged 12 commits into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,832 changes: 0 additions & 4,832 deletions results/cohere-chat/model_outputs.json

This file was deleted.

1,606 changes: 803 additions & 803 deletions results/cohere/model_outputs.json

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions src/alpaca_eval/decoders/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import os
import random
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

import cohere
import tqdm
Expand Down Expand Up @@ -51,24 +51,24 @@ def cohere_completions(

with utils.Timer() as t:
if num_procs == 1:
completions = [_cohere_completion_helper(prompt, **kwargs) for prompt in tqdm.tqdm(prompts, desc="prompts")]
completions_and_token_counts = [_cohere_completion_helper(prompt, **kwargs) for prompt in tqdm.tqdm(prompts, desc="prompts")]
else:
with multiprocessing.Pool(num_procs) as p:
partial_completion_helper = functools.partial(_cohere_completion_helper, **kwargs)
completions = list(
completions_and_token_counts = list(
tqdm.tqdm(
p.imap(partial_completion_helper, prompts),
desc="prompts",
total=len(prompts),
)
)
logging.info(f"Completed {n_examples} examples in {t}.")

# cohere charges $2.5 for every 1000 call to API that is less than 1000 characters. Only counting prompts here
price = [2.5 / 1000 * math.ceil(len(prompt) / 1000) for prompt in prompts]
completions, num_tokens = zip(*completions_and_token_counts)
price_per_token = 0.000015 # cohere charges $0.000015 per token.
price_per_example = [price_per_token * n for n in num_tokens]
avg_time = [t.duration / n_examples] * len(completions)

return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
return dict(completions=list(completions), price_per_example=price_per_example, time_per_example=avg_time)


def _cohere_completion_helper(
Expand All @@ -79,7 +79,7 @@ def _cohere_completion_helper(
max_tries=5,
mode="instruct",
**kwargs,
) -> str:
) -> Tuple[str,int]:
cohere_api_key = random.choice(cohere_api_keys)
client = cohere.Client(cohere_api_key)

Expand All @@ -89,20 +89,22 @@ def _cohere_completion_helper(
for trynum in range(max_tries): # retry errors
try:
if mode == "instruct":
response = client.generate(prompt=prompt, **curr_kwargs)
response = client.generate(prompt=prompt, return_likelihoods="ALL", **curr_kwargs)
text = response[0].text
num_tokens = len(response[0].token_likelihoods)
elif mode == "chat":
response = client.chat(prompt, **curr_kwargs)
text = response.text
num_tokens = 0 # not implemented for chat
else:
raise ValueError(f"Invalid mode {mode} for cohere_completions")

if text == "":
raise CohereError("Empty string response")

return text
return text, num_tokens

except CohereError as e:
print(f"Try #{trynum+1}/{max_tries}: Error running prompt {repr(prompt)}: {e}")

return " " # placeholder response for errors, doesn't allow empty string
return " ", 0 # placeholder response for errors, doesn't allow empty string
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ openbuddy-llama2-70b-v10.1,87.67123287671232,1.1508417516577765,701,96,6,803,com
openchat-v2-w-13b,87.1268656716418,1.1769197439396015,699,102,3,804,community,1566.0
openbuddy-llama-65b-v8,86.53366583541147,1.2029182403474274,693,107,2,802,community,1162.0
wizardlm-13b-v1.1,86.31840796019901,1.2063217831272972,692,108,4,804,community,1525.0
cohere,85.0560398505604,1.2558329840021718,682,119,2,803,community,1715.0
openchat-v2-13b,84.96894409937889,1.2572979835605944,683,120,2,805,community,1564.0
humpback-llama-65b,83.70646766169155,1.3071034735987248,672,130,2,804,community,1269.0
vicuna-13b-v1.3,82.11180124223603,1.348769957803504,660,143,2,805,verified,1132.0
Expand Down Expand Up @@ -52,8 +53,6 @@ falcon-40b-instruct,45.71428571428572,1.7524717060805597,366,435,4,805,minimal,6
alpaca-farm-ppo-sim-gpt4-20k,44.099378881987576,1.7399772578861137,350,445,10,805,verified,511.0
pythia-12b-mix-sft,41.86335403726708,1.737637146007538,336,467,2,805,verified,913.0
alpaca-farm-ppo-human,41.24223602484472,1.7271813123250834,328,469,8,805,minimal,803.0
cohere-chat,29.565217391304348,1.5949050483247118,232,561,12,805,community,779.0
cohere,28.385093167701864,1.5717547121761728,221,569,15,805,community,682.0
alpaca-7b,26.459627329192543,1.535711469748,205,584,16,805,minimal,396.0
oasst-sft-pythia-12b,25.962732919254663,1.5261079289535309,201,588,16,805,verified,726.0
falcon-7b-instruct,23.60248447204969,1.4898235369056625,187,612,6,805,verified,478.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ falcon-40b-instruct,46.70807453416149,1.7551420072945083,minimal,4,805,374,427,,
alpaca-farm-ppo-human,46.45962732919255,1.750131850347461,minimal,8,805,370,427,,803
pythia-12b-mix-sft,43.22981366459627,1.7449120766669366,verified,2,805,347,456,,913
oasst-sft-pythia-12b,32.79503105590062,1.6369108459870174,verified,16,805,256,533,,726
cohere-chat,32.79503105590062,1.6416235300873216,community,12,805,258,535,,779
cohere,32.608695652173914,1.635641080422956,community,15,805,255,535,,682
alpaca-7b,32.298136645962735,1.630307861230374,minimal,16,805,252,537,,396
falcon-7b-instruct,29.565217391304348,1.6021542242903124,verified,6,805,235,564,,478
text_davinci_001,21.490683229813666,1.421716368655911,minimal,20,805,163,622,,296
8 changes: 0 additions & 8 deletions src/alpaca_eval/models_configs/cohere-chat/configs.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions src/alpaca_eval/models_configs/cohere/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cohere:
prompt_template: "cohere/prompt.txt"
fn_completions: "cohere_completions"
completions_kwargs:
model_name: "command"
model_name: "command-nightly"
mode: "instruct"
max_tokens: 2048
pretty_name: "Cohere"
pretty_name: "Cohere Command"
2 changes: 1 addition & 1 deletion src/alpaca_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def make_prompts(
if df.empty:
return [], df

text_to_format = re.findall("{([^ \s]+?)}", template)
text_to_format = re.findall(r"{([^ \s]+?)}", template)
n_occurrences = Counter(text_to_format)

if not all([n == batch_size for n in n_occurrences.values()]):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decoders_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_anthropic_completions(mocker):
def test_cohere_completions(mocker):
mocker.patch(
"alpaca_eval.decoders.cohere._cohere_completion_helper",
return_value="Mocked completion text",
return_value=["Mocked completion text",42],
)
result = cohere_completions(["Prompt 1", "Prompt 2"], num_procs=1)
_run_all_asserts_completions(result)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pairwise_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def expected_annotations():
def single_annotator():
return SinglePairwiseAnnotator(
prompt_template="text_davinci_003/basic_prompt.txt",
completion_parser_kwargs=dict(outputs_to_match={1: "(?:^|\n) ?Output \(a\)", 2: "(?:^|\n) ?Output \(b\)"}),
completion_parser_kwargs=dict(outputs_to_match={1: r"(?:^|\n) ?Output \(a\)", 2: "(?:^|\n) ?Output \(b\)"}),
is_randomize_output_order=False,
is_shuffle=False,
)
Expand Down