Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove djsonb dependency and move its relevant logic into Grout core #7

Merged
merged 7 commits into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Changes we're making en route to version 2.0. When we cut a release, this
section will form the changelog for v2.0.

- Removed external `djsonb` dependency and moved its lookup logic into
Grout core ([#7](https://github.com/azavea/grout/pull/7)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎆

- Removed extraneous location fields from the `Record` data model (
[#5](https://github.com/azavea/grout/pull/5)), including:
- `city`
Expand Down
5 changes: 2 additions & 3 deletions grout/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.contrib.gis.db import models as gis_models
from django.contrib.gis.gdal.error import GDALException
from django.contrib.postgres.fields import JSONField
from django.db.models import Q

from rest_framework.exceptions import ParseError, NotFound
Expand All @@ -16,12 +17,10 @@
from grout.models import Boundary, BoundaryPolygon, Record, RecordType
from grout.exceptions import QueryParameterException

from djsonb import fields as jsb


# Map custom fields to CharField so that django-filter knows how to handle them.
FILTER_OVERRIDES = {
jsb.JsonBField: {
JSONField: {
'filter_class': django_filters.CharFilter
},
gis_models.PointField: {
Expand Down
293 changes: 293 additions & 0 deletions grout/lookups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
import json
import re
import shlex

from django.db.models import Lookup
from django.contrib.postgres.fields import JSONField


class FilterTree:
"""This class should properly assemble the pieces necessary to write the WHERE clause of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: could this block comment be rephrased to open with a single line, separated from the rest with a blank line, per the docstring PEP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do. I'll flesh out all the docstrings in this module.

a postgres query
The jsonb_filter_field property of your view should designate the
name of the column to filter by.
Manually filtering by way of Django's ORM might look like:
Something.objects.filter(<jsonb_field>__jsonb=<filter_specification>)

Check out the tests for some real examples"""
def __init__(self, tree, field):
self.field = field
self.tree = tree
self.sql_generators = {
"intrange": FilterTree.intrange_filter,
"containment": FilterTree.containment_filter,
"containment_multiple": FilterTree.multiple_containment_filter
}
self.rules = self.get_rules(self.tree)

def is_rule(self, obj):
"""Check for bottoming out the recursion in `get_rules`"""
if '_rule_type' in obj and obj['_rule_type'] in self.sql_generators:
return True
return False

def get_rules(self, obj, current_path=[]):
"""Recursively crawl a dict looking for filtering rules"""
# If node isn't a rule or dictionary
if type(obj) != dict:
return []

# If node is a rule return its location and its details
if self.is_rule(obj):
return [([self.field] + current_path, obj)]

rules = []
for path, val in obj.items():
rules = rules + self.get_rules(val, current_path + [path])
return rules

def sql(self):
"""Produce output that can be compiled into SQL by Django and psycopg2.

The format of the output should be a tuple of a (template) string followed by a list
of parameters for compiling that template
"""
rule_specs = []

patterns = {}
pattern_specs = []

for rule in self.rules:
# If not a properly registered rule type
if not self.is_rule(rule[1]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know index 1 will be in bounds for rule? Could it be made more obvious what data is at the 0 and 1 offsets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! Based on the parameter list for the methods that use these values (like FilterTree.containment_filter), it looks like rule[0] corresponds to the path of the nested query, while rule[1] corresponds to the rule itself. Both elements have to exist, since self.get_rules can only return either A) an empty list or B) a list of two-tuples.

I'll add this info in a comment and redo the iteration assignment to change it from:

for rule in self.rules:

To something more expressive;

for path, rule in self.rules:

pass
rule_type = rule[1]['_rule_type']
sql_tuple = self.sql_generators[rule_type](rule[0], rule[1])
if sql_tuple is not None:
rule_specs.append(sql_tuple)

# The check on 'pattern' here allows us to apply a pattern filter on top of others
if 'pattern' in rule[1]:
# Don't filter as an exact match on the text entered; match per word.
for pattern in shlex.split(rule[1]['pattern']):
if rule[1]['_rule_type'] == 'containment_multiple':
sql_tuple = FilterTree.text_similarity_filter(rule[0], pattern, True)
else:
sql_tuple = FilterTree.text_similarity_filter(rule[0], pattern, False)
# add to the list of rules generated for this pattern (one per field)
patterns.setdefault(pattern, []).append(sql_tuple)

rule_string = ' AND '.join([rule[0] for rule in rule_specs])

pattern_rules = patterns.values()
pattern_strings = []

# check if any of the fields for this string pattern match
for rule_list in pattern_rules:
pattern_strings.append(' OR '.join([rule[0] for rule in rule_list]))
pattern_specs += rule_list

# check that record has a match for all of the string patterns in some field
pattern_string = '(' + ') AND ('.join(pattern_strings) + ')' if pattern_strings else ''

if rule_string != '' and pattern_string != '':
filter_string = '(' + (' AND ('.join([rule_string, pattern_string])) + ')' + ')'
elif rule_string != '' or pattern_string != '':
filter_string = '(' + ''.join([rule_string, pattern_string]) + ')'
else:
filter_string = ''

# flatten the rule_paths
rule_paths_first = ([rule[1] for rule in rule_specs] +
[rule[1] for rule in pattern_specs])
rule_paths = [item for sublist in rule_paths_first
for item in sublist]

outcome = (filter_string, tuple(rule_paths))
return outcome

# Filters
@classmethod
def containment_filter(cls, path, rule):
"""Filter for objects that contain the specified value at some location"""
template = reconstruct_object(path[1:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the questions about rule above, could it be made more obvious what the slices of path contain, and why it's safe to assume path has a length of at least two?

has_containment = 'contains' in rule
abstract_contains_str = path[0] + " @> %s"

if has_containment:
all_contained = rule.get('contains')
else:
return None

contains_params = []
json_path = [json.dumps(x) for x in path[1:]]
for contained in all_contained:
interpolants = tuple(json_path + [json.dumps(contained)])
contains_params.append(template % interpolants)

contains_str = ' OR '.join([abstract_contains_str] * len(all_contained))

if contains_str != '':
return ('(' + contains_str + ')', contains_params)
else:
return None

@classmethod
def multiple_containment_filter(cls, path, rule):
"""Filter for objects that contain the specified value in any of the objects in a
given list"""
template = reconstruct_object_multiple(path[1:])
has_containment = 'contains' in rule
abstract_contains_str = path[0] + " @> %s"

if has_containment:
all_contained = rule.get('contains')
else:
return None

contains_params = []
json_path = [json.dumps(x) for x in path[1:]]
for contained in all_contained:
interpolants = tuple(json_path + [json.dumps(contained)])
contains_params.append(template % interpolants)

contains_str = ' OR '.join([abstract_contains_str] * len(all_contained))

if contains_str != '':
return ('(' + contains_str + ')', contains_params)
else:
return None

@classmethod
def intrange_filter(cls, path, rule):
"""Filter for numbers that match boundaries provided by a rule"""
traversed_int = "(" + extract_value_at_path(path) + ")::int"
has_min = 'min' in rule and rule['min'] is not None
has_max = 'max' in rule and rule['max'] is not None

if has_min:
minimum = rule['min']
more_than = ("{traversal_int} >= %s"
.format(traversal_int=traversed_int))
if has_max:
maximum = rule['max']
less_than = ("{traversal_int} <= %s"
.format(traversal_int=traversed_int))

if has_min and not has_max:
sql_template = '(' + more_than + ')'
return (sql_template, path[1:] + [minimum])
elif has_max and not has_min:
sql_template = '(' + less_than + ')'
return (sql_template, path[1:] + [maximum])
elif has_max and has_min:
sql_template = '(' + less_than + ' AND ' + more_than + ')'
return (sql_template, path[1:] + [maximum] + path[1:] + [minimum])
else:
return None

@classmethod
def text_similarity_filter(cls, path, pattern, path_multiple=False):
"""Filter for objects that contain members (at the specified addresses)
which match against a provided pattern
If path_multiple is true, this function generates a regular expression to parse
the json array of objects. This regular expression works by finding the key and
attempting to match a string against that key's associated value. This unfortunate
use of regex is necessitated by Postgres' inability to iterate in a WHERE clause
and the requirement that we deal with records that have multiple related objects."""
has_similarity = pattern is not None
if not has_similarity:
return None

if path_multiple:
traversed_text = "(" + extract_value_at_path(path[:-1]) + ")"
else:
traversed_text = "(" + extract_value_at_path(path) + ")"

sql_template = ("{traversed_text}::text ~* %s"
.format(traversed_text=traversed_text))

if path_multiple:
return (sql_template, path[1:-1] + ['{key}": "([^"]*?{val}.*?)"'
.format(key=re.escape(path[-1]),
val=re.escape(pattern))])
else:
return (sql_template, path[1:] + [re.escape(pattern)])


# Utility functions
def extract_value_at_path(path):
return operator_at_traversal_path(path, '->>')


# N.B. This only returns useful query snippets if the parent path
# exists. That is, if you try to query "a"->"b"?"c" but your objects don't have a
# "b" key, you will always get zero rows back, whereas if they do have a "b" key, then
# you will get true if it contains "c" and false otherwise.
def contains_key_at_path(path):
return operator_at_traversal_path(path, '?')


def operator_at_traversal_path(path, op):
"""Construct traversal instructions for Postgres from a list of nodes; apply op as last step
like: '%s->%S->%s->>%s' for path={a: {b: {c: value } } }, op='->>'

Don't use this unless extract_value_at_path and contains_key_at_path don't work for you
"""
fmt_strs = [path[0]] + ['%s' for leaf in path[1:]]
traversal = '->'.join(fmt_strs[:-1]) + '{op}%s'.format(op=op)
return traversal


def reconstruct_object(path):
"""Reconstruct the object from root to leaf, recursively"""
if len(path) == 0:
return '%s'
else:
# The indexed query on `path` below is the means by which we recurse
# Every iteration pushes it closer to a length of 0 and, thus, bottoming out
return '{{%s: {recons}}}'.format(recons=reconstruct_object(path[1:]))


def reconstruct_object_multiple(path):
"""Reconstruct the object from root to leaf, recursively"""
if len(path) == 0:
return '%s'
elif len(path) == 2:
return '{{%s: [{recons}]}}'.format(recons=reconstruct_object_multiple(path[1:]))
else:
# The indexed query on `path` below is the means by which we recurse
# Every iteration pushes it closer to a length of 0 and, thus, bottoming out
# This function differs from the singular reconstruction in that the final object
# gets wrapped in a list (when length is 2, there should be a key and a value left)
return '{{%s: {recons}}}'.format(recons=reconstruct_object_multiple(path[1:]))


class DriverLookup(Lookup):
lookup_name = 'jsonb'

def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)

return FilterTree(rhs_params[0], lhs).sql()


@JSONField.register_lookup
class JSONLookup(Lookup):
lookup_name = 'jsonb'

def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)

field = lhs
# JSONField formats query values for the database by wrapping them psycopg2's
# JsonAdapter, but we need the raw Python dict so that we can parse the
# query tree. Intercept the query parameter (it'll always be the first
# element in the parameter list, since the jsonb filter only accepts one argument)
# and revert it back to a Python dict for tree parsing.
tree = rhs_params[0].adapted

return FilterTree(tree, field).sql()
8 changes: 4 additions & 4 deletions grout/migrations/0001_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from __future__ import unicode_literals

from django.db import models, migrations
import djsonb.fields
import django.contrib.gis.db.models.fields
from django.contrib.postgres.fields import JSONField
import uuid


Expand All @@ -20,7 +20,7 @@ class Migration(migrations.Migration):
('created', models.DateTimeField(auto_now_add=True)),
('modified', models.DateTimeField(auto_now=True)),
('version', models.IntegerField()),
('schema', djsonb.fields.JsonBField()),
('schema', JSONField()),
('label', models.CharField(max_length=50)),
('slug', models.CharField(unique=True, max_length=50)),
],
Expand All @@ -36,7 +36,7 @@ class Migration(migrations.Migration):
('label', models.CharField(max_length=50)),
('slug', models.CharField(max_length=50)),
('geom', django.contrib.gis.db.models.fields.PointField(srid=4326)),
('data', djsonb.fields.JsonBField()),
('data', JSONField()),
],
options={
'abstract': False,
Expand All @@ -49,7 +49,7 @@ class Migration(migrations.Migration):
('created', models.DateTimeField(auto_now_add=True)),
('modified', models.DateTimeField(auto_now=True)),
('version', models.IntegerField()),
('schema', djsonb.fields.JsonBField()),
('schema', JSONField()),
('record_type', models.CharField(max_length=50)),
],
),
Expand Down
4 changes: 2 additions & 2 deletions grout/migrations/0004_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import unicode_literals

from django.db import models, migrations
import djsonb.fields
from django.contrib.postgres.fields import JSONField
import uuid
import django.contrib.gis.db.models.fields

Expand All @@ -22,7 +22,7 @@ class Migration(migrations.Migration):
('modified', models.DateTimeField(auto_now=True)),
('status', models.CharField(default=b'pen', max_length=10, choices=[(b'pen', b'Pending'), (b'pro', b'Processing'), (b'war', b'Warning'), (b'err', b'Error'), (b'com', b'Complete')])),
('label', models.CharField(max_length=64)),
('errors', djsonb.fields.JsonBField(null=True, blank=True)),
('errors', JSONField(null=True, blank=True)),
('source_file', models.FileField(upload_to=b'boundaries/%Y/%m/%d')),
('geom', django.contrib.gis.db.models.fields.MultiPolygonField(srid=4326, null=True, blank=True)),
],
Expand Down
Loading