diff --git a/generate.py b/generate.py index 9b6eb2792..ab8c7e138 100644 --- a/generate.py +++ b/generate.py @@ -1,8 +1,8 @@ # generate.py generates a new recipe scraper. import ast import json -import os import sys +from pathlib import Path import requests @@ -14,7 +14,8 @@ def generate_scraper(class_name, host_name): - with open("templates/scraper.py") as source: + template_path = Path("templates/scraper.py") + with template_path.open() as source: code = source.read() program = ast.parse(code) @@ -23,14 +24,13 @@ def generate_scraper(class_name, host_name): if not state.step(node): break - output = f"recipe_scrapers/{class_name.lower()}.py" - with open(output, "w") as target: - target.write(state.result()) + output = Path(f"recipe_scrapers/{class_name.lower()}.py") + output.write_text(state.result()) def generate_scraper_test(class_name, host_name): - if not os.path.isdir(f"tests/test_data/{host_name}"): - os.mkdir(f"tests/test_data/{host_name}") + test_data_dir = Path(f"tests/test_data/{host_name}") + test_data_dir.mkdir(parents=True, exist_ok=True) testjson = { "host": host_name, @@ -47,13 +47,13 @@ def generate_scraper_test(class_name, host_name): "description": "", } - output = f"tests/test_data/{host_name}/{class_name.lower()}.json" - with open(output, "w") as target: - json.dump(testjson, target, indent=2) + output = test_data_dir / f"{class_name.lower()}.json" + output.write_text(json.dumps(testjson, indent=2)) def init_scraper(class_name): - with open("recipe_scrapers/__init__.py", "r+") as source: + init_file = Path("recipe_scrapers/__init__.py") + with init_file.open("r+") as source: code = source.read() program = ast.parse(code) @@ -68,8 +68,8 @@ def init_scraper(class_name): def generate_test_data(class_name, host_name, content): - output = f"tests/test_data/{host_name}/{class_name.lower()}.testhtml" - with open(output, "w", encoding="utf-8") as target: + output = Path(f"tests/test_data/{host_name}/{class_name.lower()}.testhtml") + with output.open("w", encoding="utf-8") as target: target.write(content.decode(encoding="utf-8")) @@ -218,9 +218,12 @@ def get_line_offsets(code): def main(): - if len(sys.argv) < 3: - print("Usage: generate.py ScraperClassName url") - exit(1) + if len(sys.argv) != 3: + print("Usage: python generate.py ") + print( + "Example: python generate.py ExampleClassName https://www.example.com/recipe/12345/example-recipe/" + ) + sys.exit(1) class_name = sys.argv[1] url = sys.argv[2] @@ -232,6 +235,8 @@ def main(): generate_test_data(class_name, host_name, testhtml) init_scraper(class_name) + print(f"Successfully generated scraper for {class_name} ({host_name})") + if __name__ == "__main__": main()