Skip to content

Commit

Permalink
added function to benchmark experimental vocab batch lookup (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet authored Apr 26, 2021
1 parent d2a0776 commit 0790ce6
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions benchmark/benchmark_experimental_vocab.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import argparse
from collections import (Counter, OrderedDict)
import time

import random
import string
from timeit import default_timer as timer
from matplotlib import pyplot as plt
import torch
from torchtext.experimental.datasets import DATASETS
from torchtext.experimental.vocab import (
Expand All @@ -16,6 +19,42 @@
from torchtext.experimental.transforms import basic_english_normalize


def compare_legacy_and_experimental_batch_lookup():
num_tokens = 1000
num_letters = 6
num_lines = 100000
vocab = [''.join(random.sample(string.ascii_letters * num_letters, num_letters)) for _ in range(num_tokens)]
counter = Counter()
counter.update(vocab)
legacy_vocab = Vocab(counter)
experimental_vocab = VocabExperimental(counter)
speed_ups = []
token_lengths = [i for i in range(2, 100)]
for i in token_lengths:
lines = [random.sample(vocab, i) for _ in range(num_lines)]
start_time = timer()
for text in lines:
legacy_vocab.lookup_indices(text)
legacy_time = timer() - start_time

start_time = timer()
for text in lines:
experimental_vocab.lookup_indices(text)

experimental_time = timer() - start_time

speed_ups.append(legacy_time / experimental_time)
print("speed-up={} for average length={}".format(legacy_time / experimental_time, i))
del lines

plt.close()
fig, ax = plt.subplots(1, 1)
ax.plot(token_lengths, speed_ups)
ax.set_xlabel('Average Tokens per line')
ax.set_ylabel('Speed-up')
plt.savefig("speedup.jpg")


def legacy_vocab_from_file_object(file_like_object, **kwargs):
r"""Create a `Vocab` object from a file like object.
Expand Down Expand Up @@ -76,7 +115,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
print("Construction time:", time.monotonic() - t0)


def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset = 'AG_NEWS'):
def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset='AG_NEWS'):
def _run_benchmark_lookup(tokens, vocab):
t0 = time.monotonic()
# list lookup
Expand Down

0 comments on commit 0790ce6

Please sign in to comment.