Skip to content

Commit

Permalink
[WIP] Fix Pyright static type checking by replacing if-else imports w…
Browse files Browse the repository at this point in the history
…ith try-except (huggingface#16578)

* rebase and isort

* modify cookiecutter init

* fix cookiecutter auto imports

* fix clean_frameworks_in_init

* fix add_model_to_main_init

* blackify

* replace unnecessary f-strings

* update yolos imports

* fix roberta import bug

* fix yolos missing dependency

* fix add_model_like and cookiecutter bug

* fix repository consistency error

* modify cookiecutter, fix add_new_model_like

* remove stale line

Co-authored-by: Dom Miketa <dmiketa@exscientia.co.uk>
  • Loading branch information
d-miketa and exs-dmiketa authored May 9, 2022
1 parent 7783fa6 commit df735d1
Show file tree
Hide file tree
Showing 116 changed files with 3,842 additions and 754 deletions.
268 changes: 169 additions & 99 deletions src/transformers/__init__.py

Large diffs are not rendered by default.

21 changes: 14 additions & 7 deletions src/transformers/commands/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,9 @@ def clean_frameworks_in_init(
return

remove_pattern = "|".join(to_remove)
re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
re_try = re.compile(r"\s*try:")
re_else = re.compile(r"\s*else:")
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")

with open(init_file, "r", encoding="utf-8") as f:
Expand All @@ -776,11 +778,15 @@ def clean_frameworks_in_init(
new_lines = []
idx = 0
while idx < len(lines):
# Conditional imports
if re_conditional_imports.search(lines[idx]) is not None:
# Conditional imports in try-except-else blocks
if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
# Remove the preceding `try:`
new_lines.pop()
idx += 1
while is_empty_line(lines[idx]):
# Iterate until `else:`
while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
idx += 1
idx += 1
indent = find_indent(lines[idx])
while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
idx += 1
Expand All @@ -790,6 +796,7 @@ def clean_frameworks_in_init(
for framework in to_remove:
line = line.replace(f", is_{framework}_available", "")
line = line.replace(f"is_{framework}_available, ", "")
line = line.replace(f"is_{framework}_available,", "")
line = line.replace(f"is_{framework}_available", "")

if len(line.strip()) > 0:
Expand Down Expand Up @@ -836,11 +843,11 @@ def add_model_to_main_init(
while idx < len(lines):
if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
framework = None
elif lines[idx].lstrip().startswith("if is_torch_available"):
elif lines[idx].lstrip().startswith("if not is_torch_available"):
framework = "pt"
elif lines[idx].lstrip().startswith("if is_tf_available"):
elif lines[idx].lstrip().startswith("if not is_tf_available"):
framework = "tf"
elif lines[idx].lstrip().startswith("if is_flax_available"):
elif lines[idx].lstrip().startswith("if not is_flax_available"):
framework = "flax"

# Skip if we are in a framework not wanted.
Expand Down
71 changes: 61 additions & 10 deletions src/transformers/models/albert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
Expand All @@ -32,13 +33,28 @@
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig", "AlbertOnnxConfig"],
}

if is_sentencepiece_available():
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_albert"] = ["AlbertTokenizer"]

if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_albert_fast"] = ["AlbertTokenizerFast"]

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_albert"] = [
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"AlbertForMaskedLM",
Expand All @@ -52,7 +68,12 @@
"load_tf_weights_in_albert",
]

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_albert"] = [
"TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFAlbertForMaskedLM",
Expand All @@ -66,7 +87,12 @@
"TFAlbertPreTrainedModel",
]

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_albert"] = [
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
Expand All @@ -81,13 +107,28 @@
if TYPE_CHECKING:
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig

if is_sentencepiece_available():
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_albert import AlbertTokenizer

if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_albert_fast import AlbertTokenizerFast

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
Expand All @@ -101,7 +142,12 @@
load_tf_weights_in_albert,
)

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM,
Expand All @@ -115,7 +161,12 @@
TFAlbertPreTrainedModel,
)

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
Expand Down
50 changes: 43 additions & 7 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)


_import_structure = {
Expand All @@ -29,7 +35,12 @@
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
}

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_auto"] = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
Expand Down Expand Up @@ -81,7 +92,12 @@
"AutoModelWithLMHead",
]

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_auto"] = [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
Expand Down Expand Up @@ -115,7 +131,12 @@
"TFAutoModelWithLMHead",
]

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
Expand Down Expand Up @@ -151,7 +172,12 @@
from .processing_auto import PROCESSOR_MAPPING, AutoProcessor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
Expand Down Expand Up @@ -203,7 +229,12 @@
AutoModelWithLMHead,
)

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -237,7 +268,12 @@
TFAutoModelWithLMHead,
)

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down
65 changes: 56 additions & 9 deletions src/transformers/models/bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,35 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"],
"tokenization_bart": ["BartTokenizer"],
}

if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_bart_fast"] = ["BartTokenizerFast"]

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bart"] = [
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
"BartForCausalLM",
Expand All @@ -40,10 +57,20 @@
"PretrainedBartModel",
]

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bart"] = [
"FlaxBartDecoderPreTrainedModel",
"FlaxBartForCausalLM",
Expand All @@ -58,10 +85,20 @@
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig
from .tokenization_bart import BartTokenizer

if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_bart_fast import BartTokenizerFast

if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BartForCausalLM,
Expand All @@ -73,10 +110,20 @@
PretrainedBartModel,
)

if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_bart import (
FlaxBartDecoderPreTrainedModel,
FlaxBartForCausalLM,
Expand Down
Loading

0 comments on commit df735d1

Please sign in to comment.