From 2f33a36cd44f3d3ff002670d9d02b17a790222e5 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 13 Sep 2024 19:40:10 +0100 Subject: [PATCH] auto-generate aggregation classes --- utils/generator.py | 97 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 7 deletions(-) diff --git a/utils/generator.py b/utils/generator.py index c0a4f3a0..e60aaa5e 100644 --- a/utils/generator.py +++ b/utils/generator.py @@ -32,6 +32,7 @@ lstrip_blocks=True, ) query_py = jinja_env.get_template("query.py.tpl") +aggs_py = jinja_env.get_template("aggs.py.tpl") types_py = jinja_env.get_template("types.py.tpl") # map with name replacements for Elasticsearch attributes @@ -43,6 +44,22 @@ "_types.query_dsl:DistanceFeatureQuery": "_types.query_dsl:DistanceFeatureQueryBase", } +# some aggregation types are complicated to determine from the schema, so they +# have their correct type here +AGG_TYPES = { + "bucket_count_ks_test": "Pipeline", + "bucket_correlation": "Pipeline", + "bucket_sort": "Bucket", + "categorize_text": "Bucket", + "filter": "Bucket", + "moving_avg": "Pipeline", + "variable_width_histogram": "Bucket", +} + + +def property_to_class_name(name): + return "".join([w.title() if w != "ip" else "IP" for w in name.split("_")]) + def wrapped_doc(text, width=70, initial_indent="", subsequent_indent=""): """Formats a docstring as a list of lines of up to the request width.""" @@ -101,6 +118,18 @@ def find_type(self, name, namespace=None): ): return t + def inherits_from(self, type_, name, namespace=None): + while "inherits" in type_: + type_ = self.find_type( + type_["inherits"]["type"]["name"], + type_["inherits"]["type"]["namespace"], + ) + if type_["name"]["name"] == name and ( + namespace is None or type_["name"]["namespace"] == namespace + ): + return True + return False + def get_python_type(self, schema_type): """Obtain Python typing details for a given schema type @@ -156,7 +185,9 @@ def get_python_type(self, schema_type): # for dicts we use Mapping[key_type, value_type] key_type, key_param = self.get_python_type(schema_type["key"]) value_type, value_param = self.get_python_type(schema_type["value"]) - return f"Mapping[{key_type}, {value_type}]", None + return f"Mapping[{key_type}, {value_type}]", ( + {**value_param, "hash": True} if value_param else None + ) elif schema_type["kind"] == "union_of": if ( @@ -334,17 +365,38 @@ def property_to_python_class(self, p): """ k = { "property_name": p["name"], - "name": "".join([w.title() for w in p["name"].split("_")]), + "name": property_to_class_name(p["name"]), } k["docstring"] = wrapped_doc(p.get("description") or "") + other_classes = [] kind = p["type"]["kind"] if kind == "instance_of": namespace = p["type"]["type"]["namespace"] name = p["type"]["type"]["name"] if f"{namespace}:{name}" in TYPE_REPLACEMENTS: namespace, name = TYPE_REPLACEMENTS[f"{namespace}:{name}"].split(":") - type_ = schema.find_type(name, namespace) + if name == "QueryContainer" and namespace == "_types.query_dsl": + type_ = { + "kind": "interface", + "properties": [p], + } + else: + type_ = schema.find_type(name, namespace) + if p["name"] in AGG_TYPES: + k["parent"] = AGG_TYPES[p["name"]] + if type_["kind"] == "interface": + # set the correct parent for bucket and pipeline aggregations + if self.inherits_from( + type_, "PipelineAggregationBase", "_types.aggregations" + ): + k["parent"] = "Pipeline" + elif self.inherits_from( + type_, "BucketAggregationBase", "_types.aggregations" + ): + k["parent"] = "Bucket" + + # generate class attributes k["args"] = [] k["params"] = [] if "behaviors" in type_: @@ -397,6 +449,21 @@ def property_to_python_class(self, p): ) else: break + + elif type_["kind"] == "type_alias": + if type_["type"]["kind"] == "union_of": + # for unions we create sub-classes + for other in type_["type"]["items"]: + other_class = self.interface_to_python_class( + other["type"]["name"], self.interfaces, for_types_py=False + ) + other_class["parent"] = k["name"] + other_classes.append(other_class) + else: + raise RuntimeError( + "Cannot generate code for instances of type_alias instances that are not unions." + ) + else: raise RuntimeError( f"Cannot generate code for instances of kind '{type_['kind']}'" @@ -444,9 +511,9 @@ def property_to_python_class(self, p): else: raise RuntimeError(f"Cannot generate code for type {p['type']}") - return k + return [k] + other_classes - def interface_to_python_class(self, interface, interfaces): + def interface_to_python_class(self, interface, interfaces, for_types_py=True): """Return a dictionary with template data necessary to render an interface a Python class. @@ -477,7 +544,7 @@ def interface_to_python_class(self, interface, interfaces): k = {"name": interface, "args": []} while True: for arg in type_["properties"]: - schema.add_attribute(k, arg, for_types_py=True) + schema.add_attribute(k, arg, for_types_py=for_types_py) if "inherits" not in type_ or "type" not in type_["inherits"]: break @@ -500,13 +567,28 @@ def generate_query_py(schema, filename): classes = [] query_container = schema.find_type("QueryContainer", "_types.query_dsl") for p in query_container["properties"]: - classes.append(schema.property_to_python_class(p)) + classes += schema.property_to_python_class(p) with open(filename, "wt") as f: f.write(query_py.render(classes=classes, parent="Query")) print(f"Generated {filename}.") +def generate_aggs_py(schema, filename): + """Generate aggs.py with all the properties of `AggregationContainer` as + Python classes. + """ + classes = [] + aggs_container = schema.find_type("AggregationContainer", "_types.aggregations") + for p in aggs_container["properties"]: + if "containerProperty" not in p or not p["containerProperty"]: + classes += schema.property_to_python_class(p) + + with open(filename, "wt") as f: + f.write(aggs_py.render(classes=classes, parent="Agg")) + print(f"Generated {filename}.") + + def generate_types_py(schema, filename): """Generate types.py""" classes = {} @@ -542,4 +624,5 @@ def generate_types_py(schema, filename): if __name__ == "__main__": schema = ElasticsearchSchema() generate_query_py(schema, "elasticsearch_dsl/query.py") + generate_aggs_py(schema, "elasticsearch_dsl/aggs.py") generate_types_py(schema, "elasticsearch_dsl/types.py")