Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update from_pretrained method #15

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions src/autotiktokenizer/autotiktokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,32 @@

import tiktoken
from huggingface_hub import snapshot_download
from functools import wraps

PRE_DOC = """
Loads a pretrained tokenizer from the specified path or name and returns the TikToken encoding.

Args:
tokenizer_name_or_path (str):
The name or path of the pretrained tokenizer.
repo_type """
POST_DOC = """
Returns:
encoding (Encoding): The TikToken encoding.
"""
def set_doc(is_snapshot=False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
if is_snapshot:
wrapper.__doc__ = snapshot_download.__doc__
return wrapper
middle_doc = (( snapshot_download.__doc__.split("repo_type")[1]).split("Returns:")[0]).strip()
wrapper.__doc__ = PRE_DOC+middle_doc+POST_DOC
return wrapper

return decorator

class AutoTikTokenizer:
"""
Expand Down Expand Up @@ -65,21 +91,13 @@ def _bytes_to_unicode(self) -> Dict[int, str]:
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

@set_doc(is_snapshot=True)
def _download_from_hf_hub(self,
repo_name: str) -> str:
"""
Downloads the necessary files from the HuggingFace Hub for the tokenizer.

Args:
repo_name (str): The name of the repository to download the files from.

Returns:
path (str): The path to the downloaded files.
"""
kwargs) -> str:

# Download all the necessary files from HF Hub
files_needed = ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt']
path = snapshot_download(repo_id=repo_name, allow_patterns=files_needed)
kwargs["allow_patterns"] = ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt']
path = snapshot_download(**kwargs)
files_downloaded = os.listdir(path)

# Assertions to make sure the necessary files are there
Expand Down Expand Up @@ -253,18 +271,10 @@ def _detect_tokenizer_type(self, tokenizer: Dict, tokenizer_config: Dict) -> st
if tokenizer.get('model', '').get('type', '').lower() == 'wordpiece':
return 'wordpiece'
return 'bpe'

@classmethod
def from_pretrained(cls, tokenizer_name_or_path: str) -> tiktoken.Encoding:
"""
Loads a pretrained tokenizer from the specified path or name and returns the TikToken encoding.

Args:
tokenizer_name_or_path (str): The name or path of the pretrained tokenizer.

Returns:
encoding (Encoding): The TikToken encoding.
"""
@set_doc()
def from_pretrained(cls, tokenizer_name_or_path: str,**kwargs ) -> tiktoken.Encoding:
#init instance
instance = cls()

Expand All @@ -274,7 +284,8 @@ def from_pretrained(cls, tokenizer_name_or_path: str) -> tiktoken.Encoding:
tokenizer = instance._read_json(os.path.join(tokenizer_name_or_path, 'tokenizer.json'))
else:
try :
path = instance._download_from_hf_hub(tokenizer_name_or_path)
kwargs["repo_id"] = tokenizer_name_or_path
path = instance._download_from_hf_hub(kwargs)
tokenizer = instance._read_json(os.path.join(path, 'tokenizer.json'))
except Exception as e:
print("Tokenizer could not be loaded from a local directory nor from the hub")
Expand Down