From 4325741d134c75b2aeb071fc13880f4eed367ab2 Mon Sep 17 00:00:00 2001 From: jesko Date: Fri, 27 Sep 2024 21:41:34 +0200 Subject: [PATCH] improves automatic path discovery in xml extractor --- refinery/lib/xml.py | 47 +++++++++++++++--- refinery/units/formats/__init__.py | 79 ++++++++++++++++-------------- refinery/units/formats/xml.py | 11 +---- 3 files changed, 81 insertions(+), 56 deletions(-) diff --git a/refinery/lib/xml.py b/refinery/lib/xml.py index bf6629588..edd17ca1e 100644 --- a/refinery/lib/xml.py +++ b/refinery/lib/xml.py @@ -9,6 +9,7 @@ import weakref import re import defusedxml.ElementTree as et +import collections from typing import Any, Dict, Iterable, List, Optional from xml.parsers import expat @@ -78,7 +79,7 @@ class XMLNodeBase: different types of XML node classes to represent e.g. leaves / text nodes from others. """ - __slots__ = 'tag', 'children', 'empty', 'attributes', 'content', '_parent', '__weakref__' + __slots__ = 'tag', 'index', 'children', 'empty', 'attributes', 'content', '_parent', '__weakref__' attributes: Dict[str, Any] children: List[XMLNodeBase] @@ -91,14 +92,18 @@ class XMLNodeBase: def __init__( self, tag: str, + index: Optional[int], parent: Optional[XMLNodeBase] = None, content: Optional[str] = None, empty: bool = False, attributes: Optional[Dict[str, Any]] = None, ): + if parent is None and index is not None: + raise ValueError('Cannot set index for XML node without parent.') if attributes is None: attributes = {} self.tag = tag + self.index = index self.content = content self.empty = empty self.children = [] @@ -118,6 +123,29 @@ def parent(self, parent): parent = weakref.ref(parent) self._parent = parent + def __hash__(self): + return hash((hash(self.parent), self.tag, self.index)) + + def __eq__(self, other: XMLNodeBase): + return self.parent == other.parent and self.tag == other.tag and self.index == other.index + + @property + def basename(self): + name = self.tag + if self.index is not None: + name = F'{name}[{self.index}]' + return name + + @property + def path(self): + name = self.basename + if self.parent is None: + return name + return F'{self.parent.path}/{name}' + + def __repr__(self): + return F'<{self.__class__.__name__}:{self.path}>' + def __iter__(self): return iter(self.children) @@ -161,9 +189,9 @@ class XMLNode(XMLNodeBase): source: Optional[Element] - def __init__(self, tag: str): - super().__init__(tag) - self.source = None + def __init__(self, tag: str, index: int, parent: Optional[XMLNode] = None, source: Optional[Element] = None): + super().__init__(tag, index, parent) + self.source = source def write(self, stream): """ @@ -181,16 +209,19 @@ def parse(data) -> XMLNode: tree that is generated by the standard library. """ def translate(element: Element, cursor: XMLNode, level: int = 0): + total = collections.Counter(child.tag for child in element) + count = collections.Counter() for child in element: - node = XMLNode(child.tag) + tag = child.tag + index = None if total[tag] == 1 else count[tag] + node = XMLNode(tag, index, cursor, child) + count[tag] += 1 translate(child, node, level + 1) - node.parent = cursor - node.source = child cursor.children.append(node) cursor.attributes = element.attrib cursor.content = element.text or element.tail or '' return cursor root = ForgivingParse(data).getroot() - rt = translate(root, XMLNode(root.tag)) + rt = translate(root, XMLNode(root.tag, None)) rt.source = root return rt diff --git a/refinery/units/formats/__init__.py b/refinery/units/formats/__init__.py index 70e6798f6..83a1b1922 100644 --- a/refinery/units/formats/__init__.py +++ b/refinery/units/formats/__init__.py @@ -310,50 +310,53 @@ def _make_path_builder( root: XMLNodeBase ) -> Callable[[XMLNodeBase, Optional[int]], str]: - path_attributes = Counter() - - def walk(node: XMLNodeBase): - total = 1 - for key, val in node.attributes.items(): - if re.fullmatch(R'[-\s\w+,.;@(){}]{1,64}', self._normalize_val(val)): - path_attributes[key] += 1 - for child in node.children: - total += walk(child) - return total - - total = walk(root) - - if not path_attributes: - path_attribute = None - count = 0 - else: - path_attribute, count = path_attributes.most_common(1)[0] - if 3 * count <= 2 * total: - path_attribute = None - + nfmt = self.args.format nkey = self._normalize_key nval = self._normalize_val - node_format = self.args.format + nmap = {} + + if nfmt is None: + def rank_attribute(attribute: str): + length = len(attribute) + scount = length - len(re.sub(r'\s+', '', attribute)) + return (1 / length, scount) + + def walk(node: XMLNodeBase): + candidates = [ + candidate for candidate, count in Counter( + key + for child in node.children + for key, val in child.attributes.items() + if re.fullmatch(R'[-\s\w+,.;@(){}]{2,64}', nval(val)) + ).items() + if count == len(node.children) + ] + if not candidates: + attr = None + else: + candidates.sort(key=rank_attribute) + attr = candidates[0] + for child in node.children: + nmap[child.path] = attr + walk(child) - def path_builder(node: XMLNodeBase, index: Optional[int] = None) -> str: + walk(root) + + def path_builder(node: XMLNodeBase) -> str: attrs = node.attributes - if node_format and meta: + if nfmt and meta is not None: try: - return meta.format_str( - node_format, - self.codec, - node.tag, **{ - nkey(key): nval(val) - for key, val in attrs.items() - } - ) + symbols = {nkey(key): nval(val) for key, val in attrs.items()} + return meta.format_str(nfmt, self.codec, node.tag, symbols) except KeyError: pass - if path_attribute is not None and path_attribute in attrs: - return nval(attrs[path_attribute]) - out = nval(node.tag) - if index is not None: - out = F'{out}/{index}' - return out + try: + return nval(attrs[nmap[node.path]]) + except KeyError: + index = node.index + name = nval(node.tag) + if index is not None: + name = F'{name}/{index}' + return name return path_builder diff --git a/refinery/units/formats/xml.py b/refinery/units/formats/xml.py index 04354098a..fa1af0cfa 100644 --- a/refinery/units/formats/xml.py +++ b/refinery/units/formats/xml.py @@ -25,17 +25,8 @@ def extract(node: xml.XMLNode = node): with MemoryFile() as stream: node.write(stream) return bytes(stream.getbuffer() | ppxml) - tag_pre_count = Counter() - tag_run_count = Counter() - for child in node.children: - tag_pre_count[child.tag] += 1 yield UnpackResult('/'.join(parts), extract, **node.attributes) for child in node.children: - if tag_pre_count[child.tag] == 1: - yield from walk(child, *parts, path(child)) - continue - tag_run_count[child.tag] += 1 - index = tag_run_count[child.tag] - yield from walk(child, *parts, path(child, index)) + yield from walk(child, *parts, path(child)) yield from walk(root, path(root))