@@ -151,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
151151 ... lines = [ln.rstrip() for ln in fp.readlines()]
152152 >>> lines = replace_vars_with_imports(lines, import_path)
153153 """
154+ copied = []
154155 body , tracking , skip_offset = [], False , 0
155156 for ln in lines :
156157 offset = len (ln ) - len (ln .lstrip ())
@@ -161,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
161162 if var :
162163 name = var .groups ()[0 ]
163164 # skip private or apply white-list for allowed vars
164- if not name .startswith ("__" ) or name in ("__all__" ,):
165+ if name not in copied and ( not name .startswith ("__" ) or name in ("__all__" ,) ):
165166 body .append (f"{ ' ' * offset } from { import_path } import { name } # noqa: F401" )
167+ copied .append (name )
166168 tracking , skip_offset = True , offset
167169 continue
168170 if not tracking :
@@ -197,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
197199 return body
198200
199201
202+ def prune_func_calls (lines : List [str ]) -> List [str ]:
203+ """Prune calling functions from a file, even multi-line.
204+
205+ >>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+ >>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+ >>> with open(py_file, encoding="utf-8") as fp:
208+ ... lines = [ln.rstrip() for ln in fp.readlines()]
209+ >>> lines = prune_func_calls(lines)
210+ """
211+ body , tracking , score = [], False , 0
212+ for ln in lines :
213+ # catching callable
214+ calling = re .match (r"^@?[\w_\d\.]+ *\(" , ln .lstrip ())
215+ if calling and " import " not in ln :
216+ tracking = True
217+ score = 0
218+ if tracking :
219+ score += ln .count ("(" ) - ln .count (")" )
220+ if score == 0 :
221+ tracking = False
222+ else :
223+ body .append (ln )
224+ return body
225+
226+
200227def prune_empty_statements (lines : List [str ]) -> List [str ]:
201228 """Prune emprty if/else and try/except.
202229
@@ -302,6 +329,15 @@ def parse_version_from_file(pkg_root: str) -> str:
302329 return ver
303330
304331
332+ def prune_duplicate_lines (body ):
333+ body_ = []
334+ # drop duplicated lines
335+ for ln in body :
336+ if ln .lstrip () not in body_ or ln .lstrip () in (")" , "" ):
337+ body_ .append (ln )
338+ return body_
339+
340+
305341def create_meta_package (src_folder : str , pkg_name : str = "pytorch_lightning" , lit_name : str = "pytorch" ):
306342 """Parse the real python package and for each module create a mirroe version with repalcing all function and
307343 class implementations by cross-imports to the true package.
@@ -331,34 +367,36 @@ class implementations by cross-imports to the true package.
331367 logging .warning (f"unsupported file: { local_path } " )
332368 continue
333369 # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
334- body = prune_comments_docstrings (lines )
370+ body = prune_comments_docstrings ([ ln . rstrip () for ln in lines ] )
335371 if fname not in ("__init__.py" , "__main__.py" ):
336372 body = prune_imports_callables (body )
337- body = replace_block_with_imports ([ln .rstrip () for ln in body ], import_path , "class" )
338- body = replace_block_with_imports (body , import_path , "def" )
339- body = replace_block_with_imports (body , import_path , "async def" )
373+ for key_word in ("class" , "def" , "async def" ):
374+ body = replace_block_with_imports (body , import_path , key_word )
375+ # TODO: fix reimporting which is artefact after replacing var assignment with import;
376+ # after fixing , update CI by remove F811 from CI/check pkg
340377 body = replace_vars_with_imports (body , import_path )
378+ if fname not in ("__main__.py" ,):
379+ body = prune_func_calls (body )
341380 body_len = - 1
342381 # in case of several in-depth statements
343382 while body_len != len (body ):
344383 body_len = len (body )
384+ body = prune_duplicate_lines (body )
345385 body = prune_empty_statements (body )
346386 # add try/catch wrapper for whole body,
347387 # so when import fails it tells you what is the package version this meta package was generated for...
348388 body = wrap_try_except (body , pkg_name , pkg_ver )
349389
350390 # todo: apply pre-commit formatting
391+ # clean to many empty lines
351392 body = [ln for ln , _group in groupby (body )]
352- lines = []
353393 # drop duplicated lines
354- for ln in body :
355- if ln + os .linesep not in lines or ln .lstrip () in (")" , "" ):
356- lines .append (ln + os .linesep )
394+ body = prune_duplicate_lines (body )
357395 # compose the target file name
358396 new_file = os .path .join (src_folder , "lightning" , lit_name , local_path )
359397 os .makedirs (os .path .dirname (new_file ), exist_ok = True )
360398 with open (new_file , "w" , encoding = "utf-8" ) as fp :
361- fp .writelines (lines )
399+ fp .writelines ([ ln + os . linesep for ln in body ] )
362400
363401
364402def set_version_today (fpath : str ) -> None :
@@ -380,7 +418,6 @@ def _download_frontend(root: str = _PROJECT_ROOT):
380418 directory."""
381419
382420 try :
383- build_dir = "build"
384421 frontend_dir = pathlib .Path (root , "src" , "lightning_app" , "ui" )
385422 download_dir = tempfile .mkdtemp ()
386423
@@ -390,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
390427 file = tarfile .open (fileobj = response , mode = "r|gz" )
391428 file .extractall (path = download_dir )
392429
393- shutil .move (os .path .join (download_dir , build_dir ), frontend_dir )
430+ shutil .move (os .path .join (download_dir , "build" ), frontend_dir )
394431 print ("The Lightning UI has successfully been downloaded!" )
395432
396433 # If installing from source without internet connection, we don't want to break the installation
0 commit comments