diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index 86b59acc..e2850104 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -47,7 +47,7 @@ class PyTorchNode: pg_name (str): Process Group name for the inter-GPU communication. """ - SUPPORTED_VERSIONS = ["1.0.2-chakra.0.0.4", "1.0.3-chakra.0.0.4", "1.1.0-chakra.0.0.4"] + SUPPORTED_VERSIONS = ["1.0.2-chakra.0.0.4", "1.0.3-chakra.0.0.4", "1.1.0-chakra.0.0.4", "1.1.1-chakra.0.0.4"] def __init__(self, schema: str, node_data: Dict[str, Any]) -> None: """ @@ -86,7 +86,7 @@ def parse_data(self, node_data: Dict[str, Any]) -> None: node_data (Dict[str, Any]): The node data to be parsed. """ if self.schema in self.SUPPORTED_VERSIONS: - if self.schema in ["1.0.2-chakra.0.0.4", "1.0.3-chakra.0.0.4", "1.1.0-chakra.0.0.4"]: + if self.schema in ["1.0.2-chakra.0.0.4", "1.0.3-chakra.0.0.4", "1.1.0-chakra.0.0.4", "1.1.1-chakra.0.0.4"]: self._parse_data_1_0_3_chakra_0_0_4(node_data) else: raise ValueError(