Skip to content

Commit

Permalink
generalised imports to schema_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamzu24 committed Jan 10, 2025
1 parent cb442b2 commit 26b7283
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions packages/schema_wrapper/src/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
SCHEMA_VERSION: Final = "v0.10.0"
SCHEMA_URL_TEMPLATE: Final = "https://raw.githubusercontent.com/uwdata/mosaic/refs/heads/main/docs/public/schema/{version}.json"
KNOWN_PRIMITIVES = {"string": "str", "boolean": "bool", "number": "float", "object": "Dict[str, Any]"}
IMPORTS = {"typing": ["List", "Dict", "Any", "Union"], ".src.SchemaBase": ["SchemaBase"], ".src.utils": ["revert_validation"]}

def generate_import_string(imports: Dict[str, List[str]]) -> str:
import_string = ""
for source, cur_imports in imports.items():
import_string += f"from {source} import {', '.join(cur_imports)}\n"
import_string += '\n'
return import_string

def schema_url(version: str = SCHEMA_VERSION) -> str:
return SCHEMA_URL_TEMPLATE.format(version=version)
Expand Down Expand Up @@ -44,7 +52,7 @@ def generate_additional_properties_class(class_name: str, class_schema: Dict[str
def generate_enum_class(class_name: str, class_schema: Dict[str, Any]) -> str:
enum_options = class_schema.get('enum', [])
enum_type = get_type_hint(class_schema)
class_def = f"class {class_name}(SchemaBase):\n enum_options = {enum_options}\n\n def __init__(self, value: {enum_type}):\n"
class_def = f"class {class_name}(SchemaBase):\n enum_options = {enum_options}\n\n def __init__(self, value: {enum_type}):\n"
class_def += """ if value not in self.enum_options:
raise ValueError(f"Value of enum not in allowed values: {self.enum_options}")
self.value = value\n"""
Expand All @@ -59,9 +67,10 @@ def generate_class(class_name: str, class_schema: Dict[str, Any]) -> str:
return generate_additional_properties_class(class_name, class_schema)
elif 'enum' in class_schema:
return generate_enum_class(class_name, class_schema)
#if 'items' in class_schema:
# print(f"class_name: {class_name}")
return f"class {class_name}(SchemaBase):\n def __init__(self):\n pass\n"
else:
#print(f"class_name: {class_name}\nschema: {class_schema}\n\n")
type_hint = get_type_hint(class_schema)
return f"class {class_name}(SchemaBase):\n def __init__(self, value: {type_hint}):\n self.value = value\n"

# Check for '$ref' and handle it
if '$ref' in class_schema:
Expand Down Expand Up @@ -196,7 +205,8 @@ def generate_schema_wrapper(schema_file: Path, output_file: Path) -> str:
definitions[name] = class_code

generated_classes = "\n\n".join(definitions.values())
generated_classes = "from typing import List, Dict, Any, Union\nfrom .src.SchemaBase import SchemaBase\n\nfrom .src.utils import revert_validation" + generated_classes
import_string = generate_import_string(IMPORTS)
generated_classes = import_string + generated_classes


with open(output_file, 'w') as f:
Expand Down

0 comments on commit 26b7283

Please sign in to comment.