Skip to content

Commit 7aca328

Browse files
cyyeverArthurZucker
authored andcommitted
Use removeprefix and removesuffix (#41240)
* Use removeprefix and removesuffix Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * More fixes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
1 parent 9e60961 commit 7aca328

File tree

12 files changed

+14
-28
lines changed

12 files changed

+14
-28
lines changed

src/transformers/commands/chat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,7 @@ def parse_generate_flags(self, generate_flags: list[str]) -> dict:
443443
# 2. b. strings should be quoted
444444
def is_number(s: str) -> bool:
445445
# handle negative numbers
446-
if s.startswith("-"):
447-
s = s[1:]
446+
s = s.removeprefix("-")
448447
return s.replace(".", "", 1).isdigit()
449448

450449
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}

src/transformers/commands/serving.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,8 +1066,7 @@ def generate_with_cache(**kwargs):
10661066
for result in streamer:
10671067
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
10681068
if "gptoss" in model.config.architectures[0].lower():
1069-
if result.endswith("<|return|>"):
1070-
result = result[: -len("<|return|>")]
1069+
result = result.removesuffix("<|return|>")
10711070
results += result
10721071

10731072
# (related to temporary hack 2)
@@ -1325,8 +1324,7 @@ def generate_with_cache(**kwargs):
13251324
for result in streamer:
13261325
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
13271326
if "gptoss" in model.config.architectures[0].lower():
1328-
if result.endswith("<|return|>"):
1329-
result = result[: -len("<|return|>")]
1327+
result = result.removesuffix("<|return|>")
13301328
results += result
13311329

13321330
# (related to temporary hack 2)

src/transformers/dynamic_module_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,7 @@ def get_class_in_module(
285285
`typing.Type`: The class looked for.
286286
"""
287287
name = os.path.normpath(module_path)
288-
if name.endswith(".py"):
289-
name = name[:-3]
288+
name = name.removesuffix(".py")
290289
name = name.replace(os.path.sep, ".")
291290
module_file: Path = Path(HF_MODULES_CACHE) / module_path
292291
with _HF_REMOTE_CODE_LOCK:

src/transformers/modelcard.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,7 @@ def parse_log_history(log_history):
794794
if idx > 0:
795795
eval_results = {}
796796
for key, value in log_history[idx].items():
797-
if key.startswith("eval_"):
798-
key = key[5:]
797+
key = key.removeprefix("eval_")
799798
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
800799
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
801800
eval_results[camel_cased_key] = value

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5632,7 +5632,7 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal
56325632
for name, module in self.named_modules():
56335633
if remove_prefix:
56345634
_prefix = f"{self.base_model_prefix}."
5635-
name = name[len(_prefix) :] if name.startswith(_prefix) else name
5635+
name = name.removeprefix(_prefix)
56365636
elif add_prefix:
56375637
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
56385638

@@ -5952,7 +5952,7 @@ def _adjust_missing_and_unexpected_keys(
59525952
# in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model
59535953
if loading_task_model_from_base_state_dict:
59545954
_prefix = f"{self.base_model_prefix}."
5955-
unexpected_keys = [k[len(_prefix) :] if k.startswith(_prefix) else k for k in unexpected_keys]
5955+
unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys]
59565956

59575957
return missing_keys, unexpected_keys
59585958

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
566566
)
567567
image_processor_class = get_image_processor_class_from_name(image_processor_type)
568568
else:
569-
image_processor_type_slow = (
570-
image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type
571-
)
569+
image_processor_type_slow = image_processor_type.removesuffix("Fast")
572570
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
573571
if image_processor_class is None and image_processor_type.endswith("Fast"):
574572
raise ValueError(

src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, c
105105
hf_model = ChineseCLIPModel(config).eval()
106106

107107
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
108-
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
108+
pt_weights = {(name.removeprefix("module.")): value for name, value in pt_weights.items()}
109109

110110
copy_text_model_and_projection(hf_model, pt_weights)
111111
copy_vision_model_and_projection(hf_model, pt_weights)

src/transformers/models/longt5/modeling_longt5.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,8 +1272,7 @@ def dummy_inputs(self):
12721272

12731273
def _try_load_missing_tied_module(self, key):
12741274
module = self
1275-
if key.endswith(".weight"):
1276-
key = key[: -len(".weight")]
1275+
key = key.removesuffix(".weight")
12771276
for sub_key in key.split("."):
12781277
if not hasattr(module, sub_key):
12791278
return

src/transformers/models/rag/retrieval_rag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,7 @@ def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=N
509509
def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
510510
# TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
511511
# TODO(piktus): better handling of truncation
512-
if doc_title.startswith('"'):
513-
doc_title = doc_title[1:]
514-
if doc_title.endswith('"'):
515-
doc_title = doc_title[:-1]
512+
doc_title = doc_title.removeprefix('"').removesuffix('"')
516513
if prefix is None:
517514
prefix = ""
518515
out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(

src/transformers/utils/auto_docstring.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,8 +1215,7 @@ def get_checkpoint_from_config_class(config_class):
12151215
# For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
12161216
for ckpt_name, ckpt_link in checkpoints:
12171217
# allow the link to end with `/`
1218-
if ckpt_link.endswith("/"):
1219-
ckpt_link = ckpt_link[:-1]
1218+
ckpt_link = ckpt_link.removesuffix("/")
12201219

12211220
# verify the checkpoint name corresponds to the checkpoint link
12221221
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"

0 commit comments

Comments
 (0)