Skip to content

Commit

Permalink
add correlation study of overfitting
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Feb 2, 2022
1 parent 0891f62 commit 626a8d3
Show file tree
Hide file tree
Showing 7 changed files with 33,596 additions and 25 deletions.
17 changes: 3 additions & 14 deletions cluster/interpolation.py → analysis/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,10 @@

tokens = np.load('tokens.npy')
lm_scores = np.load('scores.npy')
kmeans_tokens = np.load('kmeans_tokens.npy')
kmeans_scores = np.load('kmeans_scores.npy')
knn_scores = np.load('knn_only_scores.npy')

finetuned_centroid_scores = np.load('finetuned_centroids_scores.npy')


assert len(tokens) == len(lm_scores)
assert len(kmeans_tokens) == len(tokens)
assert len(knn_scores) == len(tokens)
assert len(finetuned_centroid_scores) == len(tokens)


# calculate unigram probability
Expand All @@ -43,16 +36,12 @@
# unigram_scores = unigram_scores.astype(np.float32)
# unigram_scores = torch.from_numpy(unigram_scores)

# kmeans_scores = knn_scores

kmeans_scores = finetuned_centroid_scores

lm_scores = torch.from_numpy(lm_scores)
kmeans_scores = torch.from_numpy(kmeans_scores)
knn_scores = torch.from_numpy(knn_scores)

combine_probs = torch.stack([lm_scores, kmeans_scores], dim=0)
combine_probs = torch.stack([lm_scores, knn_scores], dim=0)

with open('finetuned_interpolation_result.txt', 'w') as outfile:
with open('small_interpolation_result.txt', 'w') as outfile:
for lmbda in tqdm(np.linspace(0.0, 0.99, num=50)):
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - lmbda)
Expand Down
51 changes: 51 additions & 0 deletions analysis/overfit_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import math

from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm

import numpy as np
import torch

from fairseq.data import Dictionary

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))

bpe_cont = "@@"
bpe_toks = {
i
for i in range(len(dictionary))
if dictionary[i].endswith(bpe_cont)
}

bpe_len = len(bpe_cont)

tokens = np.load('tokens.npy')
lm_scores = np.load('scores.npy')
knn_scores = np.load('knn_only_scores.npy')

# calculate Pearson's correlation
corr, _ = pearsonr(lm_scores, knn_scores)
print('LM vs KNN Pearsons correlation: %.3f' % corr)

# calculate Pearson's correlation
corr, _ = spearmanr(lm_scores, knn_scores)
print('LM vs KNN Spearmans correlation: %.3f' % corr)


for ep in [10, 20, 30, 50, 100, 150, 200]:
print(ep)
overfit_scores = np.load('overfit_lm_scores_checkpoint' + str(ep) + '.npy')
assert len(tokens) == len(lm_scores)
assert len(knn_scores) == len(tokens)
assert len(overfit_scores) == len(tokens)

# calculate Pearson's correlation
corr, _ = pearsonr(overfit_scores, knn_scores)
print('OverfitLM vs KNN Pearsons correlation: %.3f' % corr)

# calculate Pearson's correlation
corr, _ = spearmanr(overfit_scores, knn_scores)
print('OverfitLM vs KNN Spearmans correlation: %.3f' % corr)


6 changes: 3 additions & 3 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def load_checkpoint(
state = checkpoint_utils.load_checkpoint_to_cpu(filename)

# KNNLM-distill: handling replaced output embedding weight matrix shape
if self.get_model().decoder.embed_out.shape[0] != state["model"]['decoder.embed_out'].shape[0]:
self.get_model().decoder.embed_out = \
torch.nn.Parameter(self.get_model().decoder.embed_out.new(state["model"]['decoder.embed_out'].shape))
# if self.get_model().decoder.embed_out.shape[0] != state["model"]['decoder.embed_out'].shape[0]:
# self.get_model().decoder.embed_out = \
# torch.nn.Parameter(self.get_model().decoder.embed_out.new(state["model"]['decoder.embed_out'].shape))

# load model parameters
try:
Expand Down
4 changes: 2 additions & 2 deletions fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ def main(parsed_args):
))

# np.save('knnlm_tokens.npy', np.concatenate(all_token_ids))
# np.save('finetuned_centroids_scores.npy', np.concatenate(all_scores))
# np.save('overfit_lm_scores_' + parsed_args.path.split('/')[-1].split('.')[0] + '.npy', np.concatenate(all_scores))
if all_knn_scores:
np.save('knn_only_scores.npy', np.concatenate(all_knn_scores))
np.save('recompute_knn_only_scores.npy', np.concatenate(all_knn_scores))

if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
Expand Down
13 changes: 7 additions & 6 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def main(args, init_distributed=False):
criterion = task.build_criterion(args)

logger.info(model)
# if only tune centroid matrix
if args.finetune_centroids:
print('Finetune only centroid matrix!!!')
for name, param in model.named_parameters():
if name != 'decoder.embed_out':
param.requires_grad = False

logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
logger.info('num. model params: {} (num. trained: {})'.format(
sum(p.numel() for p in model.parameters()),
Expand All @@ -83,12 +90,6 @@ def main(args, init_distributed=False):
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

# if only tune centroid matrix
if args.finetune_centroids:
for name, param in model.named_parameters():
if name != 'decoder.embed_out':
param.requires_grad = False

# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
Expand Down
Loading

0 comments on commit 626a8d3

Please sign in to comment.