Skip to content

Commit

Permalink
Implement batch-supported mauve
Browse files Browse the repository at this point in the history
  • Loading branch information
wade3han committed Feb 4, 2022
1 parent d60d6fd commit fc4a602
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 21 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ You can also use different forms as inputs for `p` and `q`, e.g.,
- `mauve_scaling_factor`: "c" from the paper. Default 5.
- `verbose`: If True (default), print running time updates
- `seed`: random seed to initialize *k*-means cluster assignments.
- `batch_size`: Batch size for feature extraction.

Note: `p` and `q` can be of different lengths, but it is
recommended that they are the same length.
Expand All @@ -167,8 +168,6 @@ If you would like to contribute, please submit a pull request.
We encourage and highly value community contributions.

Some features which would be good to have are:
- batched implementation featurization (current implementation sequentially featurizes generations);
this requires appropriate padding/masking
- featurization in HuggingFace Transformers with a TensorFlow backend.

## Best Practices for MAUVE
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ scikit-learn
faiss
tqdm
requests
pytest
## Optional
# torch
# transformers
15 changes: 9 additions & 6 deletions src/mauve/compute_mauve.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute_mauve(
kmeans_num_redo=5, kmeans_max_iter=500,
featurize_model_name='gpt2-large', device_id=-1, max_text_length=1024,
divergence_curve_discretization_size=25, mauve_scaling_factor=5,
verbose=False, seed=25
verbose=False, seed=25, batch_size=1, use_float64=False,
):

"""
Expand Down Expand Up @@ -69,6 +69,7 @@ def compute_mauve(
See `Best Practices <index.html#best-practices-for-mauve>`_ for details.
:param ``verbose``: If True, print running time updates.
:param ``seed``: random seed to initialize k-means cluster assignments.
:param ``batch_size``: Batch size for feature extraction
:return: an object with fields p_hist, q_hist, divergence_curve and mauve.
Expand All @@ -85,11 +86,11 @@ def compute_mauve(
raise ValueError('Supply at least one of q_features, q_tokens, q_text')
p_features = get_features_from_input(
p_features, p_tokens, p_text, featurize_model_name, max_text_length,
device_id, name="p", verbose=verbose
device_id, name="p", verbose=verbose, batch_size=batch_size, use_float64=use_float64,
)
q_features = get_features_from_input(
q_features, q_tokens, q_text, featurize_model_name, max_text_length,
device_id, name="q", verbose=verbose
device_id, name="q", verbose=verbose, batch_size=batch_size, use_float64=use_float64,
)
if num_buckets == 'auto':
# heuristic: use num_clusters = num_generations / 10
Expand Down Expand Up @@ -131,8 +132,8 @@ def compute_mauve(
return to_return

def get_features_from_input(features, tokenized_texts, texts,
featurize_model_name, max_len, device_id, name,
verbose=False):
featurize_model_name, max_len, device_id, name, batch_size,
verbose=False, use_float64=False):
global MODEL, TOKENIZER, MODEL_NAME
if features is None:
# Featurizing is necessary. Make sure the required packages are available
Expand Down Expand Up @@ -169,8 +170,10 @@ def get_features_from_input(features, tokenized_texts, texts,
MODEL_NAME = featurize_model_name
else:
MODEL = MODEL.to(get_device_from_arg(device_id))
if use_float64:
MODEL = MODEL.double()
if verbose: print('Featurizing tokens')
features = featurize_tokens_from_model(MODEL, tokenized_texts, name).detach().cpu().numpy()
features = featurize_tokens_from_model(MODEL, tokenized_texts, batch_size, name).detach().cpu().numpy()
else:
features = np.asarray(features)
return features
Expand Down
45 changes: 32 additions & 13 deletions src/mauve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,44 @@ def decode_samples_from_lst(tokenizer, tokenized_texts):
return output

@torch.no_grad()
def featurize_tokens_from_model(model, tokenized_texts, name="", verbose=False):
"""Featurize tokenized texts using models
def featurize_tokens_from_model(model, tokenized_texts, batch_size, name="", verbose=False):
"""Featurize tokenized texts using models, support batchify
:param model: HF Transformers model
:param batch_size: Batch size used during forward pass
:param tokenized_texts: list of torch.LongTensor of shape (1, length)
:param verbose: If True, print status and time
:return:
"""
device = next(model.parameters()).device
t1 = time.time()
feats = []
for sen in tqdm(tokenized_texts, desc=f"Featurizing {name}"):
if isinstance(sen, list):
sen = torch.LongTensor(sen).unsqueeze(0)
sen = sen.to(device)
outs = model(input_ids=sen, past_key_values=None,
output_hidden_states=True, return_dict=True)
h = outs.hidden_states[-1] # (batch_size, seq_len, dim)
feats.append(h[:, -1, :].cpu())
feats, chunks, chunk_sent_lengths = [], [], []
chunk_idx = 0

while chunk_idx * batch_size < len(tokenized_texts):
_chunk = [_t.view(-1) for _t in tokenized_texts[chunk_idx * batch_size: (chunk_idx + 1) * batch_size]]
chunks.append(_chunk)
chunk_sent_lengths.append([len(_c) for _c in _chunk])
chunk_idx += 1

for chunk, chunk_sent_length in tqdm(list(zip(chunks, chunk_sent_lengths)), desc=f"Featurizing {name}"):
padded_chunk = torch.nn.utils.rnn.pad_sequence(chunk,
batch_first=True,
padding_value=0).to(device)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.ones(sent_length).long() for sent_length in chunk_sent_length],
batch_first=True,
padding_value=0).to(device)
outs = model(input_ids=padded_chunk,
attention_mask=attention_mask,
past_key_values=None,
output_hidden_states=True,
return_dict=True)
h = []
for hidden_state, sent_length in zip(outs.hidden_states[-1], chunk_sent_length):
h.append(hidden_state[sent_length - 1])
h = torch.stack(h, dim=0)
feats.append(h.cpu())
t2 = time.time()
if verbose: print(f'Featurize time: {round(t2-t1, 2)}')
if verbose:
print(f'Featurize time: {round(t2-t1, 2)}')
return torch.cat(feats)
Empty file added tests/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tests/test_mauve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import math

import numpy as np
import pytest

import mauve
from examples import load_gpt2_dataset
from mauve.compute_mauve import get_features_from_input


class TestMauve:
@pytest.fixture(scope="class")
def human_texts(self):
return load_gpt2_dataset('data/amazon.valid.jsonl', num_examples=100)

@pytest.fixture(scope="class")
def generated_texts(self):
return load_gpt2_dataset('data/amazon-xl-1542M.valid.jsonl', num_examples=100)

@pytest.mark.parametrize(
"batch_size",
[16, 8, 4, 3, 2],
)
def test_batchify_mauve(self, human_texts, generated_texts, batch_size):
out = mauve.compute_mauve(p_text=human_texts,
q_text=generated_texts,
device_id=0,
max_text_length=256,
batch_size=batch_size,
verbose=False,
use_float64=True)
assert math.isclose(out.mauve, 0.99168, abs_tol=1e-4), f"{out.mauve} != 0.99168"

def test_default_mauve(self, human_texts, generated_texts):
out = mauve.compute_mauve(p_text=human_texts,
q_text=generated_texts,
device_id=0,
max_text_length=256,
verbose=False,
use_float64=True)
assert math.isclose(out.mauve, 0.99168, abs_tol=1e-4)

@pytest.mark.parametrize(
"batch_size",
[16, 8, 4, 3, 2],
)
def test_batchify_mauve_feature_level(self, human_texts, batch_size):
p_features_original = get_features_from_input(
None, None, human_texts, 'gpt2-large', 1024,
-1, name="p", verbose=False, batch_size=1, use_float64=True,
)
p_features_batched = get_features_from_input(
None, None, human_texts, 'gpt2-large', 1024,
-1, name="p", verbose=False, batch_size=batch_size, use_float64=True,
)
norm_of_difference = np.linalg.norm(p_features_original - p_features_batched, axis=1) # shape = (n,)
# ensure that new features are close to old features
assert np.max(norm_of_difference) < 1e-5 * np.max(np.linalg.norm(p_features_original, axis=1))

0 comments on commit fc4a602

Please sign in to comment.