Skip to content

Commit

Permalink
Cleaning up code
Browse files Browse the repository at this point in the history
  • Loading branch information
wade3han committed Dec 5, 2021
1 parent a1a46a0 commit 4fe6739
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions src/mauve/compute_mauve.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def fast_compute_mauve(
kmeans_num_redo=5, kmeans_max_iter=500,
featurize_model_name='gpt2-large', device_id=-1, max_text_length=1024,
divergence_curve_discretization_size=25, mauve_scaling_factor=5,
verbose=False, seed=25
verbose=False, seed=25, batch_size=8,
):
"""
Compute the MAUVE score between two text generations P and Q.
Expand Down Expand Up @@ -72,6 +72,7 @@ def fast_compute_mauve(
See `Best Practices <index.html#best-practices-for-mauve>`_ for details.
:param ``verbose``: If True, print running time updates.
:param ``seed``: random seed to initialize k-means cluster assignments.
:param ``batch_size``: Batch size for feature extraction
:return: an object with fields p_hist, q_hist, divergence_curve and mauve.
Expand All @@ -88,11 +89,11 @@ def fast_compute_mauve(
raise ValueError('Supply at least one of q_features, q_tokens, q_text')
p_features = fast_get_features_from_input(
p_features, p_tokens, p_text, featurize_model_name, max_text_length,
device_id, name="p", verbose=verbose
device_id, name="p", verbose=verbose, batch_size=batch_size,
)
q_features = fast_get_features_from_input(
q_features, q_tokens, q_text, featurize_model_name, max_text_length,
device_id, name="q", verbose=verbose
device_id, name="q", verbose=verbose, batch_size=batch_size,
)
if num_buckets == 'auto':
# heuristic: use num_clusters = num_generations / 10
Expand Down Expand Up @@ -194,7 +195,6 @@ def compute_mauve(
q_features, q_tokens, q_text, featurize_model_name, max_text_length,
device_id, name="q", verbose=verbose
)
print(p_features.shape)
if num_buckets == 'auto':
# heuristic: use num_clusters = num_generations / 10
num_buckets = max(2, int(round(min(p_features.shape[0], q_features.shape[0]) / 10)))
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_features_from_input(features, tokenized_texts, texts,


def fast_get_features_from_input(features, tokenized_texts, texts,
featurize_model_name, max_len, device_id, name,
featurize_model_name, max_len, device_id, name, batch_size,
verbose=False):
global MODEL, TOKENIZER, MODEL_NAME
if features is None:
Expand Down Expand Up @@ -321,7 +321,10 @@ def fast_get_features_from_input(features, tokenized_texts, texts,
else:
MODEL = MODEL.to(get_device_from_arg(device_id))
if verbose: print('Featurizing tokens')
features = fast_featurize_tokens_from_model(MODEL, tokenized_texts, batch_size=8, name=name).detach().cpu().numpy()
features = fast_featurize_tokens_from_model(MODEL,
tokenized_texts,
batch_size=batch_size,
name=name).detach().cpu().numpy()
else:
features = np.asarray(features)
return features
Expand Down
2 changes: 1 addition & 1 deletion src/mauve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def fast_featurize_tokens_from_model(model, tokenized_texts, batch_size, name=""
chunk_sent_lengths.append([len(_c) for _c in _chunk])
chunk_idx += 1

for chunk, chunk_sent_length in tqdm(zip(chunks, chunk_sent_lengths), desc=f"Featurizing {name}"):
for chunk, chunk_sent_length in tqdm(list(zip(chunks, chunk_sent_lengths)), desc=f"Featurizing {name}"):
padded_chunk = torch.nn.utils.rnn.pad_sequence(chunk,
batch_first=True,
padding_value=0).to(device)
Expand Down

0 comments on commit 4fe6739

Please sign in to comment.