Skip to content

Commit

Permalink
Merge pull request #10 from narenaryan/feat/xml-serialize
Browse files Browse the repository at this point in the history
feat: Add XML & Yaml serializers in addition to JSON
  • Loading branch information
narenaryan authored May 19, 2024
2 parents da19fc4 + 8abdb0a commit 70dc24c
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 10 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
lark
lark==1.1.9
PyYAML==6.0.1
2 changes: 1 addition & 1 deletion src/promptml/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
"""

# pylint: disable=invalid-name
version = "0.5.0"
version = "0.6.0"
21 changes: 13 additions & 8 deletions src/promptml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
prompt = parser.parse()
"""

import json
import os
import re

from lark import Lark, Transformer
from .serializer import SerializerFactory

class PromptMLTransformer(Transformer):
"""
Expand Down Expand Up @@ -182,6 +180,9 @@ def __init__(self, code: str):
self.code = code
self.prompt = {}
self.parser = Lark(promptml_grammar)
self.xml_serializer = SerializerFactory.create_serializer("xml")
self.json_serializer = SerializerFactory.create_serializer("json")
self.yaml_serializer = SerializerFactory.create_serializer("yaml")

def parse(self):
"""
Expand All @@ -201,16 +202,20 @@ def _parse_prompt(self):
self.prompt = PromptParser.transformer.transform(tree)
return self.prompt

def serialize_json(self, indent=None):
def to_json(self, indent=None):
""" Serialize the prompt data to JSON.
"""
return json.dumps(self.prompt, indent=indent)
return self.json_serializer.serialize(self.prompt, indent=indent)

def deserialize_json(self, serialized_data):
""" Deserialize the prompt data from JSON.
def to_yaml(self):
""" Serialize the prompt data to YAML.
"""
self.prompt = json.loads(serialized_data)
return self.yaml_serializer.serialize(self.prompt)

def to_xml(self):
""" Serialize the prompt data to XML.
"""
return self.xml_serializer.serialize(self.prompt)

class PromptParserFromFile(PromptParser):
"""
Expand Down
84 changes: 84 additions & 0 deletions src/promptml/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import json
from xml.etree import ElementTree as ET
from xml.dom import minidom
from abc import ABC, abstractmethod
from enum import Enum
import yaml

class Serializer(ABC):
""" A class for serializing data to a specific format. """
@abstractmethod
def serialize(self, prompt: dict, **kwargs) -> str:
pass

class XMLSerializer(Serializer):
""" A class for serializing data to XML format. """
def _dict_to_xml(self, data, root_name="prompt"):
"""Convert a dictionary to XML"""
root = ET.Element(root_name)

def add_node(parent, data):
"""Recursively add nodes to the XML tree"""
for key, value in data.items():
node = ET.SubElement(parent, key)

if key == "examples":
for example in value:
example_node = ET.SubElement(node, "example")
for k, v in example.items():
child = ET.SubElement(example_node, k)
child.text = str(v)
continue

if key == "instructions":
for instruction in value:
instruction_node = ET.SubElement(node, "step")
instruction_node.text = instruction
continue

if isinstance(value, dict):
add_node(node, value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
add_node(node, item)
else:
child = ET.SubElement(node, "item")
child.text = str(item)
else:
node.text = str(value)

add_node(root, data)
xml_doc = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ")
return xml_doc

def serialize(self, prompt: dict, **kwargs):
return self._dict_to_xml(prompt)

class JSONSerializer(Serializer):
""" A class for serializing data to JSON format. """
def serialize(self, prompt: dict, **kwargs):
indent = kwargs.get("indent", 4)
return json.dumps(prompt, indent=indent)

class YAMLSerializer(Serializer):
""" A class for serializing data to YAML format. """
def serialize(self, prompt: dict, **kwargs):
return yaml.dump(prompt)

class SerializerFormat(Enum):
XML = "xml"
JSON = "json"
YAML = "yaml"

class SerializerFactory:
""" A class for creating serializers. """
@staticmethod
def create_serializer(format: str) -> Serializer:
if format == SerializerFormat.XML.value:
return XMLSerializer()
elif format == SerializerFormat.JSON.value:
return JSONSerializer()
elif format == SerializerFormat.YAML.value:
return YAMLSerializer()
raise ValueError("Invalid format")
14 changes: 14 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from unittest import TestCase
from src.promptml.serializer import (
SerializerFactory,
XMLSerializer,
JSONSerializer,
YAMLSerializer
)

class TestSerializer(TestCase):
def test_create_serializer(self):
self.assertIsInstance(SerializerFactory.create_serializer("xml"), XMLSerializer)
self.assertIsInstance(SerializerFactory.create_serializer("json"), JSONSerializer)
self.assertIsInstance(SerializerFactory.create_serializer("yaml"), YAMLSerializer)
self.assertRaises(ValueError, SerializerFactory.create_serializer, "invalid")

0 comments on commit 70dc24c

Please sign in to comment.