diff --git a/zntrack/core/jupyter.py b/zntrack/core/jupyter.py index 9083393a..83958f36 100644 --- a/zntrack/core/jupyter.py +++ b/zntrack/core/jupyter.py @@ -10,6 +10,7 @@ def jupyter_class_to_file(silent, nb_name, module_name): """Extract the class definition form an ipynb file""" + # TOOD is it really module_name and not class name? if not silent: log.warning( @@ -27,6 +28,7 @@ def jupyter_class_to_file(silent, nb_name, module_name): ) reading_class = False + found_node = False imports = "" @@ -52,6 +54,15 @@ def jupyter_class_to_file(silent, nb_name, module_name): if line.startswith("@"): # handle decorators reading_class = True class_definition += line + if line.startswith(f"class {module_name}") or line.startswith( + f"def {module_name}" + ): + found_node = True + if found_node and not reading_class: + if re.match(r"#.*zntrack:.*break", line): + # stop converting the file after this line if the Node was already + # found + break src = imports + "\n\n" + class_definition