From 4fe67390b99a6eb623beafdc5e7a70684a1226a4 Mon Sep 17 00:00:00 2001 From: Seungju Han Date: Sun, 5 Dec 2021 23:24:05 +0900 Subject: [PATCH] Cleaning up code --- src/mauve/compute_mauve.py | 15 +++++++++------ src/mauve/utils.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/mauve/compute_mauve.py b/src/mauve/compute_mauve.py index 4b60d00..21ffcb4 100644 --- a/src/mauve/compute_mauve.py +++ b/src/mauve/compute_mauve.py @@ -42,7 +42,7 @@ def fast_compute_mauve( kmeans_num_redo=5, kmeans_max_iter=500, featurize_model_name='gpt2-large', device_id=-1, max_text_length=1024, divergence_curve_discretization_size=25, mauve_scaling_factor=5, - verbose=False, seed=25 + verbose=False, seed=25, batch_size=8, ): """ Compute the MAUVE score between two text generations P and Q. @@ -72,6 +72,7 @@ def fast_compute_mauve( See `Best Practices `_ for details. :param ``verbose``: If True, print running time updates. :param ``seed``: random seed to initialize k-means cluster assignments. + :param ``batch_size``: Batch size for feature extraction :return: an object with fields p_hist, q_hist, divergence_curve and mauve. @@ -88,11 +89,11 @@ def fast_compute_mauve( raise ValueError('Supply at least one of q_features, q_tokens, q_text') p_features = fast_get_features_from_input( p_features, p_tokens, p_text, featurize_model_name, max_text_length, - device_id, name="p", verbose=verbose + device_id, name="p", verbose=verbose, batch_size=batch_size, ) q_features = fast_get_features_from_input( q_features, q_tokens, q_text, featurize_model_name, max_text_length, - device_id, name="q", verbose=verbose + device_id, name="q", verbose=verbose, batch_size=batch_size, ) if num_buckets == 'auto': # heuristic: use num_clusters = num_generations / 10 @@ -194,7 +195,6 @@ def compute_mauve( q_features, q_tokens, q_text, featurize_model_name, max_text_length, device_id, name="q", verbose=verbose ) - print(p_features.shape) if num_buckets == 'auto': # heuristic: use num_clusters = num_generations / 10 num_buckets = max(2, int(round(min(p_features.shape[0], q_features.shape[0]) / 10))) @@ -282,7 +282,7 @@ def get_features_from_input(features, tokenized_texts, texts, def fast_get_features_from_input(features, tokenized_texts, texts, - featurize_model_name, max_len, device_id, name, + featurize_model_name, max_len, device_id, name, batch_size, verbose=False): global MODEL, TOKENIZER, MODEL_NAME if features is None: @@ -321,7 +321,10 @@ def fast_get_features_from_input(features, tokenized_texts, texts, else: MODEL = MODEL.to(get_device_from_arg(device_id)) if verbose: print('Featurizing tokens') - features = fast_featurize_tokens_from_model(MODEL, tokenized_texts, batch_size=8, name=name).detach().cpu().numpy() + features = fast_featurize_tokens_from_model(MODEL, + tokenized_texts, + batch_size=batch_size, + name=name).detach().cpu().numpy() else: features = np.asarray(features) return features diff --git a/src/mauve/utils.py b/src/mauve/utils.py index 3cfe9c7..6911e37 100644 --- a/src/mauve/utils.py +++ b/src/mauve/utils.py @@ -129,7 +129,7 @@ def fast_featurize_tokens_from_model(model, tokenized_texts, batch_size, name="" chunk_sent_lengths.append([len(_c) for _c in _chunk]) chunk_idx += 1 - for chunk, chunk_sent_length in tqdm(zip(chunks, chunk_sent_lengths), desc=f"Featurizing {name}"): + for chunk, chunk_sent_length in tqdm(list(zip(chunks, chunk_sent_lengths)), desc=f"Featurizing {name}"): padded_chunk = torch.nn.utils.rnn.pad_sequence(chunk, batch_first=True, padding_value=0).to(device)