Skip to content

Commit

Permalink
Update mauve implementation and pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
wade3han committed Feb 3, 2022
1 parent 08e8cd8 commit 593e108
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/mauve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_model(model_name, tokenizer, device_id):
if 'gpt2' in model_name:
model = AutoModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(device)
model = model.eval()
model = model.double()
else:
raise ValueError(f'Unknown model: {model_name}')
return model
Expand Down
40 changes: 29 additions & 11 deletions tests/test_mauve.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math

import mauve
import numpy as np
import pytest

import mauve
from examples import load_gpt2_dataset
from mauve.compute_mauve import get_features_from_input


class TestMauve:

@pytest.fixture(scope="class")
def human_texts(self):
return load_gpt2_dataset('data/amazon.valid.jsonl', num_examples=100)
Expand All @@ -16,23 +17,40 @@ def human_texts(self):
def generated_texts(self):
return load_gpt2_dataset('data/amazon-xl-1542M.valid.jsonl', num_examples=100)

def test_default_mauve(self, human_texts, generated_texts):
@pytest.mark.parametrize(
"batch_size",
[16, 8, 4, 3, 2],
)
def test_batchify_mauve(self, human_texts, generated_texts, batch_size):
out = mauve.compute_mauve(p_text=human_texts,
q_text=generated_texts,
device_id=0,
max_text_length=256,
batch_size=batch_size,
verbose=False)
assert math.isclose(out.mauve, 0.9917, abs_tol=1e-4)
assert math.isclose(out.mauve, 0.99168, abs_tol=1e-4), f"{out.mauve} != 0.99168"

@pytest.mark.parametrize(
"batch_size",
[1, 2, 3, 4, 8, 16],
)
def test_batchify_mauve(self, human_texts, generated_texts, batch_size):
def test_default_mauve(self, human_texts, generated_texts):
out = mauve.compute_mauve(p_text=human_texts,
q_text=generated_texts,
device_id=0,
max_text_length=256,
batch_size=batch_size,
verbose=False)
assert math.isclose(out.mauve, 0.9917, abs_tol=1e-4)
assert math.isclose(out.mauve, 0.99168, abs_tol=1e-4)

@pytest.mark.parametrize(
"batch_size",
[16, 8, 4, 3, 2],
)
def test_batchify_mauve_feature_level(self, human_texts, batch_size):
p_features_original = get_features_from_input(
None, None, human_texts, 'gpt2-large', 1024,
-1, name="p", verbose=False, batch_size=1,
)
p_features_batched = get_features_from_input(
None, None, human_texts, 'gpt2-large', 1024,
-1, name="p", verbose=False, batch_size=batch_size,
)
norm_of_difference = np.linalg.norm(p_features_original - p_features_batched, axis=1) # shape = (n,)
# ensure that new features are close to old features
assert np.max(norm_of_difference) < 1e-5 * np.max(np.linalg.norm(p_features_original, axis=1))

0 comments on commit 593e108

Please sign in to comment.