From cb2a930c2502b7f7fd9539f691735e91eeac7d7d Mon Sep 17 00:00:00 2001 From: Laurel Orr Date: Sat, 1 Jul 2023 22:45:45 -0700 Subject: [PATCH] fix: dummy client to output tokens and random responses --- manifest/clients/dummy.py | 167 ++++++++++-------- manifest/response.py | 2 +- tests/test_client.py | 138 +++++++++++++-- tests/test_manifest.py | 362 +++++++++++++++++++++++++++++--------- 4 files changed, 502 insertions(+), 167 deletions(-) diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index 3a15577..d29a4b7 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -1,6 +1,10 @@ """Dummy client.""" +import hashlib import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tiktoken from manifest.clients.client import Client from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request @@ -14,7 +18,13 @@ class DummyClient(Client): # User param -> (client param, default value) PARAMS = { - "n": ("num_results", 1), + "engine": ("model", "text-davinci-003"), + "temperature": ("temperature", 0.0), + "max_tokens": ("max_tokens", 10), + "n": ("n", 1), + "top_p": ("top_p", 1.0), + "top_k": ("best_of", 1), + "batch_size": ("batch_size", 20), } REQUEST_CLS = LMRequest NAME = "dummy" @@ -33,6 +43,9 @@ def connect( connection_str: connection string. client_args: client arguments. """ + # We tiktoken as it is faster than HF for tokenizing + # Use any model to create the tokenizer + self.encoder = tiktoken.get_encoding("cl100k_base") for key in self.PARAMS: setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) @@ -74,7 +87,65 @@ def get_model_params(self) -> Dict: Returns: model params. """ - return {"engine": "dummy"} + return {"engine": "dummy", "model": getattr(self, "engine")} + + def get_mock_output( + self, output_toks: int, is_completion: bool, seed: Optional[int] = None + ) -> LMModelChoice: + """Return mock model output by generating random tokens.""" + np.random.seed(seed) + random_tokens = np.random.randint( + 0, self.encoder.max_token_value + 1, output_toks + ) + response = self.encoder.decode(random_tokens) # type: ignore + if is_completion: + np.random.seed(seed) + random_logprobs = np.random.uniform( + low=-2, high=-0.00001, size=output_toks + ).tolist() + else: + # Return all Nones to mimic chat models + # OpenAI chat models do not return logprobs + random_logprobs = [None] * output_toks + return LMModelChoice( + text=response, + token_logprobs=random_logprobs, + tokens=random_tokens.tolist(), + ) + + def get_mock_choices( + self, + prompt_list: List[str], + request_params: Dict, + is_completion: bool, + ) -> Tuple[List[LMModelChoice], List[Usage]]: + """Get choices and usages of mock output.""" + choices = [] + usages = [] + for prompt in prompt_list: + num_prompt_tokens = len(self.encoder.encode(prompt)) + if request_params["temperature"] == 0: + # Get integer seed from hash of prompt + seed = ( + int(hashlib.sha256(prompt.encode("utf-8")).hexdigest(), 16) + % 10**8 + ) + else: + # Get random seed + seed = None + for _ in range(int(request_params["n"])): + choice = self.get_mock_output( + request_params["max_tokens"], is_completion=is_completion, seed=seed + ) + choices.append(choice) + usages.append( + Usage( + prompt_tokens=num_prompt_tokens, + completion_tokens=request_params["max_tokens"], + total_tokens=num_prompt_tokens + request_params["max_tokens"], + ) + ) + return choices, usages def run_request(self, request: Request) -> Response: """ @@ -88,32 +159,19 @@ def run_request(self, request: Request) -> Response: request parameters as dict. """ if isinstance(request.prompt, list): - num_results = len(request.prompt) + prompt_list = request.prompt else: - num_results = 1 + prompt_list = [request.prompt] request_params = request.to_dict(self.PARAMS) + choices, usages = self.get_mock_choices( + prompt_list, request_params, is_completion=True + ) return Response( - response=ModelChoices( - choices=[LMModelChoice(text="hello")] # type: ignore - * int(request_params["num_results"]) - * num_results - ), + response=ModelChoices(choices=choices), cached=False, request=request, - usages=Usages( - usages=[ - Usage( - **{ - "prompt_tokens": 1, - "completion_tokens": 1, - "total_tokens": 2, - } - ) - ] - * int(request_params["num_results"]) - * num_results - ), + usages=Usages(usages=usages), response_type="text", request_type=self.REQUEST_CLS, ) @@ -145,35 +203,17 @@ def run_chat_request( Returns: response. """ - num_results = 1 - response_dict = { - "choices": [ - { - "text": request.prompt[0]["content"], - } - for i in range(num_results) - ] - } + prompt_list = ["_".join(pmp["content"] for pmp in request.prompt)] + request_params = request.to_dict(self.PARAMS) + + choices, usages = self.get_mock_choices( + prompt_list, request_params, is_completion=False + ) return Response( - response=ModelChoices( - choices=[ - LMModelChoice(**choice) # type: ignore - for choice in response_dict["choices"] - ] - ), + response=ModelChoices(choices=choices), cached=False, request=request, - usages=Usages( - usages=[ - Usage( - **{ - "prompt_tokens": 1, - "completion_tokens": 1, - "total_tokens": 2, - } - ) - ] - ), + usages=Usages(usages=usages), response_type="text", request_type=LMChatRequest, ) @@ -193,30 +233,19 @@ def run_score_prompt_request( request parameters as dict. """ if isinstance(request.prompt, list): - num_results = len(request.prompt) + prompt_list = request.prompt else: - num_results = 1 - response_dict = { - "choices": [ - { - "text": request.prompt - if isinstance(request.prompt, str) - else request.prompt[i], - "token_logprobs": [0.3], - } - for i in range(num_results) - ] - } + prompt_list = [request.prompt] + request_params = request.to_dict(self.PARAMS) + + choices, usages = self.get_mock_choices( + prompt_list, request_params, is_completion=True + ) return Response( - response=ModelChoices( - choices=[ - LMModelChoice(**choice) # type: ignore - for choice in response_dict["choices"] - ] - ), + response=ModelChoices(choices=choices), cached=False, request=request, - usages=None, + usages=Usages(usages=usages), response_type="text", request_type=LMScoreRequest, ) diff --git a/manifest/response.py b/manifest/response.py index 7760b8b..7e61b7b 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -53,7 +53,7 @@ class LMModelChoice(BaseModel): """Model single completion.""" text: str - token_logprobs: Optional[List[float]] = None + token_logprobs: Optional[List[Optional[float]]] = None tokens: Optional[List[str]] = None diff --git a/tests/test_client.py b/tests/test_client.py index 5f5810b..b2b9fb0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -19,8 +19,19 @@ def test_init() -> None: def test_get_params() -> None: """Test get param functions.""" client = DummyClient(connection_str=None) - assert client.get_model_params() == {"engine": "dummy"} - assert client.get_model_inputs() == ["n"] + assert client.get_model_params() == { + "engine": "dummy", + "model": "text-davinci-003", + } + assert client.get_model_inputs() == [ + "engine", + "temperature", + "max_tokens", + "n", + "top_p", + "top_k", + "batch_size", + ] def test_get_request() -> None: @@ -31,43 +42,148 @@ def test_get_request() -> None: response = client.run_request(request_params) assert client.get_cache_key(request_params) == { "prompt": "hello", - "num_results": 3, + "model": "text-davinci-003", + "n": 3, + "temperature": 0.0, + "max_tokens": 10, + "top_p": 1.0, + "best_of": 1, "engine": "dummy", "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3, + "choices": [ + { + "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501 + "token_logprobs": [ + -0.2649905035732101, + -1.210794839387105, + -1.2173929801003434, + -0.7758233850171001, + -0.7165940659570416, + -1.7430328887209088, + -1.5379414228820203, + -1.7838011423472508, + -1.139095076944217, + -0.6321855879833425, + ], + "tokens": [ + "70470", + "80723", + "52693", + "39743", + "38983", + "1303", + "56072", + "22306", + "17738", + "53176", + ], + } + ] + * 3 } assert response.get_usage_obj().dict() == { - "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3, + "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] + * 3, } request_params = client.get_request("hello", {"n": 5}) response = client.run_request(request_params) assert client.get_cache_key(request_params) == { "prompt": "hello", - "num_results": 5, + "model": "text-davinci-003", + "n": 5, + "temperature": 0.0, + "max_tokens": 10, + "top_p": 1.0, + "best_of": 1, "engine": "dummy", "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, + "choices": [ + { + "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501 + "token_logprobs": [ + -0.2649905035732101, + -1.210794839387105, + -1.2173929801003434, + -0.7758233850171001, + -0.7165940659570416, + -1.7430328887209088, + -1.5379414228820203, + -1.7838011423472508, + -1.139095076944217, + -0.6321855879833425, + ], + "tokens": [ + "70470", + "80723", + "52693", + "39743", + "38983", + "1303", + "56072", + "22306", + "17738", + "53176", + ], + } + ] + * 5 } assert response.get_usage_obj().dict() == { - "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, + "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] + * 5, } request_params = client.get_request(["hello"] * 5, {"n": 1}) response = client.run_request(request_params) assert client.get_cache_key(request_params) == { "prompt": ["hello"] * 5, - "num_results": 1, + "model": "text-davinci-003", + "n": 1, + "temperature": 0.0, + "max_tokens": 10, + "top_p": 1.0, + "best_of": 1, "engine": "dummy", "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, + "choices": [ + { + "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501 + "token_logprobs": [ + -0.2649905035732101, + -1.210794839387105, + -1.2173929801003434, + -0.7758233850171001, + -0.7165940659570416, + -1.7430328887209088, + -1.5379414228820203, + -1.7838011423472508, + -1.139095076944217, + -0.6321855879833425, + ], + "tokens": [ + "70470", + "80723", + "52693", + "39743", + "38983", + "1303", + "56072", + "22306", + "17738", + "53176", + ], + } + ] + * 5 } assert response.get_usage_obj().dict() == { - "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, + "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] + * 5, } diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 8b7ac6c..12cf291 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -73,6 +73,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: cache_name="sqlite", cache_connection=sqlite_cache, n=n, + temperature=0.0, ) prompt = "This is a prompt" @@ -80,8 +81,6 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: result = manifest.run(prompt, return_response=return_response, bad_input=5) assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized." - # Allow params in the request object but not in the client to go through - assert "top_k" not in manifest.client_pool.get_next_client().PARAMS result = manifest.run(prompt, return_response=return_response, top_k=5) assert result is not None @@ -96,21 +95,30 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: res = result.get_response(manifest.stop_token) else: res = cast(str, result) + assert ( manifest.cache.get( { - "prompt": "This is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "This is a prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) if n == 1: - assert res == "hello" + assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines" else: - assert res == ["hello", "hello"] + assert res == [ + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + ] prompt = "This is a prompt" result = manifest.run(prompt, run_id="34", return_response=return_response) @@ -126,19 +134,27 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: assert ( manifest.cache.get( { - "prompt": "This is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "This is a prompt", "request_cls": "LMRequest", - "num_results": n, + "temperature": 0.0, + "top_p": 1.0, "run_id": "34", } ) is not None ) if n == 1: - assert res == "hello" + assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines" else: - assert res == ["hello", "hello"] + assert res == [ + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + ] prompt = "Hello is a prompt" result = manifest.run(prompt, return_response=return_response) @@ -154,45 +170,60 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: assert ( manifest.cache.get( { - "prompt": "Hello is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "Hello is a prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) if n == 1: - assert res == "hello" + assert res == "appersstoff210 currentNodeleh norm unified_voice DIYHam" else: - assert res == ["hello", "hello"] + assert res == [ + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + ] prompt = "Hello is a prompt" - result = manifest.run(prompt, stop_token="ll", return_response=return_response) + result = manifest.run( + prompt, stop_token=" current", return_response=return_response + ) if return_response: assert isinstance(result, Response) result = cast(Response, result) assert len(result.get_usage_obj().usages) == len( result.get_response_obj().choices ) - res = result.get_response(stop_token="ll") + res = result.get_response(stop_token=" current") else: res = cast(str, result) assert ( manifest.cache.get( { - "prompt": "Hello is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "Hello is a prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) if n == 1: - assert res == "he" + assert res == "appersstoff210" else: - assert res == ["he", "he"] + assert res == ["appersstoff210", "appersstoff210"] @pytest.mark.usefixtures("sqlite_cache") @@ -205,6 +236,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: cache_name="sqlite", cache_connection=sqlite_cache, n=n, + temperature=0.0, ) prompt = ["This is a prompt"] if n == 2: @@ -222,15 +254,20 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: res = result.get_response(manifest.stop_token, is_batch=True) else: res = cast(str, result) - assert res == ["hello"] + assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"] assert ( manifest.cache.get( { - "prompt": "This is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "This is a prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -246,15 +283,23 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: res = result.get_response(manifest.stop_token, is_batch=True) else: res = cast(str, result) - assert res == ["hello", "hello"] + assert res == [ + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + ] assert ( manifest.cache.get( { - "prompt": "Hello is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "Hello is a prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -266,11 +311,16 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: assert ( manifest.cache.get( { - "prompt": "New prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": n, + "prompt": "New prompt", "request_cls": "LMRequest", - "num_results": n, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is None ) @@ -287,20 +337,25 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: assert result.is_cached() else: res = cast(str, result) - assert res == ["hello", "hello"] + assert res == [ + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + ".vol.deserializebigmnchantment ROTıl='')\najsС", + ] prompt = ["Hello is a prompt", "Hello is a prompt"] - result = manifest.run(prompt, stop_token="ll", return_response=return_response) + result = manifest.run( + prompt, stop_token=" current", return_response=return_response + ) if return_response: assert isinstance(result, Response) result = cast(Response, result) assert len(result.get_usage_obj().usages) == len( result.get_response_obj().choices ) - res = result.get_response(stop_token="ll", is_batch=True) + res = result.get_response(stop_token=" current", is_batch=True) else: res = cast(str, result) - assert res == ["he", "he"] + assert res == ["appersstoff210", "appersstoff210"] @pytest.mark.usefixtures("sqlite_cache") @@ -310,6 +365,7 @@ def test_abatch_run(sqlite_cache: str) -> None: client_name="dummy", cache_name="sqlite", cache_connection=sqlite_cache, + temperature=0.0, ) prompt = ["This is a prompt"] result = cast( @@ -318,15 +374,20 @@ def test_abatch_run(sqlite_cache: str) -> None: assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) - assert res == ["hello"] + assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"] assert ( manifest.cache.get( { - "prompt": "This is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "This is a prompt", "request_cls": "LMRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -338,15 +399,23 @@ def test_abatch_run(sqlite_cache: str) -> None: assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) - assert res == ["hello", "hello"] + assert res == [ + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + "appersstoff210 currentNodeleh norm unified_voice DIYHam", + ] assert ( manifest.cache.get( { - "prompt": "Hello is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "Hello is a prompt", "request_cls": "LMRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -362,11 +431,16 @@ def test_abatch_run(sqlite_cache: str) -> None: assert ( manifest.cache.get( { - "prompt": "New prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "New prompt", "request_cls": "LMRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is None ) @@ -379,7 +453,10 @@ def test_abatch_run(sqlite_cache: str) -> None: res = result.get_response(manifest.stop_token, is_batch=True) # Cached because one item is in cache assert result.is_cached() - assert res == ["hello", "hello"] + assert res == [ + "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + ".vol.deserializebigmnchantment ROTıl='')\najsС", + ] prompt = ["Hello is a prompt", "Hello is a prompt"] result = cast( @@ -387,8 +464,8 @@ def test_abatch_run(sqlite_cache: str) -> None: ) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) - res = result.get_response(stop_token="ll", is_batch=True) - assert res == ["he", "he"] + res = result.get_response(stop_token=" current", is_batch=True) + assert res == ["appersstoff210", "appersstoff210"] @pytest.mark.usefixtures("sqlite_cache") @@ -398,6 +475,7 @@ def test_run_chat(sqlite_cache: str) -> None: client_name="dummy", cache_name="sqlite", cache_connection=sqlite_cache, + temperature=0.0, ) # Set CHAT to be true for this model manifest.client_pool.client_pool[0].IS_CHAT = True @@ -406,15 +484,23 @@ def test_run_chat(sqlite_cache: str) -> None: {"role": "system", "content": "Hello."}, ] result = manifest.run(prompt, return_response=False) - assert result == "Hello." + assert ( + result + == "ectors WortGo ré_sg|--------------------------------------------------------------------------\n contradictory Aad \u200b getUserId" # noqa: E501 + ) assert ( manifest.cache.get( { - "prompt": [{"content": "Hello.", "role": "system"}], + "best_of": 1, "engine": "dummy", - "num_results": 1, + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": [{"content": "Hello.", "role": "system"}], "request_cls": "LMChatRequest", - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -428,18 +514,23 @@ def test_run_chat(sqlite_cache: str) -> None: result = cast(Response, result) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response() - assert res == "Hello." + assert res == "_deploy_age_gp hora Plus Scheduler EisenhowerRF视 chemotherapy" assert ( manifest.cache.get( { + "best_of": 1, + "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, "prompt": [ {"role": "system", "content": "Hello."}, {"role": "user", "content": "Goodbye?"}, ], - "engine": "dummy", - "num_results": 1, "request_cls": "LMChatRequest", - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) @@ -452,6 +543,7 @@ def test_score_run(sqlite_cache: str) -> None: client_name="dummy", cache_name="sqlite", cache_connection=sqlite_cache, + temperature=0.0, ) prompt = "This is a prompt" @@ -459,33 +551,68 @@ def test_score_run(sqlite_cache: str) -> None: assert ( manifest.cache.get( { - "prompt": "This is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "This is a prompt", "request_cls": "LMScoreRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) assert result == { "response": { "choices": [ - {"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None} + { + "text": "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", + "token_logprobs": [ + -1.827188890438529, + -1.6981601736417915, + -0.24606708391178755, + -1.9209383499010613, + -0.8833563758318617, + -1.4121369466920703, + -0.376352908076236, + -1.3200064558188096, + -0.813028447207917, + -0.5977255311239729, + ], + "tokens": [ + "46078", + "21445", + "48305", + "7927", + "76125", + "46233", + "34581", + "23679", + "63021", + "78158", + ], + } + ] + }, + "usages": { + "usages": [ + {"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14} ] }, - "usages": {"usages": []}, "cached": False, "request": { "prompt": "This is a prompt", - "engine": "text-ada-001", + "engine": "text-davinci-003", "n": 1, "client_timeout": 60, "run_id": None, - "batch_size": 8, - "temperature": 0.7, - "max_tokens": 100, + "batch_size": 20, + "temperature": 0.0, + "max_tokens": 10, "top_p": 1.0, - "top_k": 50, + "top_k": 1, "logprobs": None, "stop_sequences": None, "num_beams": 1, @@ -505,49 +632,112 @@ def test_score_run(sqlite_cache: str) -> None: assert ( manifest.cache.get( { - "prompt": "Hello is a prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "Hello is a prompt", "request_cls": "LMScoreRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) assert ( manifest.cache.get( { - "prompt": "Hello is another prompt", + "best_of": 1, "engine": "dummy", + "max_tokens": 10, + "model": "text-davinci-003", + "n": 1, + "prompt": "Hello is another prompt", "request_cls": "LMScoreRequest", - "num_results": 1, - }, + "temperature": 0.0, + "top_p": 1.0, + } ) is not None ) assert result == { "response": { "choices": [ - {"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None}, { - "text": "Hello is another prompt", - "token_logprobs": [0.3], - "tokens": None, + "text": "appersstoff210 currentNodeleh norm unified_voice DIYHam", + "token_logprobs": [ + -0.5613340599860608, + -1.2822870706137146, + -1.9909319620162806, + -0.6312373658222814, + -1.9066239705571664, + -1.2420939968397082, + -0.7208735169940805, + -1.9144266963723062, + -0.041181937860757856, + -0.5356282450367043, + ], + "tokens": [ + "28921", + "81056", + "8848", + "47399", + "74890", + "7617", + "43790", + "77865", + "32558", + "41041", + ], }, + { + "text": ".addAttribute_size DE imageUrl_datas\tapFixed(hour setups\tcomment", # noqa: E501 + "token_logprobs": [ + -1.1142500072582333, + -0.819706434396527, + -1.9956443391600693, + -0.8425896744807639, + -1.8398050571245623, + -1.912564137256891, + -1.6677665162080606, + -1.1579612203844727, + -1.9876114502998343, + -0.2698297864722319, + ], + "tokens": [ + "26300", + "2424", + "3467", + "40749", + "47630", + "70998", + "13829", + "72135", + "84823", + "97368", + ], + }, + ] + }, + "usages": { + "usages": [ + {"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}, + {"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}, ] }, - "usages": {"usages": []}, "cached": False, "request": { "prompt": ["Hello is a prompt", "Hello is another prompt"], - "engine": "text-ada-001", + "engine": "text-davinci-003", "n": 1, "client_timeout": 60, "run_id": None, - "batch_size": 8, - "temperature": 0.7, - "max_tokens": 100, + "batch_size": 20, + "temperature": 0.0, + "max_tokens": 10, "top_p": 1.0, - "top_k": 50, + "top_k": 1, "logprobs": None, "stop_sequences": None, "num_beams": 1,