From db8c1561568751eb2aac4905e62d35ec366c62f6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 15:18:12 -0700 Subject: [PATCH] extract header from code --- .cross_sync/generate.py | 46 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 321ad82fa..618dfbc6c 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -23,11 +23,27 @@ """ +def extract_header_comments(file_path) -> str: + """ + Extract the file header. Header is defined as the top-level + comments before any code or imports + """ + header = [] + with open(file_path, "r") as f: + for line in f: + if line.startswith("#") or line.strip() == "": + header.append(line) + else: + break + return "".join(header) + + class CrossSyncOutputFile: - def __init__(self, file_path: str, ast_tree): - self.file_path = file_path + def __init__(self, output_path: str, ast_tree, header: str | None = None): + self.output_path = output_path self.tree = ast_tree + self.header = header or "" def render(self, with_black=True, save_to_disk: bool = False) -> str: """ @@ -37,24 +53,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: with_black: whether to run the output through black before returning save_to_disk: whether to write the output to the file path """ - header = ( - "# Copyright 2024 Google LLC\n" - "#\n" - '# Licensed under the Apache License, Version 2.0 (the "License");\n' - "# you may not use this file except in compliance with the License.\n" - "# You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing, software\n" - '# distributed under the License is distributed on an "AS IS" BASIS,\n' - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "# See the License for the specific language governing permissions and\n" - "# limitations under the License.\n" - "#\n" - "# This file is automatically generated by CrossSync. Do not edit manually.\n" - ) - full_str = header + ast.unparse(self.tree) + full_str = self.header + ast.unparse(self.tree) if with_black: import black # type: ignore import autoflake # type: ignore @@ -65,8 +64,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: ) if save_to_disk: import os - os.makedirs(os.path.dirname(self.file_path), exist_ok=True) - with open(self.file_path, "w") as f: + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w") as f: f.write(full_str) return full_str @@ -87,7 +86,8 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: if output_path is not None: # contains __CROSS_SYNC_OUTPUT__ annotation converted_tree = file_transformer.visit(ast_tree) - artifacts.add(CrossSyncOutputFile(output_path, converted_tree)) + header = extract_header_comments(file_path) + artifacts.add(CrossSyncOutputFile(output_path, converted_tree, header)) # return set of output artifacts return artifacts