Skip to content

Commit

Permalink
Implement batch-supported mauve
Browse files Browse the repository at this point in the history
  • Loading branch information
wade3han committed Dec 5, 2021
1 parent d60d6fd commit dc8e84a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/mauve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .compute_mauve import compute_mauve
from .compute_mauve import fast_compute_mauve

__all__ = ['compute_mauve']
__all__ = ['compute_mauve', 'fast_compute_mauve']
145 changes: 145 additions & 0 deletions src/mauve/compute_mauve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sklearn.decomposition import PCA
from sklearn.metrics import auc as compute_area_under_curve

from .utils import fast_featurize_tokens_from_model

try:
import torch
FOUND_TORCH = True
Expand Down Expand Up @@ -130,6 +132,101 @@ def compute_mauve(
)
return to_return

def fast_compute_mauve(
p_features=None, q_features=None,
p_tokens=None, q_tokens=None,
p_text=None, q_text=None,
num_buckets='auto', pca_max_data=-1, kmeans_explained_var=0.9,
kmeans_num_redo=5, kmeans_max_iter=500,
featurize_model_name='gpt2-large', device_id=-1, max_text_length=1024,
divergence_curve_discretization_size=25, mauve_scaling_factor=5,
verbose=False, seed=25, batch_size=8,
):
"""
Compute the MAUVE score between two text generations P and Q.
P is either specified as ``p_features``, ``p_tokens``, or ``p_text``. Same with Q.
:param ``p_features``: ``numpy.ndarray`` of shape (n, d), where n is the number of generations.
:param ``q_features``: ``numpy.ndarray`` of shape (n, d), where n is the number of generations.
:param ``p_tokens``: list of length n, each entry is torch.LongTensor of shape (1, length).
:param ``q_tokens``: list of length n, each entry is torch.LongTensor of shape (1, length).
:param ``p_text``: list of length n, each entry is a string.
:param ``q_text``: list of length n, each entry is a string.
:param ``num_buckets``: the size of the histogram to quantize P and Q. Options: ``'auto'`` (default, which is n/10) or an integer.
:param ``pca_max_data``: the number data points to use for PCA. If `-1`, use all the data. Default -1.
:param ``kmeans_explained_var``: amount of variance of the data to keep in dimensionality reduction by PCA. Default 0.9.
:param ``kmeans_num_redo``: number of times to redo k-means clustering (the best objective is kept). Default 5.
Try reducing this to 1 in order to reduce running time.
:param ``kmeans_max_iter``: maximum number of k-means iterations. Default 500.
Try reducing this to 100 in order to reduce running time.
:param ``featurize_model_name``: name of the model from which features are obtained. Default 'gpt2-large'.
We support all models which can be loaded from ``transformers.AutoModel.from_pretrained(featurize_model_name)``.
:param ``device_id``: Device for featurization. Supply gpu_id (e.g. 0 or 3) to use GPU or -1 to use CPU.
:param ``max_text_length``: maximum number of tokens to consider. Default 1024.
:param ``divergence_curve_discretization_size``: Number of points to consider on the divergence curve. Default 25.
Larger values do not offer much of a difference.
:param ``mauve_scaling_factor``: The constant``c`` from the paper. Default 5.
See `Best Practices <index.html#best-practices-for-mauve>`_ for details.
:param ``verbose``: If True, print running time updates.
:param ``seed``: random seed to initialize k-means cluster assignments.
:param ``batch_size``: Batch size for feature extraction
:return: an object with fields p_hist, q_hist, divergence_curve and mauve.
* ``out.mauve`` is a number between 0 and 1, the MAUVE score. Higher values means P is closer to Q.
* ``out.frontier_integral``, a number between 0 and 1. Lower values mean that P is closer to Q.
* ``out.p_hist`` is the obtained histogram for P. Same for ``out.q_hist``.
* ``out.divergence_curve`` contains the points in the divergence curve. It is of shape (m, 2), where m is ``divergence_curve_discretization_size``
"""

if p_features is None and p_tokens is None and p_text is None:
raise ValueError('Supply at least one of p_features, p_tokens, p_text')
if q_features is None and q_tokens is None and q_text is None:
raise ValueError('Supply at least one of q_features, q_tokens, q_text')
p_features = fast_get_features_from_input(
p_features, p_tokens, p_text, featurize_model_name, max_text_length,
device_id, name="p", verbose=verbose, batch_size=batch_size,
)
q_features = fast_get_features_from_input(
q_features, q_tokens, q_text, featurize_model_name, max_text_length,
device_id, name="q", verbose=verbose, batch_size=batch_size,
)
if num_buckets == 'auto':
# heuristic: use num_clusters = num_generations / 10
num_buckets = max(2, int(round(min(p_features.shape[0], q_features.shape[0]) / 10)))
elif not isinstance(num_buckets, int):
raise ValueError('num_buckets is expected to be an integer or "auto"')

# Acutal binning
t1 = time.time()
p, q = cluster_feats(p_features, q_features,
num_clusters=num_buckets,
norm='l2', whiten=False,
pca_max_data=pca_max_data,
explained_variance=kmeans_explained_var,
num_redo=kmeans_num_redo,
max_iter=kmeans_max_iter,
seed=seed, verbose=verbose)
t2 = time.time()
if verbose:
print('total discretization time:', round(t2 - t1, 2), 'seconds')

# Divergence curve and mauve
mixture_weights = np.linspace(1e-6, 1 - 1e-6, divergence_curve_discretization_size)
divergence_curve = get_divergence_curve_for_multinomials(p, q, mixture_weights, mauve_scaling_factor)
x, y = divergence_curve.T
idxs1 = np.argsort(x)
idxs2 = np.argsort(y)
mauve_score = 0.5 * (
compute_area_under_curve(x[idxs1], y[idxs1]) +
compute_area_under_curve(y[idxs2], x[idxs2])
)
fi_score = get_fronter_integral(p, q)
to_return = SimpleNamespace(
p_hist=p, q_hist=q, divergence_curve=divergence_curve,
mauve=mauve_score,
frontier_integral=fi_score,
num_buckets=num_buckets,
)
return to_return

def get_features_from_input(features, tokenized_texts, texts,
featurize_model_name, max_len, device_id, name,
verbose=False):
Expand Down Expand Up @@ -175,6 +272,54 @@ def get_features_from_input(features, tokenized_texts, texts,
features = np.asarray(features)
return features

def fast_get_features_from_input(features, tokenized_texts, texts,
featurize_model_name, max_len, device_id, name, batch_size,
verbose=False):
global MODEL, TOKENIZER, MODEL_NAME
if features is None:
# Featurizing is necessary. Make sure the required packages are available
if not FOUND_TORCH:
raise ModuleNotFoundError(
"""PyTorch not found. Please install PyTorch if you would like to use the featurization.
For details, see `https://github.com/krishnap25/mauve`
and `https://pytorch.org/get-started/locally/`.
""")
if not FOUND_TRANSFORMERS:
raise ModuleNotFoundError(
"""Transformers not found. Please install Transformers if you would like to use the featurization.
For details, see `https://github.com/krishnap25/mauve`
and `https://huggingface.co/transformers/installation.html`.
""")

if tokenized_texts is None:
# tokenize texts
if TOKENIZER is None or MODEL_NAME != featurize_model_name:
if verbose: print('Loading tokenizer')
TOKENIZER = get_tokenizer(featurize_model_name)
if verbose: print('Tokenizing text...')
tokenized_texts = [
TOKENIZER.encode(sen, return_tensors='pt', truncation=True, max_length=max_len)
for sen in texts
]
# use tokenized_texts to featurize
if TOKENIZER is None or MODEL_NAME != featurize_model_name:
if verbose: print('Loading tokenizer')
TOKENIZER = get_tokenizer(featurize_model_name)
if MODEL is None or MODEL_NAME != featurize_model_name:
if verbose: print('Loading model')
MODEL = get_model(featurize_model_name, TOKENIZER, device_id)
MODEL_NAME = featurize_model_name
else:
MODEL = MODEL.to(get_device_from_arg(device_id))
if verbose: print('Featurizing tokens')
features = fast_featurize_tokens_from_model(MODEL,
tokenized_texts,
batch_size=batch_size,
name=name).detach().cpu().numpy()
else:
features = np.asarray(features)
return features

def cluster_feats(p, q, num_clusters,
norm='none', whiten=True,
pca_max_data=-1,
Expand Down
43 changes: 43 additions & 0 deletions src/mauve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,46 @@ def featurize_tokens_from_model(model, tokenized_texts, name="", verbose=False):
t2 = time.time()
if verbose: print(f'Featurize time: {round(t2-t1, 2)}')
return torch.cat(feats)

@torch.no_grad()
def fast_featurize_tokens_from_model(model, tokenized_texts, batch_size, name="", verbose=False):
"""Featurize tokenized texts using models, support batchify
:param model: HF Transformers model
:param batch_size: Batch size used during forward pass
:param tokenized_texts: list of torch.LongTensor of shape (1, length)
:param verbose: If True, print status and time
:return:
"""
device = next(model.parameters()).device
t1 = time.time()
feats, chunks, chunk_sent_lengths = [], [], []
chunk_idx = 0

while chunk_idx * batch_size < len(tokenized_texts):
_chunk = [_t.view(-1) for _t in tokenized_texts[chunk_idx * batch_size: (chunk_idx + 1) * batch_size]]
chunks.append(_chunk)
chunk_sent_lengths.append([len(_c) for _c in _chunk])
chunk_idx += 1

for chunk, chunk_sent_length in tqdm(list(zip(chunks, chunk_sent_lengths)), desc=f"Featurizing {name}"):
padded_chunk = torch.nn.utils.rnn.pad_sequence(chunk,
batch_first=True,
padding_value=0).to(device)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.ones(sent_length).long() for sent_length in chunk_sent_length],
batch_first=True,
padding_value=0).to(device)
outs = model(input_ids=padded_chunk,
attention_mask=attention_mask,
past_key_values=None,
output_hidden_states=True,
return_dict=True)
h = []
for hidden_state, sent_length in zip(outs.hidden_states[-1], chunk_sent_length):
h.append(hidden_state[sent_length - 1])
h = torch.stack(h, dim=0)
feats.append(h.cpu())
t2 = time.time()
if verbose:
print(f'Featurize time: {round(t2-t1, 2)}')
return torch.cat(feats)

0 comments on commit dc8e84a

Please sign in to comment.