diff --git a/pytket/stub_generation/regenerate_stubs b/pytket/stub_generation/regenerate_stubs index a927b3912c..1cc4ff895e 100755 --- a/pytket/stub_generation/regenerate_stubs +++ b/pytket/stub_generation/regenerate_stubs @@ -8,13 +8,17 @@ import pytket import pkgutil -def replace_in_file_string(file_string: str, matcher: str, replacement: str) -> tuple[str, int]: +def replace_in_file_string( + file_string: str, matcher: str, replacement: str +) -> tuple[str, int]: split_text = re.split(matcher, file_string) modified_text = replacement.join(split_text) return modified_text, len(split_text) - 1 -def replace_in_file_string_multiline(file_string: str, matcher: str, replacement: str) -> tuple[str, int]: +def replace_in_file_string_multiline( + file_string: str, matcher: str, replacement: str +) -> tuple[str, int]: split_text = re.split(matcher, file_string, re.MULTILINE) modified_text = replacement.join(split_text) return modified_text, len(split_text) - 1 @@ -27,8 +31,12 @@ class StubFixer: handle_circuit_depth_needed: bool = False def handle_args_kwargs_types(self, file_string: str) -> str: - modified, n_args = replace_in_file_string(file_string, "\*args(?=[^:])", "*args: Any") - modified, n_kwargs = replace_in_file_string(modified, "\*\*kwargs(?=[^:])", "**kwargs: Any") + modified, n_args = replace_in_file_string( + file_string, "\*args(?=[^:])", "*args: Any" + ) + modified, n_kwargs = replace_in_file_string( + modified, "\*\*kwargs(?=[^:])", "**kwargs: Any" + ) n_total = n_args + n_kwargs if n_total > 0 and not re.search("from typing import .*Any", modified): self.handle_args_kwargs_needed = True @@ -40,50 +48,68 @@ class StubFixer: modified_text = "NDArray[".join(split_text) if len(split_text) > 1: modified_text, _ = replace_in_file_string(modified_text, "\[., .\]", "") - modified_text, _ = replace_in_file_string(modified_text, "NDArray\[bool\]", "NDArray[numpy.bool_]") - modified_text = "from numpy.typing import NDArray" + os.linesep + modified_text + modified_text, _ = replace_in_file_string( + modified_text, "NDArray\[bool\]", "NDArray[numpy.bool_]" + ) + modified_text = ( + "from numpy.typing import NDArray" + os.linesep + modified_text + ) self.handle_numpy_stuff_needed = True return modified_text return file_string def print_info(self): if not self.handle_args_kwargs_needed: - print(f"Info: No file required the transformation `{self.handle_args_kwargs_types.__name__}`. Is it still needed?") + print( + f"Info: No file required the transformation `{self.handle_args_kwargs_types.__name__}`. Is it still needed?" + ) if not self.handle_numpy_stuff_needed: - print(f"Info: No file required the transformation `{self.handle_numpy_stuff.__name__}`. Is it still needed?") + print( + f"Info: No file required the transformation `{self.handle_numpy_stuff.__name__}`. Is it still needed?" + ) if __name__ == "__main__": parser = argparse.ArgumentParser( - prog='StubGenerator', - description='generates type stubs for pybind11 modules', + prog="StubGenerator", + description="generates type stubs for pybind11 modules", ) pytket_dir = Path(__file__).parent.parent.resolve() gen_root_dir = pytket_dir for mod in pkgutil.iter_modules(pytket._tket.__path__): - print(f"Generate {mod.name}.pyi") - subprocess.run([ - "pybind11-stubgen", - f"pytket._tket.{mod.name}", - "--enum-class-locations", "CXConfigType|BasisOrder|OpType:pytket._tket.circuit", - "--enum-class-locations", "PauliSynthStrat:pytket._tket.transform", - "--enum-class-locations", "GraphColourMethod|PauliPartitionStrat:pytket._tket.partition", - "--enum-class-locations", "SafetyMode:pytket._tket.passes", - "--enum-class-locations", "ZXWireType|QuantumType:pytket._tket.zx", - "-o", gen_root_dir]) + if not mod.name.startswith("lib"): + print(f"Generate {mod.name}.pyi") + subprocess.run( + [ + "pybind11-stubgen", + f"pytket._tket.{mod.name}", + "--enum-class-locations", + "CXConfigType|BasisOrder|OpType:pytket._tket.circuit", + "--enum-class-locations", + "PauliSynthStrat:pytket._tket.transform", + "--enum-class-locations", + "GraphColourMethod|PauliPartitionStrat:pytket._tket.partition", + "--enum-class-locations", + "SafetyMode:pytket._tket.passes", + "--enum-class-locations", + "ZXWireType|QuantumType:pytket._tket.zx", + "-o", + gen_root_dir, + ] + ) print("Cleanup:") stub_fixer = StubFixer() - for path in Path(f'{gen_root_dir}/pytket/_tket').iterdir(): + for path in Path(f"{gen_root_dir}/pytket/_tket").iterdir(): if path.is_file() and path.suffix == ".pyi": print(f" Fixing {path}") - with path.open('r+') as file: + with path.open("r+") as file: text = file.read() text = stub_fixer.handle_args_kwargs_types(text) text = stub_fixer.handle_numpy_stuff(text) - with path.open('w') as file: + with path.open("w") as file: file.write(text) print("")