Skip to content

Commit cd92e35

Browse files
prune func calls in meta pkg init (#13742)
* prune func calls in meta pkg init * move calling * prune * coped Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c136ef5 commit cd92e35

File tree

3 files changed

+66
-25
lines changed

3 files changed

+66
-25
lines changed

.actions/setup_tools.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
151151
... lines = [ln.rstrip() for ln in fp.readlines()]
152152
>>> lines = replace_vars_with_imports(lines, import_path)
153153
"""
154+
copied = []
154155
body, tracking, skip_offset = [], False, 0
155156
for ln in lines:
156157
offset = len(ln) - len(ln.lstrip())
@@ -161,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
161162
if var:
162163
name = var.groups()[0]
163164
# skip private or apply white-list for allowed vars
164-
if not name.startswith("__") or name in ("__all__",):
165+
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
165166
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
167+
copied.append(name)
166168
tracking, skip_offset = True, offset
167169
continue
168170
if not tracking:
@@ -197,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
197199
return body
198200

199201

202+
def prune_func_calls(lines: List[str]) -> List[str]:
203+
"""Prune calling functions from a file, even multi-line.
204+
205+
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+
>>> with open(py_file, encoding="utf-8") as fp:
208+
... lines = [ln.rstrip() for ln in fp.readlines()]
209+
>>> lines = prune_func_calls(lines)
210+
"""
211+
body, tracking, score = [], False, 0
212+
for ln in lines:
213+
# catching callable
214+
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
215+
if calling and " import " not in ln:
216+
tracking = True
217+
score = 0
218+
if tracking:
219+
score += ln.count("(") - ln.count(")")
220+
if score == 0:
221+
tracking = False
222+
else:
223+
body.append(ln)
224+
return body
225+
226+
200227
def prune_empty_statements(lines: List[str]) -> List[str]:
201228
"""Prune emprty if/else and try/except.
202229
@@ -302,6 +329,15 @@ def parse_version_from_file(pkg_root: str) -> str:
302329
return ver
303330

304331

332+
def prune_duplicate_lines(body):
333+
body_ = []
334+
# drop duplicated lines
335+
for ln in body:
336+
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
337+
body_.append(ln)
338+
return body_
339+
340+
305341
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
306342
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
307343
class implementations by cross-imports to the true package.
@@ -331,34 +367,36 @@ class implementations by cross-imports to the true package.
331367
logging.warning(f"unsupported file: {local_path}")
332368
continue
333369
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
334-
body = prune_comments_docstrings(lines)
370+
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
335371
if fname not in ("__init__.py", "__main__.py"):
336372
body = prune_imports_callables(body)
337-
body = replace_block_with_imports([ln.rstrip() for ln in body], import_path, "class")
338-
body = replace_block_with_imports(body, import_path, "def")
339-
body = replace_block_with_imports(body, import_path, "async def")
373+
for key_word in ("class", "def", "async def"):
374+
body = replace_block_with_imports(body, import_path, key_word)
375+
# TODO: fix reimporting which is artefact after replacing var assignment with import;
376+
# after fixing , update CI by remove F811 from CI/check pkg
340377
body = replace_vars_with_imports(body, import_path)
378+
if fname not in ("__main__.py",):
379+
body = prune_func_calls(body)
341380
body_len = -1
342381
# in case of several in-depth statements
343382
while body_len != len(body):
344383
body_len = len(body)
384+
body = prune_duplicate_lines(body)
345385
body = prune_empty_statements(body)
346386
# add try/catch wrapper for whole body,
347387
# so when import fails it tells you what is the package version this meta package was generated for...
348388
body = wrap_try_except(body, pkg_name, pkg_ver)
349389

350390
# todo: apply pre-commit formatting
391+
# clean to many empty lines
351392
body = [ln for ln, _group in groupby(body)]
352-
lines = []
353393
# drop duplicated lines
354-
for ln in body:
355-
if ln + os.linesep not in lines or ln.lstrip() in (")", ""):
356-
lines.append(ln + os.linesep)
394+
body = prune_duplicate_lines(body)
357395
# compose the target file name
358396
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
359397
os.makedirs(os.path.dirname(new_file), exist_ok=True)
360398
with open(new_file, "w", encoding="utf-8") as fp:
361-
fp.writelines(lines)
399+
fp.writelines([ln + os.linesep for ln in body])
362400

363401

364402
def set_version_today(fpath: str) -> None:
@@ -380,7 +418,6 @@ def _download_frontend(root: str = _PROJECT_ROOT):
380418
directory."""
381419

382420
try:
383-
build_dir = "build"
384421
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
385422
download_dir = tempfile.mkdtemp()
386423

@@ -390,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
390427
file = tarfile.open(fileobj=response, mode="r|gz")
391428
file.extractall(path=download_dir)
392429

393-
shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
430+
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
394431
print("The Lightning UI has successfully been downloaded!")
395432

396433
# If installing from source without internet connection, we don't want to break the installation

.github/actions/pkg-check/action.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ runs:
1414
run: pip install "twine==4.0.1" setuptools wheel flake8
1515
shell: bash
1616

17-
- name: Create package
17+
- name: Source check
1818
env:
1919
PACKAGE_NAME: ${{ inputs.pkg-name }}
2020
run: |
2121
python setup.py check --metadata --strict
22-
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303
23-
python setup.py sdist bdist_wheel
22+
# TODO: fix reimporting (F811) which is aftefact after rplacing var assigne with import in meta package
23+
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303,F811
24+
shell: bash
25+
26+
- name: Create package
27+
env:
28+
PACKAGE_NAME: ${{ inputs.pkg-name }}
29+
run: python setup.py sdist bdist_wheel
2430
shell: bash
2531

2632
- name: Check package

src/pytorch_lightning/loggers/__init__.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from os import environ
14+
import os
1515

16-
from pytorch_lightning.loggers.base import ( # LightningLoggerBase imported for backward compatibility
17-
LightningLoggerBase,
18-
)
16+
# LightningLoggerBase imported for backward compatibility
17+
from pytorch_lightning.loggers.base import LightningLoggerBase
18+
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
1919
from pytorch_lightning.loggers.csv_logs import CSVLogger
2020
from pytorch_lightning.loggers.logger import Logger, LoggerCollection
21-
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
22-
23-
__all__ = ["CSVLogger", "LightningLoggerBase", "Logger", "LoggerCollection", "TensorBoardLogger"]
24-
25-
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
2621
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
2722
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F401
23+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
2824
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401
2925

26+
__all__ = ["CSVLogger", "LightningLoggerBase", "Logger", "LoggerCollection", "TensorBoardLogger"]
27+
3028
if _COMET_AVAILABLE:
3129
__all__.append("CometLogger")
3230
# needed to prevent ModuleNotFoundError and duplicated logs.
33-
environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
31+
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
3432

3533
if _MLFLOW_AVAILABLE:
3634
__all__.append("MLFlowLogger")

0 commit comments

Comments
 (0)