Skip to content

Commit

Permalink
extract header from code
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 12, 2024
1 parent 81a06f8 commit db8c156
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions .cross_sync/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,27 @@
"""


def extract_header_comments(file_path) -> str:
"""
Extract the file header. Header is defined as the top-level
comments before any code or imports
"""
header = []
with open(file_path, "r") as f:
for line in f:
if line.startswith("#") or line.strip() == "":
header.append(line)
else:
break
return "".join(header)


class CrossSyncOutputFile:

def __init__(self, file_path: str, ast_tree):
self.file_path = file_path
def __init__(self, output_path: str, ast_tree, header: str | None = None):
self.output_path = output_path
self.tree = ast_tree
self.header = header or ""

def render(self, with_black=True, save_to_disk: bool = False) -> str:
"""
Expand All @@ -37,24 +53,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str:
with_black: whether to run the output through black before returning
save_to_disk: whether to write the output to the file path
"""
header = (
"# Copyright 2024 Google LLC\n"
"#\n"
'# Licensed under the Apache License, Version 2.0 (the "License");\n'
"# you may not use this file except in compliance with the License.\n"
"# You may obtain a copy of the License at\n"
"#\n"
"# http://www.apache.org/licenses/LICENSE-2.0\n"
"#\n"
"# Unless required by applicable law or agreed to in writing, software\n"
'# distributed under the License is distributed on an "AS IS" BASIS,\n'
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
"# See the License for the specific language governing permissions and\n"
"# limitations under the License.\n"
"#\n"
"# This file is automatically generated by CrossSync. Do not edit manually.\n"
)
full_str = header + ast.unparse(self.tree)
full_str = self.header + ast.unparse(self.tree)
if with_black:
import black # type: ignore
import autoflake # type: ignore
Expand All @@ -65,8 +64,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str:
)
if save_to_disk:
import os
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w") as f:
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
with open(self.output_path, "w") as f:
f.write(full_str)
return full_str

Expand All @@ -87,7 +86,8 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]:
if output_path is not None:
# contains __CROSS_SYNC_OUTPUT__ annotation
converted_tree = file_transformer.visit(ast_tree)
artifacts.add(CrossSyncOutputFile(output_path, converted_tree))
header = extract_header_comments(file_path)
artifacts.add(CrossSyncOutputFile(output_path, converted_tree, header))
# return set of output artifacts
return artifacts

Expand Down

0 comments on commit db8c156

Please sign in to comment.