Skip to content

Django style filtering for GetEntitySetRequest #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ venv
dist
.idea
.coverage
env/
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Fixed
- URL encode $filter contents - Barton Ip
- Headers attribute on ODataHttpRequest - Barton Ip

## [1.5.0]

Expand Down
30 changes: 30 additions & 0 deletions docs/usage/querying.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,36 @@ Print unique identification (Id) of all employees with name John Smith:
print(smith.EmployeeID)


Get entities matching a filter in ORM style
---------------------------------------------------

Print unique identification (Id) of all employees with name John Smith:

.. code-block:: python

from pyodata.v2.service import GetEntitySetFilter as esf

smith_employees_request = northwind.entity_sets.Employees.get_entities()
smith_employees_request = smith_employees_request.filter(FirstName="John", LastName="Smith")
for smith in smith_employees_request.execute():
print(smith.EmployeeID)


Get entities matching a complex filter in ORM style
---------------------------------------------------

Print unique identification (Id) of all employees with name John Smith:

.. code-block:: python

from pyodata.v2.service import GetEntitySetFilter as esf

smith_employees_request = northwind.entity_sets.Employees.get_entities()
smith_employees_request = smith_employees_request.filter(FirstName__contains="oh", LastName__startswith="Smi")
for smith in smith_employees_request.execute():
print(smith.EmployeeID)


Get a count of entities
-----------------------

Expand Down
243 changes: 224 additions & 19 deletions pyodata/v2/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(self, url, connection, handler, headers=None):
self._connection = connection
self._url = url
self._handler = handler
self._headers = headers
self._headers = headers or dict()
self._logger = logging.getLogger(LOGGER_NAME)

@property
Expand Down Expand Up @@ -267,10 +267,32 @@ def get_body(self):
# pylint: disable=no-self-use
return None

def get_headers(self):
"""Get dict of HTTP headers"""
def get_default_headers(self):
"""Get dict of Child specific HTTP headers"""
# pylint: disable=no-self-use
return None
return dict()

def get_headers(self):
"""Get dict of HTTP headers which is union of return value
of the method get_default_headers() and the headers
added via the method add_headers() where the latter
headers have priority - same keys get value of the latter.
"""

headers = self.get_default_headers()
headers.update(self._headers)

return headers

def add_headers(self, value):
"""Add the give dictionary of HTTP headers to
HTTP request sent by this ODataHttpRequest instance.
"""

if not isinstance(value, dict):
raise TypeError("Headers must be of type 'dict' not {}".format(type(value)))

self._headers.update(value)

def execute(self):
"""Fetches HTTP response and returns processed result
Expand All @@ -284,21 +306,17 @@ def execute(self):
# pylint: disable=assignment-from-none
body = self.get_body()

headers = {} if self._headers is None else self._headers

# pylint: disable=assignment-from-none
extra_headers = self.get_headers()
if extra_headers is not None:
headers.update(extra_headers)
headers = self.get_headers()

self._logger.debug('Send (execute) %s request to %s', self.get_method(), url)
self._logger.debug(' query params: %s', self.get_query_params())
self._logger.debug(' headers: %s', headers)
if body:
self._logger.debug(' body: %s', body)

params = "&".join("%s=%s" % (k, v) for k, v in self.get_query_params().items())
response = self._connection.request(
self.get_method(), url, headers=headers, params=self.get_query_params(), data=body)
self.get_method(), url, headers=headers, params=params, data=body)

self._logger.debug('Received response')
self._logger.debug(' url: %s', response.url)
Expand Down Expand Up @@ -350,7 +368,7 @@ def expand(self, expand):
def get_path(self):
return self._entity_set_proxy.last_segment + self._entity_key.to_key_string()

def get_headers(self):
def get_default_headers(self):
return {'Accept': 'application/json'}

def get_query_params(self):
Expand Down Expand Up @@ -447,7 +465,7 @@ def _get_body(self):
def get_body(self):
return json.dumps(self._get_body())

def get_headers(self):
def get_default_headers(self):
return {'Accept': 'application/json', 'Content-Type': 'application/json', 'X-Requested-With': 'X'}

@staticmethod
Expand Down Expand Up @@ -547,7 +565,7 @@ def get_body(self):
body[key] = val
return json.dumps(body)

def get_headers(self):
def get_default_headers(self):
return {'Accept': 'application/json', 'Content-Type': 'application/json'}

def set(self, **kwargs):
Expand Down Expand Up @@ -609,7 +627,7 @@ def filter(self, filter_val):
self._filter = quote(filter_val)
return self

# def nav(self, key_value, nav_property):
# def nav(self, key_value, nav_property):æ
# """Navigates to a referenced collection using a collection-valued navigation property."""
# # returns QueryRequest
# raise NotImplementedError
Expand Down Expand Up @@ -640,7 +658,7 @@ def get_path(self):

return self._last_segment

def get_headers(self):
def get_default_headers(self):
if self._count:
return {}

Expand Down Expand Up @@ -703,9 +721,9 @@ def parameter(self, name, value):
def get_method(self):
return self._function_import.http_method

def get_headers(self):
def get_default_headers(self):
return {
'Accept': 'application/json',
'Accept': 'application/json'
}


Expand Down Expand Up @@ -976,6 +994,176 @@ def __gt__(self, value):
return GetEntitySetFilter.format_filter(self._proprty, 'gt', value)


class FilterExpression:
"""A filter expression object comparable to Django's Q class"""
def __init__(self, **kwargs):
self.expressions = kwargs
self.other = None
self.operator = None

def __or__(self, other):
self.other = other
self.operator = "or"
return self

def __and__(self, other):
self.other = other
self.operator = "and"
return self


class GetEntitySetFilterChainable:
"""
Example expressions
FirstName="Tim"
FirstName__contains="Tim"
Age__gt=56
Age__gte=6
Age__lt=78
Age__lte=90
Age__range=(5,9)
FirstName__in=["Tim", "Bob", "Sam"]
FirstName__startswith="Tim"
FirstName__endswith="mothy"
Addresses__Suburb="Chatswood"
Addresses__Suburb__contains="wood"
"""

operators = [
"startswith",
"endswith",
"lt",
"lte",
"gt",
"gte",
"contains",
"range",
"in",
"length",
"eq"
]

def __init__(self, request, filter_expressions, exprs):
self.request = request
self.expressions = exprs
self.filter_expressions = filter_expressions

def proprty_obj(self, name):
"""Returns value for a particular proprty"""
return self.request._entity_type.proprty(name) # pylint: disable=protected-access

def process_query_objects(self):
"""Processes FilterExpression objects to OData lookups"""
filter_expressions = []
for filter_expression in self.filter_expressions:
lhs_expressions = []
rhs_expressions = []
for expr, val in filter_expression.expressions.items():
lhs_expressions.append(self.decode_expression(expr, val))
lhs_expression = self.combine_expressions(lhs_expressions)

if filter_expression.other:
for expr, val in filter_expression.other.expressions.items():
rhs_expressions.append(self.decode_expression(expr, val))
rhs_expression = self.combine_expressions(rhs_expressions)

filter_expressions.append(
f"({lhs_expression}) {filter_expression.operator} ({rhs_expression})"
)
else:
filter_expressions.append(lhs_expression)

return filter_expressions

def process_expressions(self):
"""Processes filter kwargs into OData expressions"""
filter_expressions = []
for expr, val in self.expressions.items():
filter_expressions.append(self.decode_expression(expr, val))

filter_expressions.extend(self.process_query_objects())
return filter_expressions

def decode_expression(self, expr, val):
"""Decodes Django-like syntax into OData expressions"""
properties = self.request._entity_type._properties.keys() # pylint: disable=protected-access
field = None
# field_heirarchy = []
operator = "eq"
exprs = expr.split("__")

for part in exprs:
if part in properties:
field = part
# field_heirarchy.append(part)
elif part in self.__class__.operators:
operator = part
else:
raise ValueError("'{}' is not a valid property or operator".format(part))
# field = "/".join(field_heirarchy)

# target_field = self.proprty_obj(field_heirarchy[-1])
expression = self.build_expression(field, operator, val)

return expression

def combine_expressions(self, expressions):
"""Combines expressions"""
# pylint: disable=no-self-use
return " and ".join(expressions)

def build_expression(self, field_name, operator, value):
"""Builds expression from Django-like operator"""
# pylint: disable=too-many-branches, too-many-return-statements, no-else-return
target_field = self.proprty_obj(field_name)
if operator not in ["length", "in", "range"]:
value = target_field.to_literal(value)
if operator == "lt":
return f"{field_name} lt {value}"
elif operator == "lte":
return f"{field_name} le {value}"
elif operator == "gte":
return f"{field_name} ge {value}"
elif operator == "gt":
return f"{field_name} gt {value}"
elif operator == "startswith":
return f"startswith({field_name}, {value}) eq true"
elif operator == "endswith":
return f"endswith({field_name}, {value}) eq true"
elif operator == "length":
value = int(value)
return f"length({field_name}) eq {value}"
elif operator in ["contains"]:
return f"substringof({value}, {field_name}) eq true"
elif operator == "range":
if not isinstance(value, (tuple, list)):
raise TypeError(
"Range must be tuple or list not {}".format(type(value))
)
if len(value) != 2:
raise ValueError("Only two items can be passed in a range.")

value_0 = target_field.to_literal(value[0])
value_1 = target_field.to_literal(value[1])
return f"{field_name} gte {value_0} and {field_name} lte {value_1}"
elif operator == "in":
literal_values = []
for val in value:
val = target_field.to_literal(val)
literal_values.append(f"{field_name} eq {val}")
return " or ".join(literal_values)
elif operator == "eq":
return f"{field_name} eq {value}"
else:
raise ValueError(f"Invalid expression {operator}")

def as_filter_string(self):
"""Returns final filter string for this filter"""
expressions = self.process_expressions()
result = self.combine_expressions(expressions)
return quote(result)


class GetEntitySetRequest(QueryRequest):
"""GET on EntitySet"""

Expand All @@ -988,6 +1176,23 @@ def __getattr__(self, name):
proprty = self._entity_type.proprty(name)
return GetEntitySetFilter(proprty)

def set_filter(self, filter_val):
"""Chain filter"""
filter_text = self._filter + " and " if self._filter else ""
filter_text += filter_val
self._filter = filter_text

def filter(self, *args, **kwargs):
# pylint: disable=no-else-return
if args and isinstance(args[0], str):
self._filter = args[0]
return self
else:
self.set_filter(
GetEntitySetFilterChainable(self, args, kwargs).as_filter_string()
)
return self


class EntitySetProxy:
"""EntitySet Proxy"""
Expand Down Expand Up @@ -1461,7 +1666,7 @@ def get_boundary(self):
"""Get boundary used for request parts"""
return self.id

def get_headers(self):
def get_default_headers(self):
# pylint: disable=no-self-use
return {'Content-Type': 'multipart/mixed;boundary={}'.format(self.get_boundary())}

Expand Down
Loading