Skip to content

Commit

Permalink
Test rating arena isolation
Browse files Browse the repository at this point in the history
  • Loading branch information
SupraSummus committed Nov 8, 2024
1 parent 3c58d5e commit a5c0cce
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
11 changes: 8 additions & 3 deletions warriors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,18 @@ def map_field_name(self, field_name):
return field_name
if field_name in (
'warrior_1',
'warrior_2',
'warrior_1_id',
'warrior_2',
'warrior_2_id',
'warrior_arena_1',
'warrior_arena_2',
):
return self.map_field_name_x(field_name)
if field_name in (
'text_unit_1_2',
'text_unit_1_2_id',
'text_unit_2_1',
'text_unit_2_1_id',
'finish_reason_1_2',
'finish_reason_2_1',
'llm_version_1_2',
Expand Down Expand Up @@ -596,8 +598,11 @@ def map_field_name_x_x(self, field_name):
if self.viewpoint == '1':
return field_name
elif self.viewpoint == '2':
base, n, m = field_name.rsplit('_', 2)
return f'{base}_{m}_{n}'
if '1_2' in field_name:
return field_name.replace('1_2', '2_1')
elif '2_1' in field_name:
return field_name.replace('2_1', '1_2')
assert False

def map_field_name_x_x_x(self, field_name):
if self.viewpoint == '1':
Expand Down
63 changes: 62 additions & 1 deletion warriors/rating_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from .models import WarriorArena
from .rating_models import update_rating
from .tests.factories import BattleFactory, WarriorArenaFactory
from .tests.factories import (
ArenaFactory, BattleFactory, WarriorArenaFactory, WarriorFactory,
)


@pytest.mark.django_db
Expand Down Expand Up @@ -45,6 +47,65 @@ def test_update_rating_takes_newer_battles(battle):
assert warrior_arena_1.rating > warrior_arena_2.rating


@pytest.mark.django_db
def test_rating_is_isolated_for_each_arena():
now = timezone.now()
warrior_1 = WarriorFactory()
warrior_2 = WarriorFactory()
if warrior_1.id > warrior_2.id:
warrior_1, warrior_2 = warrior_2, warrior_1

arena_1 = ArenaFactory()
warrior_1_arena_1 = WarriorArenaFactory(warrior=warrior_1, arena=arena_1)
warrior_2_arena_1 = WarriorArenaFactory(warrior=warrior_2, arena=arena_1)
BattleFactory(
arena=arena_1,
warrior_1=warrior_1,
warrior_2=warrior_2,
resolved_at_1_2=now,
lcs_len_1_2_1=10,
lcs_len_1_2_2=1,
resolved_at_2_1=now,
lcs_len_2_1_1=10,
lcs_len_2_1_2=1,
)

arena_2 = ArenaFactory()
warrior_1_arena_2 = WarriorArenaFactory(warrior=warrior_1, arena=arena_2)
warrior_2_arena_2 = WarriorArenaFactory(warrior=warrior_2, arena=arena_2)
BattleFactory(
arena=arena_2,
warrior_1=warrior_1,
warrior_2=warrior_2,
resolved_at_1_2=now,
lcs_len_1_2_1=1,
lcs_len_1_2_2=10,
resolved_at_2_1=now,
lcs_len_2_1_1=1,
lcs_len_2_1_2=10,
)

for _ in range(2):
warrior_1_arena_1.refresh_from_db()
warrior_1_arena_1.update_rating()
warrior_2_arena_1.refresh_from_db()
warrior_2_arena_1.update_rating()
warrior_1_arena_2.refresh_from_db()
warrior_1_arena_2.update_rating()
warrior_2_arena_2.refresh_from_db()
warrior_2_arena_2.update_rating()

warrior_1_arena_1.refresh_from_db()
warrior_2_arena_1.refresh_from_db()
warrior_1_arena_2.refresh_from_db()
warrior_2_arena_2.refresh_from_db()

assert warrior_1_arena_1.rating > 40
assert warrior_1_arena_1.rating == -warrior_2_arena_1.rating
assert warrior_2_arena_2.rating == -warrior_1_arena_2.rating
assert warrior_1_arena_1.rating == warrior_2_arena_2.rating


@pytest.mark.django_db
@pytest.mark.parametrize('battle', [{
'resolved_at_1_2': timezone.now(),
Expand Down

0 comments on commit a5c0cce

Please sign in to comment.