Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 IMPROVE: Ensure QueryBuilder is passed Backend #5186

Merged
merged 3 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions aiida/cmdline/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
import logging
import os
import sys
from typing import TYPE_CHECKING

from tabulate import tabulate

from . import echo

if TYPE_CHECKING:
from aiida.orm import WorkChainNode

__all__ = ('is_verbose',)


Expand Down Expand Up @@ -306,7 +310,7 @@ def get_process_function_report(node):
return '\n'.join(report)


def get_workchain_report(node, levelname, indent_size=4, max_depth=None):
def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None):
"""
Return a multi line string representation of the log messages and output of a given workchain

Expand All @@ -333,7 +337,7 @@ def get_subtree(uuid, level=0):
Get a nested tree of work calculation nodes and their nesting level starting from this uuid.
The result is a list of uuid of these nodes.
"""
builder = orm.QueryBuilder()
builder = orm.QueryBuilder(backend=node.backend)
builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation')
builder.append(
cls=orm.WorkChainNode,
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Comment, filters=filters, project='id').all()
builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id').all()
entities_to_delete = [_[0] for _ in builder]
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Log, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Comment, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filter must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Log, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,7 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un
MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""")


def get_bands_and_parents_structure(args):
def get_bands_and_parents_structure(args, backend=None):
"""Search for bands and return bands and the closest structure that is a parent of the instance.

:returns:
Expand All @@ -1817,7 +1817,7 @@ def get_bands_and_parents_structure(args):
from aiida import orm
from aiida.common import timezone

q_build = orm.QueryBuilder()
q_build = orm.QueryBuilder(backend=backend)
if args.all_users is False:
q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email})
else:
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,15 @@ def read_cif(fileobj, index=-1, **kwargs):
return struct_list[index]

@classmethod
def from_md5(cls, md5):
def from_md5(cls, md5, backend=None):
"""
Return a list of all CIF files that match a given MD5 hash.

.. note:: the hash has to be stored in a ``_md5`` attribute,
otherwise the CIF file will not be found.
"""
from aiida.orm.querybuilder import QueryBuilder
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(cls, filters={'attributes.md5': {'==': md5}})
return builder.all(flat=True)

Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/nodes/data/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_description(self):
return f'{self.description}'

@classmethod
def get_code_helper(cls, label, machinename=None):
def get_code_helper(cls, label, machinename=None, backend=None):
"""
:param label: the code label identifying the code to load
:param machinename: the machine name where code is setup
Expand All @@ -164,7 +164,7 @@ def get_code_helper(cls, label, machinename=None):
from aiida.orm.computers import Computer
from aiida.orm.querybuilder import QueryBuilder

query = QueryBuilder()
query = QueryBuilder(backend=backend)
query.append(cls, filters={'label': label}, project='*', tag='code')
if machinename:
query.append(Computer, filters={'label': machinename}, with_node='code')
Expand Down Expand Up @@ -249,7 +249,7 @@ def get_from_string(cls, code_string):
raise MultipleObjectsError(f'{code_string} could not be uniquely resolved')

@classmethod
def list_for_plugin(cls, plugin, labels=True):
def list_for_plugin(cls, plugin, labels=True, backend=None):
"""
Return a list of valid code strings for a given plugin.

Expand All @@ -260,7 +260,7 @@ def list_for_plugin(cls, plugin, labels=True):
otherwise a list of integers with the code PKs.
"""
from aiida.orm.querybuilder import QueryBuilder
query = QueryBuilder()
query = QueryBuilder(backend=backend)
query.append(cls, filters={'attributes.input_plugin': {'==': plugin}})
valid_codes = query.all(flat=True)

Expand Down
14 changes: 7 additions & 7 deletions aiida/orm/nodes/data/upf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_pseudos_from_structure(structure, family_name):
return pseudo_list


def upload_upf_family(folder, group_label, group_description, stop_if_existing=True):
def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None):
"""Upload a set of UPF files in a given group.

:param folder: a path containing all UPF files to be added.
Expand Down Expand Up @@ -120,7 +120,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T

for filename in filenames:
md5sum = md5_file(filename)
builder = orm.QueryBuilder()
builder = orm.QueryBuilder(backend=backend)
builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}})
existing_upf = builder.first()

Expand Down Expand Up @@ -321,7 +321,7 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs
return super().store(*args, **kwargs)

@classmethod
def from_md5(cls, md5):
def from_md5(cls, md5, backend=None):
"""Return a list of all `UpfData` that match the given md5 hash.

.. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute
Expand All @@ -330,7 +330,7 @@ def from_md5(cls, md5):
:return: list of existing `UpfData` nodes that have the same md5 hash
"""
from aiida.orm.querybuilder import QueryBuilder
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(cls, filters={'attributes.md5': {'==': md5}})
return builder.all(flat=True)

Expand Down Expand Up @@ -366,7 +366,7 @@ def get_upf_family_names(self):
"""Get the list of all upf family names to which the pseudo belongs."""
from aiida.orm import QueryBuilder, UpfFamily

query = QueryBuilder()
query = QueryBuilder(backend=self.backend)
query.append(UpfFamily, tag='group', project='label')
query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group')
return query.all(flat=True)
Expand Down Expand Up @@ -448,7 +448,7 @@ def get_upf_group(cls, group_label):
return UpfFamily.get(label=group_label)

@classmethod
def get_upf_groups(cls, filter_elements=None, user=None):
def get_upf_groups(cls, filter_elements=None, user=None, backend=None):
"""Return all names of groups of type UpfFamily, possibly with some filters.

:param filter_elements: A string or a list of strings.
Expand All @@ -460,7 +460,7 @@ def get_upf_groups(cls, filter_elements=None, user=None):
"""
from aiida.orm import QueryBuilder, UpfFamily, User

builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(UpfFamily, tag='group', project='*')

if user:
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,11 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str
"""
from aiida.orm.utils.links import validate_link

validate_link(source, self, link_type, link_label)
validate_link(source, self, link_type, link_label, backend=self.backend)

# Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules
if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]:
builder = QueryBuilder().append(
builder = QueryBuilder(backend=self.backend).append(
Node, filters={'id': self.pk}, tag='parent').append(
Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable
if builder.count() > 0:
Expand Down Expand Up @@ -537,7 +537,7 @@ def get_stored_link_triples(
if link_label_filter:
edge_filters['label'] = {'like': link_label_filter}

builder = QueryBuilder()
builder = QueryBuilder(backend=self.backend)
builder.append(Node, filters=node_filters, tag='main')

node_project = ['uuid'] if only_uuid else ['*']
Expand Down Expand Up @@ -894,7 +894,7 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']:
if not node_hash or not self._cachable:
return iter(())

builder = QueryBuilder()
builder = QueryBuilder(backend=self.backend)
builder.append(self.__class__, filters={'extras._aiida_hash': node_hash}, project='*', subclassing=False)
nodes_identical = (n[0] for n in builder.iterall())

Expand Down
11 changes: 8 additions & 3 deletions aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def __init__(
:param distinct: Whether to return de-duplicated rows

"""
backend = backend or get_manager().get_backend()
self._impl: BackendQueryBuilder = backend.query()
self._backend = backend or get_manager().get_backend()
self._impl: BackendQueryBuilder = self._backend.query()

# SERIALISABLE ATTRIBUTES
# A list storing the path being traversed by the query
Expand Down Expand Up @@ -189,6 +189,11 @@ def __init__(
if order_by:
self.order_by(order_by)

@property
def backend(self) -> 'Backend':
"""Return the backend used by the QueryBuilder."""
return self._backend

def as_dict(self, copy: bool = True) -> QueryDictType:
"""Convert to a JSON serialisable dictionary representation of the query."""
data: QueryDictType = {
Expand Down Expand Up @@ -225,7 +230,7 @@ def __str__(self) -> str:

def __deepcopy__(self, memo) -> 'QueryBuilder':
"""Create deep copy of the instance."""
return type(self)(**self.as_dict()) # type: ignore
return type(self)(backend=self.backend, **self.as_dict()) # type: ignore

def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]:
"""Returns a list of all the vertices that are being used.
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/utils/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label'])


def link_triple_exists(source, target, link_type, link_label):
def link_triple_exists(source, target, link_type, link_label, backend=None):
"""Return whether a link with the given type and label exists between the given source and target node.

:param source: node from which the link is outgoing
Expand All @@ -42,15 +42,15 @@ def link_triple_exists(source, target, link_type, link_label):

# Here we have two stored nodes, so we need to check if the same link already exists in the database.
# Finding just a single match is sufficient so we can use the `limit` clause for efficiency
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(Node, filters={'id': source.id}, project=['id'])
builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label})
builder.limit(1)

return builder.count() != 0


def validate_link(source, target, link_type, link_label):
def validate_link(source, target, link_type, link_label, backend=None):
"""
Validate adding a link of the given type and label from a given node to ourself.

Expand Down Expand Up @@ -153,7 +153,7 @@ def validate_link(source, target, link_type, link_label):
if outdegree == 'unique_triple' or indegree == 'unique_triple':
# For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache
# or stored, in which case, the new proposed link is a duplicate and thus illegal
duplicate_link_triple = link_triple_exists(source, target, link_type, link_label)
duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend)

# If the outdegree is `unique` there cannot already be any other outgoing link of that type
if outdegree == 'unique' and source.get_outgoing(link_type=link_type, only_uuid=True).all():
Expand Down
6 changes: 3 additions & 3 deletions aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def clean_remote(transport, path):
pass


def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None):
def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None, backend=None):
"""
Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of
calcjobs will be determined by a query with filters based on the pks, past_days, older_than,
computers and user arguments.

:param pks: onlu include calcjobs with a pk in this list
:param pks: only include calcjobs with a pk in this list
:param past_days: only include calcjobs created since past_days
:param older_than: only include calcjobs older than
:param computers: only include calcjobs that were ran on these computers
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer
if pks:
filters_calc['id'] = {'in': pks}

query = orm.QueryBuilder()
query = orm.QueryBuilder(backend=backend)
query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc)
query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer)
query.append(orm.User, with_node='calc', filters={'email': user.email})
Expand Down
11 changes: 0 additions & 11 deletions aiida/tools/graph/age_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,6 @@ def aiida_cls(self):
"""Class of nodes contained in the entity set (node or group)"""
return self._aiida_cls

def get_entities(self):
"""Iterator that returns the AiiDA entities"""
for entity, in orm.QueryBuilder().append(
self._aiida_cls, project='*', filters={
self._identifier: {
'in': self.keyset
}
}
).iterall():
yield entity


class DirectedEdgeSet(AbstractSetContainer):
"""Extension of AbstractSetContainer
Expand Down
7 changes: 4 additions & 3 deletions aiida/tools/graph/age_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from abc import ABCMeta, abstractmethod
from collections import defaultdict
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -65,7 +66,7 @@ class QueryRule(Operation, metaclass=ABCMeta):
found in the last iteration of the query (ReplaceRule).
"""

def __init__(self, querybuilder, max_iterations=1, track_edges=False):
def __init__(self, querybuilder: orm.QueryBuilder, max_iterations=1, track_edges=False):
"""Initialization method

:param querybuilder: an instance of the QueryBuilder class from which to take the
Expand Down Expand Up @@ -107,7 +108,7 @@ def get_spec_from_path(query_dict, idx):
for pathspec in query_dict['path']:
if not pathspec['entity_type']:
pathspec['entity_type'] = 'node.Node.'
self._qbtemplate = orm.QueryBuilder(**query_dict)
self._qbtemplate = deepcopy(querybuilder)
query_dict = self._qbtemplate.as_dict()
self._first_tag = query_dict['path'][0]['tag']
self._last_tag = query_dict['path'][-1]['tag']
Expand Down Expand Up @@ -163,7 +164,7 @@ def _init_run(self, operational_set):

# Copying qbtemplate so there's no problem if it is used again in a later run:
query_dict = self._qbtemplate.as_dict()
self._querybuilder = orm.QueryBuilder.from_dict(query_dict)
self._querybuilder = deepcopy(self._qbtemplate)

self._entity_to_identifier = operational_set[self._entity_to].identifier

Expand Down
Loading