Skip to content

Commit

Permalink
tests for dry_run in completer_for
Browse files Browse the repository at this point in the history
  • Loading branch information
phelps-sg committed Nov 10, 2023
1 parent f0f1b0e commit a2ca6c4
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 16 deletions.
30 changes: 26 additions & 4 deletions llm_cooperation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, Hashable, Iterable, List, Protocol, Tuple, TypeVar
from typing import (
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Tuple,
TypeVar,
)

import numpy as np
import openai_pygenerator
import pandas as pd
from openai_pygenerator import Completer
from openai_pygenerator import Completer, Completion, Completions, History
from plotly.basedatatypes import itertools

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,6 +61,7 @@ class ModelSetup:
model: str
temperature: float
max_tokens: int
dry_run: Optional[str] = None


Experiment = Callable[[ModelSetup, int], Results]
Expand All @@ -71,6 +82,10 @@ class ModelSetup:
Payoffs = Tuple[float, float]


def assistant_message(description: str) -> Completion:
return {"role": "assistant", "content": description}


def all_combinations(grid: Grid) -> itertools.product:
return itertools.product(*grid.values())

Expand All @@ -85,12 +100,12 @@ def settings_from_combinations(


def randomized(n: int, grid: Grid) -> Iterable[Settings]:
variables = list(grid.keys())
keys = list(grid.keys())
combinations = list(all_combinations(grid))
num_combinations = len(combinations)
for __i__ in range(n):
random_index: int = int(np.random.randint(num_combinations))
yield settings_from_combinations(variables, combinations[random_index])
yield settings_from_combinations(keys, combinations[random_index])


def exhaustive(grid: Grid) -> Iterable[Settings]:
Expand All @@ -104,6 +119,13 @@ def amount_as_str(amount: float) -> str:


def completer_for(model_setup: ModelSetup) -> Completer:
if model_setup.dry_run is not None:
dummy_completions: Tuple[Completion] = (assistant_message(model_setup.dry_run),)

def dummy_completer(__history__: History, __n__: int) -> Completions:
return iter(dummy_completions)

return dummy_completer
return openai_pygenerator.completer(
model=model_setup.model,
temperature=model_setup.temperature,
Expand Down
4 changes: 2 additions & 2 deletions llm_cooperation/experiments/dilemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from openai_pygenerator import Completion

from llm_cooperation import ModelSetup, Payoffs, Settings, exhaustive
from llm_cooperation import ModelSetup, Payoffs, Settings, randomized
from llm_cooperation.experiments import AI_PARTICIPANTS, run_and_record_experiment
from llm_cooperation.gametypes import simultaneous
from llm_cooperation.gametypes.repeated import (
Expand Down Expand Up @@ -196,7 +196,7 @@ def run(model_setup: ModelSetup, sample_size: int) -> RepeatedGameResults:
extract_choice=extract_choice_pd,
next_round=simultaneous.next_round,
analyse_rounds=simultaneous.analyse_rounds,
participant_condition_sampling=exhaustive,
participant_condition_sampling=partial(randomized, n=10),
model_setup=model_setup,
)
measurement_setup: MeasurementSetup[DilemmaChoice] = MeasurementSetup(
Expand Down
3 changes: 2 additions & 1 deletion llm_cooperation/gametypes/repeated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Payoffs,
Results,
Settings,
exhaustive,
)
from llm_cooperation.gametypes import PromptGenerator, start_game

Expand Down Expand Up @@ -52,7 +53,7 @@ class GameSetup(Generic[CT, RT]):
payoffs: PayoffFunction[CT]
extract_choice: ChoiceExtractor[CT]
model_setup: ModelSetup
participant_condition_sampling: Callable[[Grid], Iterable[Settings]]
participant_condition_sampling: Callable[[Grid], Iterable[Settings]] = exhaustive


@dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_alternating.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from openai_pygenerator import Completion, user_message

from llm_cooperation import assistant_message
from llm_cooperation.experiments.ultimatum import (
Accept,
ProposerChoice,
Expand All @@ -12,7 +13,6 @@
)
from llm_cooperation.gametypes import alternating
from llm_cooperation.gametypes.repeated import Choices
from tests.test_ultimatum import assistant_message


@pytest.mark.parametrize(
Expand Down
16 changes: 14 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial

import pytest
from openai_pygenerator import Callable
from openai_pygenerator import Callable, content, user_message

from llm_cooperation import ModelSetup, Settings, exhaustive, randomized
from llm_cooperation import ModelSetup, Settings, completer_for, exhaustive, randomized
from llm_cooperation.main import (
Configuration,
Grid,
Expand Down Expand Up @@ -59,6 +59,18 @@ def test_run_all(mocker, grid):
assert run_and_record.call_count == 6 * len(list(experiments.items()))


def test_dry_run():
model_setup = ModelSetup(
model="test-model",
temperature=0.2,
max_tokens=100,
dry_run="That is the question.",
)
completer = completer_for(model_setup)
response = list(completer([user_message("To be or not to be")], 1))[0]
assert content(response) == "That is the question."


@pytest.fixture
def grid() -> Grid:
return {"temperature": [0.2, 0.3], "max_tokens": [100], "model": ["x", "y", "z"]}
10 changes: 8 additions & 2 deletions tests/test_repeated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest
from openai_pygenerator import content, user_message

from llm_cooperation import DEFAULT_MODEL_SETUP, Choice, Grid, Group, exhaustive
from llm_cooperation import (
DEFAULT_MODEL_SETUP,
Choice,
Grid,
Group,
assistant_message,
exhaustive,
)
from llm_cooperation.experiments.dilemma import (
Cooperate,
Defect,
Expand All @@ -32,7 +39,6 @@
run_experiment,
)
from llm_cooperation.gametypes.simultaneous import next_round
from tests.test_ultimatum import assistant_message


def test_play_game(mocker):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_ultimatum.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def test_ultimatum_choice():
assert Reject == ResponderChoice(ResponderEnum.Reject)


def assistant_message(description):
return {"role": "assistant", "content": description}


@pytest.mark.parametrize(
"user_response, user_proposal",
[(Accept, ProposerChoice(5.0)), (Reject, ProposerChoice(10.0))],
Expand Down

0 comments on commit a2ca6c4

Please sign in to comment.