Skip to content

Commit

Permalink
Support also text-returning aggregate function such as SQLite group_c…
Browse files Browse the repository at this point in the history
…oncat.
  • Loading branch information
wenzeslaus committed May 4, 2023
1 parent 6a85434 commit d2cd66f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
29 changes: 28 additions & 1 deletion scripts/v.dissolve/tests/v_dissolve_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_aggregate_column_result(dataset, backend):


def test_sqlite_agg_accepted(dataset):
"""Numeric SQLite aggregate function are accepted
"""Numeric SQLite aggregate functions are accepted
Additionally, it checks:
1. generated column names
Expand Down Expand Up @@ -244,6 +244,33 @@ def test_sqlite_agg_accepted(dataset):
assert sorted(aggregate_n) == [1, 2, 3]


def test_sqlite_concat(dataset):
"""SQLite concat text-returning aggregate function works"""
dissolved_vector = "test_sqlite_concat"
gs.run_command(
"v.dissolve",
input=dataset.vector_name,
column=dataset.str_column_name,
output=dissolved_vector,
aggregate_column=f"group_concat({dataset.int_column_name})",
result_column="concat_values text",
aggregate_backend="sql",
)
records = json.loads(
gs.read_command(
"v.db.select",
map=dissolved_vector,
format="json",
)
)["records"]
# Order of records is ignored - they are just sorted.
# Order within values of group_concat is defined as arbitrary by SQLite.
expected_integers = sorted(["10", "10,10,24", "5,5"])
actual_integers = sorted([record["concat_values"] for record in records])
for expected, actual in zip(expected_integers, actual_integers):
assert sorted(expected.split(",")) == sorted(actual.split(","))


def test_duplicate_columns_and_methods_accepted(dataset):
"""Duplicate aggregate columns and methods are accepted and deduplicated"""
dissolved_vector = "test_duplicates"
Expand Down
62 changes: 56 additions & 6 deletions scripts/v.dissolve/v.dissolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,33 @@ def modify_methods_for_backend(methods, backend):
return new_methods


def quote_from_type(column_type):
"""Returns quote if column values need to be quoted based on their type
Defaults to quoting for unknown types and no quoting for falsely values,
i.e., unknown types are assumed to be in need of quoting while missing type
information is assumed to be associated with numbers which don't need quoting.
"""
# Needs a general solution, e.g., https://github.com/OSGeo/grass/pull/1110
if not column_type or column_type.upper() in [
"INT",
"INTEGER",
"SMALLINT",
"REAL",
"DOUBLE",
"DOUBLE PRECISION",
]:
return ""
return "'"


def updates_to_sql(table, updates):
"""Create SQL from a list of dicts with column, value, where"""
sql = ["BEGIN TRANSACTION"]
for update in updates:
quote = quote_from_type(update.get("type", None))
sql.append(
f"UPDATE {table} SET {update['column']} = {update['value']} "
f"UPDATE {table} SET {update['column']} = {quote}{update['value']}{quote} "
f"WHERE {update['where']};"
)
sql.append("END TRANSACTION")
Expand Down Expand Up @@ -215,6 +236,27 @@ def check_aggregate_methods_or_fatal(methods, backend):
# and open for SQLite depending on its extensions.


def aggregate_columns_exist_or_fatal(vector, layer, columns):
"""Check that all columns exist or end with fatal error"""
column_names = gs.vector_columns(vector, layer).keys()
for column in columns:
if column not in column_names:
if "(" in column:
gs.fatal(
_(
"Column <{column}> does not exist in vector <{vector}> "
"(layer <{layer}>). Specify result columns if you are adding "
"function calls to aggregate columns."
).format(vector=vector, layer=layer, column=column)
)
gs.fatal(
_(
"Column <{column}> selected for aggregation does not exist "
"in vector <{vector}> (layer <{layer}>)"
).format(vector=vector, layer=layer, column=column)
)


def match_columns_and_methods(columns, methods):
"""Return all combinations of columns and methods
Expand Down Expand Up @@ -351,19 +393,26 @@ def aggregate_attributes_sql(
for result_column, column_type in zip(result_columns, column_types):
add_columns.append(f"{result_column} {column_type}")
else:
add_columns = result_columns.copy()
# Column types are part of the result column name list.
add_columns = result_columns.copy() # Ensure we have our own copy.
# Split column definitions into two lists.
result_columns = []
column_types = []
for definition in add_columns:
column_name, column_type = definition.split(" ", maxsplit=1)
result_columns.append(column_name)
column_types.append(column_type)
for row in records:
where = column_value_to_where(column, row[column], quote=quote_column)
for (
result_column,
column_type,
key,
) in zip(result_columns, select_columns):
if not column_types:
# Column types are part of the result column name list.
result_column = result_column.split(" ", maxsplit=1)[0]
) in zip(result_columns, column_types, select_columns):
updates.append(
{
"column": result_column,
"type": column_type,
"value": row[key],
"where": where,
}
Expand Down Expand Up @@ -470,6 +519,7 @@ def main():
user_aggregate_methods, aggregate_backend, provide_defaults=not result_columns
)
if not result_columns:
aggregate_columns_exist_or_fatal(input_vector, layer, columns_to_aggregate)
columns_to_aggregate, user_aggregate_methods = match_columns_and_methods(
columns_to_aggregate, user_aggregate_methods
)
Expand Down

0 comments on commit d2cd66f

Please sign in to comment.