Skip to content

Commit

Permalink
auto-generate aggregation classes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 13, 2024
1 parent bacfc74 commit 588f58e
Show file tree
Hide file tree
Showing 2 changed files with 736 additions and 7 deletions.
97 changes: 90 additions & 7 deletions utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
lstrip_blocks=True,
)
query_py = jinja_env.get_template("query.py.tpl")
aggs_py = jinja_env.get_template("aggs.py.tpl")
types_py = jinja_env.get_template("types.py.tpl")

# map with name replacements for Elasticsearch attributes
Expand All @@ -43,6 +44,22 @@
"_types.query_dsl:DistanceFeatureQuery": "_types.query_dsl:DistanceFeatureQueryBase",
}

# some aggregation types are complicated to determine from the schema, so they
# have their correct type here
AGG_TYPES = {
"bucket_count_ks_test": "Pipeline",
"bucket_correlation": "Pipeline",
"bucket_sort": "Bucket",
"categorize_text": "Bucket",
"filter": "Bucket",
"moving_avg": "Pipeline",
"variable_width_histogram": "Bucket",
}


def property_to_class_name(name):
return "".join([w.title() if w != "ip" else "IP" for w in name.split("_")])


def wrapped_doc(text, width=70, initial_indent="", subsequent_indent=""):
"""Formats a docstring as a list of lines of up to the request width."""
Expand Down Expand Up @@ -101,6 +118,18 @@ def find_type(self, name, namespace=None):
):
return t

def inherits_from(self, type_, name, namespace=None):
while "inherits" in type_:
type_ = self.find_type(
type_["inherits"]["type"]["name"],
type_["inherits"]["type"]["namespace"],
)
if type_["name"]["name"] == name and (
namespace is None or type_["name"]["namespace"] == namespace
):
return True
return False

def get_python_type(self, schema_type):
"""Obtain Python typing details for a given schema type
Expand Down Expand Up @@ -156,7 +185,9 @@ def get_python_type(self, schema_type):
# for dicts we use Mapping[key_type, value_type]
key_type, key_param = self.get_python_type(schema_type["key"])
value_type, value_param = self.get_python_type(schema_type["value"])
return f"Mapping[{key_type}, {value_type}]", None
return f"Mapping[{key_type}, {value_type}]", (
{**value_param, "hash": True} if value_param else None
)

elif schema_type["kind"] == "union_of":
if (
Expand Down Expand Up @@ -334,17 +365,38 @@ def property_to_python_class(self, p):
"""
k = {
"property_name": p["name"],
"name": "".join([w.title() for w in p["name"].split("_")]),
"name": property_to_class_name(p["name"]),
}
k["docstring"] = wrapped_doc(p.get("description") or "")
other_classes = []
kind = p["type"]["kind"]
if kind == "instance_of":
namespace = p["type"]["type"]["namespace"]
name = p["type"]["type"]["name"]
if f"{namespace}:{name}" in TYPE_REPLACEMENTS:
namespace, name = TYPE_REPLACEMENTS[f"{namespace}:{name}"].split(":")
type_ = schema.find_type(name, namespace)
if name == "QueryContainer" and namespace == "_types.query_dsl":
type_ = {
"kind": "interface",
"properties": [p],
}
else:
type_ = schema.find_type(name, namespace)
if p["name"] in AGG_TYPES:
k["parent"] = AGG_TYPES[p["name"]]

if type_["kind"] == "interface":
# set the correct parent for bucket and pipeline aggregations
if self.inherits_from(
type_, "PipelineAggregationBase", "_types.aggregations"
):
k["parent"] = "Pipeline"
elif self.inherits_from(
type_, "BucketAggregationBase", "_types.aggregations"
):
k["parent"] = "Bucket"

# generate class attributes
k["args"] = []
k["params"] = []
if "behaviors" in type_:
Expand Down Expand Up @@ -397,6 +449,21 @@ def property_to_python_class(self, p):
)
else:
break

elif type_["kind"] == "type_alias":
if type_["type"]["kind"] == "union_of":
# for unions we create sub-classes
for other in type_["type"]["items"]:
other_class = self.interface_to_python_class(
other["type"]["name"], self.interfaces, for_types_py=False
)
other_class["parent"] = k["name"]
other_classes.append(other_class)
else:
raise RuntimeError(
"Cannot generate code for instances of type_alias instances that are not unions."
)

else:
raise RuntimeError(
f"Cannot generate code for instances of kind '{type_['kind']}'"
Expand Down Expand Up @@ -444,9 +511,9 @@ def property_to_python_class(self, p):

else:
raise RuntimeError(f"Cannot generate code for type {p['type']}")
return k
return [k] + other_classes

def interface_to_python_class(self, interface, interfaces):
def interface_to_python_class(self, interface, interfaces, for_types_py=True):
"""Return a dictionary with template data necessary to render an
interface a Python class.
Expand Down Expand Up @@ -477,7 +544,7 @@ def interface_to_python_class(self, interface, interfaces):
k = {"name": interface, "args": []}
while True:
for arg in type_["properties"]:
schema.add_attribute(k, arg, for_types_py=True)
schema.add_attribute(k, arg, for_types_py=for_types_py)

if "inherits" not in type_ or "type" not in type_["inherits"]:
break
Expand All @@ -500,13 +567,28 @@ def generate_query_py(schema, filename):
classes = []
query_container = schema.find_type("QueryContainer", "_types.query_dsl")
for p in query_container["properties"]:
classes.append(schema.property_to_python_class(p))
classes += schema.property_to_python_class(p)

with open(filename, "wt") as f:
f.write(query_py.render(classes=classes, parent="Query"))
print(f"Generated {filename}.")


def generate_aggs_py(schema, filename):
"""Generate aggs.py with all the properties of `AggregationContainer` as
Python classes.
"""
classes = []
aggs_container = schema.find_type("AggregationContainer", "_types.aggregations")
for p in aggs_container["properties"]:
if "containerProperty" not in p or not p["containerProperty"]:
classes += schema.property_to_python_class(p)

with open(filename, "wt") as f:
f.write(aggs_py.render(classes=classes, parent="Agg"))
print(f"Generated {filename}.")


def generate_types_py(schema, filename):
"""Generate types.py"""
classes = {}
Expand Down Expand Up @@ -542,4 +624,5 @@ def generate_types_py(schema, filename):
if __name__ == "__main__":
schema = ElasticsearchSchema()
generate_query_py(schema, "elasticsearch_dsl/query.py")
generate_aggs_py(schema, "elasticsearch_dsl/aggs.py")
generate_types_py(schema, "elasticsearch_dsl/types.py")
Loading

0 comments on commit 588f58e

Please sign in to comment.