Skip to content

Commit

Permalink
Fix tokenizers caching (#502)
Browse files Browse the repository at this point in the history
* fix tokenizers caching

* register regex saver only if regex is available

* style

* add _transformers_available in utils

* quality

* add etst for methods

* test for tokenizers' cache instead of having a list of tokenisers classes with cache
  • Loading branch information
lhoestq authored Aug 19, 2020
1 parent 83ad751 commit 0a45189
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/nlp/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name

try:
import transformers

_transformers_available = True # pylint: disable=invalid-name
logger.info("transformers version {} available.".format(transformers.__version__))
except ImportError:
_transformers_available = False # pylint: disable=invalid-name

hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
Expand Down
48 changes: 43 additions & 5 deletions src/nlp/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import dill
import numpy as np

from .file_utils import _transformers_available


# NOTE: When used on an instance method, the cache is shared across all
# instances and IS NOT per-instance.
Expand Down Expand Up @@ -268,14 +270,29 @@ class Pickler(dill.Pickler):

def dump(obj, file):
"""pickle an object to a file"""
Pickler(file).dump(obj)
Pickler(file, recurse=True).dump(obj)
return


@contextlib.contextmanager
def _no_cache_fields(obj):
if _transformers_available:
import transformers as tr

if isinstance(obj, tr.PreTrainedTokenizerBase) and hasattr(obj, "cache") and isinstance(obj.cache, dict):
with temporary_assignment(obj, "cache", {}):
yield
else:
yield
else:
yield


def dumps(obj):
"""pickle an object to a string"""
file = StringIO()
dump(obj, file)
with _no_cache_fields(obj):
dump(obj, file)
return file.getvalue()


Expand All @@ -288,7 +305,7 @@ def proxy(func):


@pklregister(CodeType)
def save_code(pickler, obj):
def _save_code(pickler, obj):
"""
From dill._dill.save_code
This is a modified version that removes the origin (filename + line no.)
Expand All @@ -297,9 +314,11 @@ def save_code(pickler, obj):
dill._dill.log.info("Co: %s" % obj)
# Filenames of functions created in notebooks or shells start with '<'
# ex: <ipython-input-13-9ed2afe61d25> for ipython, and <stdin> for shell
# Moreover lambda functions have a special name: '<lambda>'
# ex: (lambda x: x).__code__.co_name == "<lambda>" # True
# Only those two lines are different from the original implementation:
co_filename = "" if obj.co_filename.startswith("<") else obj.co_filename
co_firstlineno = 1 if obj.co_filename.startswith("<") else obj.co_firstlineno
co_filename = "" if obj.co_filename.startswith("<") or obj.co_name == "<lambda>" else obj.co_filename
co_firstlineno = 1 if obj.co_filename.startswith("<") or obj.co_name == "<lambda>" else obj.co_firstlineno
# The rest is the same as in the original dill implementation
if dill._dill.PY3:
if hasattr(obj, "co_posonlyargcount"):
Expand Down Expand Up @@ -363,3 +382,22 @@ def save_code(pickler, obj):

def copyfunc(func):
return types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__)


try:
import regex

@pklregister(type(regex.Regex("", 0)))
def _save_regex(pickler, obj):
dill._dill.log.info("Re: %s" % obj)
args = (
obj.pattern,
obj.flags,
)
pickler.save_reduce(regex.compile, args, obj=obj)
dill._dill.log.info("# Re")
return


except ImportError:
pass
134 changes: 134 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from hashlib import md5
from types import CodeType, FunctionType
from unittest import TestCase

import regex

import nlp

from .utils import require_transformers


class Foo:
def __init__(self, foo):
self.foo = foo

def __call__(self):
return self.foo


class TokenizersCachingTest(TestCase):
@require_transformers
def test_hash_tokenizer(self):
from transformers import AutoTokenizer

def encode(x):
return tokenizer(x)

# TODO: add hash consistency tests across sessions
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
hash1 = md5(nlp.utils.dumps(tokenizer)).hexdigest()
hash1_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest()
hash1_encode = md5(nlp.utils.dumps(encode)).hexdigest()
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
hash2 = md5(nlp.utils.dumps(tokenizer)).hexdigest()
hash2_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest()
hash2_encode = md5(nlp.utils.dumps(encode)).hexdigest()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
hash3 = md5(nlp.utils.dumps(tokenizer)).hexdigest()
hash3_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest()
hash3_encode = md5(nlp.utils.dumps(encode)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)
self.assertEqual(hash1_lambda, hash3_lambda)
self.assertNotEqual(hash1_lambda, hash2_lambda)
self.assertEqual(hash1_encode, hash3_encode)
self.assertNotEqual(hash1_encode, hash2_encode)

@require_transformers
def test_hash_tokenizer_with_cache(self):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
hash1 = md5(nlp.utils.dumps(tokenizer)).hexdigest()
tokenizer("Hello world !") # call once to change the tokenizer's cache
hash2 = md5(nlp.utils.dumps(tokenizer)).hexdigest()
self.assertEqual(hash1, hash2)

def test_hash_regex(self):
pat = regex.Regex("foo")
hash1 = md5(nlp.utils.dumps(pat)).hexdigest()
pat = regex.Regex("bar")
hash2 = md5(nlp.utils.dumps(pat)).hexdigest()
pat = regex.Regex("foo")
hash3 = md5(nlp.utils.dumps(pat)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)


class RecurseDumpTest(TestCase):
def test_recurse_dump_for_function(self):
def func():
return foo

foo = [0]
hash1 = md5(nlp.utils.dumps(func)).hexdigest()
foo = [1]
hash2 = md5(nlp.utils.dumps(func)).hexdigest()
foo = [0]
hash3 = md5(nlp.utils.dumps(func)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

def test_recurse_dump_for_class(self):

hash1 = md5(nlp.utils.dumps(Foo([0]))).hexdigest()
hash2 = md5(nlp.utils.dumps(Foo([1]))).hexdigest()
hash3 = md5(nlp.utils.dumps(Foo([0]))).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

def test_recurse_dump_for_method(self):

hash1 = md5(nlp.utils.dumps(Foo([0]).__call__)).hexdigest()
hash2 = md5(nlp.utils.dumps(Foo([1]).__call__)).hexdigest()
hash3 = md5(nlp.utils.dumps(Foo([0]).__call__)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

def test_dump_ipython_function(self):

code_args = (
"co_argcount",
"co_kwonlyargcount",
"co_nlocals",
"co_stacksize",
"co_flags",
"co_code",
"co_consts",
"co_names",
"co_varnames",
"co_filename",
"co_name",
"co_firstlineno",
"co_lnotab",
"co_freevars",
"co_cellvars",
)

def create_ipython_func(co_filename, returned_obj):
def func():
return returned_obj

code = func.__code__
code = CodeType(*[getattr(code, k) if k != "co_filename" else co_filename for k in code_args])
return FunctionType(code, func.__globals__, func.__name__, func.__defaults__, func.__closure__)

co_filename, returned_obj = "<ipython-input-2-e0383a102aae>", [0]
hash1 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest()
co_filename, returned_obj = "<ipython-input-2-e0383a102aae>", [1]
hash2 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest()
co_filename, returned_obj = "<ipython-input-5-713f6613acf3>", [0]
hash3 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)
14 changes: 13 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
from distutils.util import strtobool

from nlp.utils.file_utils import _tf_available, _torch_available
from nlp.utils.file_utils import _tf_available, _torch_available, _transformers_available


def parse_flag_from_env(key, default=False):
Expand Down Expand Up @@ -50,6 +50,18 @@ def require_tf(test_case):
return test_case


def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.
These tests are skipped when transformers isn't installed.
"""
if not _transformers_available:
test_case = unittest.skip("test requires transformers")(test_case)
return test_case


def slow(test_case):
"""
Decorator marking a test as slow.
Expand Down

0 comments on commit 0a45189

Please sign in to comment.