|
7 | 7 | from typing import Any, Dict, Optional, Type, Union |
8 | 8 |
|
9 | 9 | import huggingface_hub |
10 | | -from huggingface_hub import (file_exists, hf_hub_download, |
| 10 | +from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, |
11 | 11 | try_to_load_from_cache) |
12 | 12 | from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, |
13 | 13 | LocalEntryNotFoundError, |
@@ -395,18 +395,28 @@ def get_sentence_transformer_tokenizer_config(model: str, |
395 | 395 | - dict: A dictionary containing the configuration parameters |
396 | 396 | for the Sentence Transformer BERT model. |
397 | 397 | """ |
398 | | - for config_name in [ |
399 | | - "sentence_bert_config.json", |
400 | | - "sentence_roberta_config.json", |
401 | | - "sentence_distilbert_config.json", |
402 | | - "sentence_camembert_config.json", |
403 | | - "sentence_albert_config.json", |
404 | | - "sentence_xlm-roberta_config.json", |
405 | | - "sentence_xlnet_config.json", |
406 | | - ]: |
407 | | - encoder_dict = get_hf_file_to_dict(config_name, model, revision) |
408 | | - if encoder_dict: |
409 | | - break |
| 398 | + sentence_transformer_config_files = [ |
| 399 | + "sentence_bert_config.json", |
| 400 | + "sentence_roberta_config.json", |
| 401 | + "sentence_distilbert_config.json", |
| 402 | + "sentence_camembert_config.json", |
| 403 | + "sentence_albert_config.json", |
| 404 | + "sentence_xlm-roberta_config.json", |
| 405 | + "sentence_xlnet_config.json", |
| 406 | + ] |
| 407 | + try: |
| 408 | + # If model is on HuggingfaceHub, get the repo files |
| 409 | + repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) |
| 410 | + except Exception as e: |
| 411 | + logger.debug("Error getting repo files", e) |
| 412 | + repo_files = [] |
| 413 | + |
| 414 | + encoder_dict = None |
| 415 | + for config_name in sentence_transformer_config_files: |
| 416 | + if config_name in repo_files or Path(model).exists(): |
| 417 | + encoder_dict = get_hf_file_to_dict(config_name, model, revision) |
| 418 | + if encoder_dict: |
| 419 | + break |
410 | 420 |
|
411 | 421 | if not encoder_dict: |
412 | 422 | return None |
|
0 commit comments