diff --git a/.circleci/config.yml b/.circleci/config.yml index 2072bed8789..a17dfe72560 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -37,6 +37,8 @@ defaults: CKAN_POSTGRES_PWD: pass PGPASSWORD: ckan PYTEST_COMMON_OPTIONS: -v --ckan-ini=test-core-circle-ci.ini --cov=ckan --cov=ckanext --junitxml=~/junit/result/junit.xml + # report usage of deprecated features + SQLALCHEMY_WARN_20: 1 pg_image: &pg_image image: postgres:10 environment: diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml index 4a7b5343b8d..90e2f84ca43 100644 --- a/.github/workflows/pyright.yml +++ b/.github/workflows/pyright.yml @@ -1,7 +1,7 @@ name: Check types on: [pull_request] env: - NODE_VERSION: '16' + NODE_VERSION: '18' PYTHON_VERSION: '3.8' permissions: @@ -15,7 +15,7 @@ jobs: - uses: actions/setup-python@v2 with: python-version: ${{ env.PYTHON_VERSION }} - - uses: actions/setup-node@v2-beta + - uses: actions/setup-node@v3 with: node-version: ${{ env.NODE_VERSION }} - name: Install python deps diff --git a/changes/7583.misc b/changes/7583.misc new file mode 100644 index 00000000000..3380cf45f61 --- /dev/null +++ b/changes/7583.misc @@ -0,0 +1,30 @@ +:py:meth:`~ckanext.datastore.interfaces.IDatastore.datastore_search` of +:py:class:`~ckanext.datastore.interfaces.IDatastore` interface is not completely +compatible with old version. + +``where`` key of the ``query_dict`` returned from this method has a different +format. Before it was a collection of tuples with an SQL where-clause with +positional/named ``%``-style placeholders on the first position, followed by +arbitrary number of parameters:: + + return { + ..., + "where": [('"age" BETWEEN %s AND %s', param1, param2, ...), ...] + } + +Now every element of collection must be a tuple that contains SQL where-clause +with **named** ``:``-style placeholders and a dict with the values for all the +placeholders: + + + return { + ..., + "where": [( + '"age" BETWEEN :my_ext_min AND :my_ext_max', + {"my_ext_min": age_between[0], "my_ext_max": age_between[1]}, + )] + } + +In order to avoid name conflicts with placeholders from different plugin, don't +use simple names, i.e. ``val``, ``min``, ``name``, and add unique prefix to all +the placeholders. diff --git a/changes/7583.removal b/changes/7583.removal new file mode 100644 index 00000000000..48204759eb4 --- /dev/null +++ b/changes/7583.removal @@ -0,0 +1,16 @@ +SQLAlchemy's Metadata object(:py:attr:`ckan.model.meta.metadata`) is no longer +bound the the DB engine. [A number of +operations](https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#implicit-and-connectionless-execution-bound-metadata-removed), +such as `table.exists()`, `table.create()`, `metadata.create_all()`, +`metadata.reflect()`, now produce an error +:py:class:`sqlalchemy.exc.UnboundExecutionError`. + +Depending on the situation, following changes may be required: + +* Instead of creating tables via custom CLI command or during application + startup, use [Alembic + migrations](https://docs.ckan.org/en/2.10/extensions/best-practices.html#use-migrations-when-introducing-new-models) + +* If there is no other way, change `table.create()`/`table.exists()` to + `table.create(engine)`/`table.exists()`. Get `engine` by calling + :py:func:`~ckan.model.ensure_engine`. diff --git a/ckan/authz.py b/ckan/authz.py index 48c44084dc4..c70e519101e 100644 --- a/ckan/authz.py +++ b/ckan/authz.py @@ -344,7 +344,7 @@ def _has_user_permission_for_groups( # get any roles the user has for the group q: Any = (model.Session.query(model.Member.capacity) # type_ignore_reason: attribute has no method - .filter(model.Member.group_id.in_(group_ids)) # type: ignore + .filter(model.Member.group_id.in_(group_ids)) .filter(model.Member.table_name == 'user') .filter(model.Member.state == 'active') .filter(model.Member.table_id == user_id)) @@ -402,7 +402,7 @@ def has_user_permission_for_some_org( .filter(model.Member.table_name == 'user') .filter(model.Member.state == 'active') # type_ignore_reason: attribute has no method - .filter(model.Member.capacity.in_(roles)) # type: ignore + .filter(model.Member.capacity.in_(roles)) .filter(model.Member.table_id == user_id)) group_ids = [] for row in q: @@ -416,8 +416,7 @@ def has_user_permission_for_some_org( model.Session.query(model.Group) .filter(model.Group.is_organization == True) .filter(model.Group.state == 'active') - # type_ignore_reason: attribute has no method - .filter(model.Group.id.in_(group_ids)).exists() # type: ignore + .filter(model.Group.id.in_(group_ids)).exists() ).scalar() return permission_exists @@ -491,8 +490,7 @@ def user_is_collaborator_on_dataset( if capacity: if isinstance(capacity, str): capacity = [capacity] - # type_ignore_reason: attribute has no method - q = q.filter(model.PackageMember.capacity.in_(capacity)) # type: ignore + q = q.filter(model.PackageMember.capacity.in_(capacity)) return model.Session.query(q.exists()).scalar() diff --git a/ckan/cli/search_index.py b/ckan/cli/search_index.py index bb1304081fa..cacde57b9df 100644 --- a/ckan/cli/search_index.py +++ b/ckan/cli/search_index.py @@ -161,7 +161,10 @@ def rebuild_fast(): db_url = config['sqlalchemy.url'] engine = sa.create_engine(db_url) package_ids = [] - result = engine.execute(u"select id from package where state = 'active';") + with engine.connect() as conn: + result = conn.execute( + sa.text("SELECT id FROM package where state = 'active'") + ) for row in result: package_ids.append(row[0]) diff --git a/ckan/config/environment.py b/ckan/config/environment.py index 2290111f56e..26b51edae2e 100644 --- a/ckan/config/environment.py +++ b/ckan/config/environment.py @@ -252,4 +252,4 @@ def update_config() -> None: # Close current session and open database connections to ensure a clean # clean environment even if an error occurs later on model.Session.remove() - model.Session.bind.dispose() + model.Session.bind.dispose() # type: ignore diff --git a/ckan/lib/dictization/__init__.py b/ckan/lib/dictization/__init__.py index d93c08acd8b..9e2e8f89e2d 100644 --- a/ckan/lib/dictization/__init__.py +++ b/ckan/lib/dictization/__init__.py @@ -6,7 +6,7 @@ import sqlalchemy from sqlalchemy import Table -from sqlalchemy.engine import Row # type: ignore +from sqlalchemy.engine import Row from sqlalchemy.orm import class_mapper from ckan.model.core import State @@ -27,7 +27,7 @@ def table_dictize(obj: Any, context: Context, **kw: Any) -> dict[str, Any]: result_dict: dict[str, Any] = {} if isinstance(obj, Row): - fields = obj.keys() + fields = obj._fields else: ModelClass = obj.__class__ table = class_mapper(ModelClass).persist_selectable diff --git a/ckan/lib/dictization/model_dictize.py b/ckan/lib/dictization/model_dictize.py index 07387738b20..4211bc435e0 100644 --- a/ckan/lib/dictization/model_dictize.py +++ b/ckan/lib/dictization/model_dictize.py @@ -188,7 +188,7 @@ def package_dictize( # resources res = model.resource_table - q = select([res]).where(res.c["package_id"] == pkg.id) + q = select(res).where(res.c["package_id"] == pkg.id) result = execute(q, res, context) result_dict["resources"] = resource_list_dictize(result, context) result_dict['num_resources'] = len(result_dict.get('resources', [])) @@ -196,9 +196,9 @@ def package_dictize( # tags tag = model.tag_table pkg_tag = model.package_tag_table - q = select([tag, pkg_tag.c["state"]], - from_obj=pkg_tag.join(tag, tag.c["id"] == pkg_tag.c["tag_id"]) - ).where(pkg_tag.c["package_id"] == pkg.id) + q = select(tag, pkg_tag.c["state"]).join( + pkg_tag, tag.c["id"] == pkg_tag.c["tag_id"] + ).where(pkg_tag.c["package_id"] == pkg.id) result = execute(q, pkg_tag, context) result_dict["tags"] = d.obj_list_dictize(result, context, lambda x: x["name"]) @@ -213,18 +213,21 @@ def package_dictize( # extras - no longer revisioned, so always provide latest extra = model.package_extra_table - q = select([extra]).where(extra.c["package_id"] == pkg.id) + q = select(extra).where(extra.c["package_id"] == pkg.id) result = execute(q, extra, context) result_dict["extras"] = extras_list_dictize(result, context) # groups member = model.member_table group = model.group_table - q = select([group, member.c["capacity"]], - from_obj=member.join(group, group.c["id"] == member.c["group_id"]) - ).where(member.c["table_id"] == pkg.id)\ - .where(member.c["state"] == 'active') \ - .where(group.c["is_organization"] == False) + q = select(group, member.c["capacity"]).join( + member, group.c["id"] == member.c["group_id"] + ).where( + member.c["table_id"] == pkg.id, + member.c["state"] == 'active', + group.c["is_organization"] == False + ) + result = execute(q, member, context) context['with_capacity'] = False # no package counts as cannot fetch from search index at the same @@ -235,9 +238,9 @@ def package_dictize( # owning organization group = model.group_table - q = select([group] - ).where(group.c["id"] == pkg.owner_org) \ - .where(group.c["state"] == 'active') + q = select(group).where( + group.c["id"] == pkg.owner_org + ).where(group.c["state"] == 'active') result = execute(q, group, context) organizations = d.obj_list_dictize(result, context) if organizations: @@ -247,11 +250,15 @@ def package_dictize( # relations rel = model.package_relationship_table - q = select([rel]).where(rel.c["subject_package_id"] == pkg.id) + q = select( + rel + ).where(rel.c["subject_package_id"] == pkg.id) result = execute(q, rel, context) result_dict["relationships_as_subject"] = \ d.obj_list_dictize(result, context) - q = select([rel]).where(rel.c["object_package_id"] == pkg.id) + q = select( + rel + ).where(rel.c["object_package_id"] == pkg.id) result = execute(q, rel, context) result_dict["relationships_as_object"] = \ d.obj_list_dictize(result, context) diff --git a/ckan/lib/pagination.py b/ckan/lib/pagination.py index 8f0c3ed25c5..ff6ab1442ae 100644 --- a/ckan/lib/pagination.py +++ b/ckan/lib/pagination.py @@ -53,7 +53,7 @@ class BasePage(List[Any]): - a sequence - an SQLAlchemy query - e.g.: Session.query(MyModel) - - an SQLAlchemy select - e.g.: sqlalchemy.select([my_table]) + - an SQLAlchemy select - e.g.: sqlalchemy.select(my_table) A "Page" instance maintains pagination logic associated with each page, where it begins, what the first/last item on the page is, etc. diff --git a/ckan/logic/action/create.py b/ckan/logic/action/create.py index 76acee506a4..37530170bb1 100644 --- a/ckan/logic/action/create.py +++ b/ckan/logic/action/create.py @@ -401,8 +401,7 @@ def resource_view_create( last_view = model.Session.query(model.ResourceView)\ .filter_by(resource_id=resource_id) \ .order_by( - # type_ignore_reason: incomplete SQLAlchemy types - model.ResourceView.order.desc() # type: ignore + model.ResourceView.order.desc() ).first() if not last_view: @@ -610,10 +609,7 @@ def member_create(context: Context, filter(model.Member.table_name == obj_type).\ filter(model.Member.table_id == obj.id).\ filter(model.Member.group_id == group.id).\ - order_by( - # type_ignore_reason: incomplete SQLAlchemy types - model.Member.state.asc() # type: ignore - ).first() + order_by(model.Member.state.asc()).first() if member: user_obj = model.User.get(user) if user_obj and member.table_name == u'user' and \ diff --git a/ckan/logic/action/get.py b/ckan/logic/action/get.py index 5e0055144c2..f12c5432a8e 100644 --- a/ckan/logic/action/get.py +++ b/ckan/logic/action/get.py @@ -81,7 +81,7 @@ def package_list(context: Context, data_dict: DataDict) -> ActionResult.PackageL package_table = model.package_table col = (package_table.c["id"] if api == 2 else package_table.c["name"]) - query = _select([col]) + query = _select(col) query = query.where(_and_( package_table.c["state"] == 'active', package_table.c["private"] == False, @@ -97,7 +97,9 @@ def package_list(context: Context, data_dict: DataDict) -> ActionResult.PackageL query = query.offset(offset) ## Returns the first field in each result record - return [r[0] for r in query.execute() or []] + return context["session"].scalars( + query + ).all() @logic.validate(ckan.logic.schema.default_package_list_schema) @@ -385,15 +387,13 @@ def _group_or_org_list( query = query.filter(model.Group.state == 'active') if groups: - # type_ignore_reason: incomplete SQLAlchemy types - query = query.filter(model.Group.name.in_(groups)) # type: ignore + query = query.filter(model.Group.name.in_(groups)) if q: q = u'%{0}%'.format(q) query = query.filter(_or_( - # type_ignore_reason: incomplete SQLAlchemy types - model.Group.name.ilike(q), # type: ignore - model.Group.title.ilike(q), # type: ignore - model.Group.description.ilike(q), # type: ignore + model.Group.name.ilike(q), + model.Group.title.ilike(q), + model.Group.description.ilike(q), )) query = query.filter(model.Group.is_organization == is_org) @@ -592,8 +592,7 @@ def group_list_authz(context: Context, q: Any = model.Session.query(model.Member.group_id) \ .filter(model.Member.table_name == 'user') \ .filter( - # type_ignore_reason: incomplete SQLAlchemy types - model.Member.capacity.in_(roles) # type: ignore + model.Member.capacity.in_(roles) ).filter(model.Member.table_id == user_id) \ .filter(model.Member.state == 'active') group_ids = [] @@ -608,8 +607,7 @@ def group_list_authz(context: Context, .filter(model.Group.state == 'active') if not sysadmin or am_member: - # type_ignore_reason: incomplete SQLAlchemy types - q = q.filter(model.Group.id.in_(group_ids)) # type: ignore + q = q.filter(model.Group.id.in_(group_ids)) groups = q.all() @@ -702,8 +700,7 @@ def organization_list_for_user(context: Context, q: Query[tuple[model.Member, model.Group]] = model.Session.query(model.Member, model.Group) \ .filter(model.Member.table_name == 'user') \ .filter( - # type_ignore_reason: incomplete SQLAlchemy types - model.Member.capacity.in_(roles) # type: ignore + model.Member.capacity.in_(roles) ) \ .filter(model.Member.table_id == user_id) \ .filter(model.Member.state == 'active') \ @@ -731,8 +728,7 @@ def organization_list_for_user(context: Context, if not group_ids: return [] - # type_ignore_reason: incomplete SQLAlchemy types - orgs_q = orgs_q.filter(model.Group.id.in_(group_ids)) # type: ignore + orgs_q = orgs_q.filter(model.Group.id.in_(group_ids)) orgs_and_capacities = [ (org, group_ids_to_capacities[org.id]) for org in orgs_q.all()] @@ -848,18 +844,16 @@ def user_list( if all_fields: query: 'Query[Any]' = model.Session.query( model.User, - # type_ignore_reason: incomplete SQLAlchemy types - model.User.name.label('name'), # type: ignore - model.User.fullname.label('fullname'), # type: ignore - model.User.about.label('about'), # type: ignore - model.User.email.label('email'), # type: ignore - model.User.created.label('created'), # type: ignore - _select([_func.count(model.Package.id)], - _and_( - model.Package.creator_user_id == model.User.id, - model.Package.state == 'active', - model.Package.private == False, - )).label('number_created_packages') + model.User.name.label('name'), + model.User.fullname.label('fullname'), + model.User.about.label('about'), + model.User.email.label('email'), + model.User.created.label('created'), + _select(_func.count(model.Package.id)).where( + model.Package.creator_user_id == model.User.id, + model.Package.state == 'active', + model.Package.private == False, + ).label('number_created_packages') ) else: query = model.Session.query(model.User.name) @@ -887,16 +881,10 @@ def user_list( pass if order_by == 'display_name' or order_by_field is None: query = query.order_by( - _case( - [( - _or_( - model.User.fullname == None, - model.User.fullname == '' - ), - model.User.name - )], - else_=model.User.fullname - ) + _case((_or_( + model.User.fullname.is_(None), + model.User.fullname == '' + ), model.User.name), else_=model.User.fullname) ) elif order_by_field == 'number_created_packages' \ or order_by_field == 'fullname' \ @@ -1593,8 +1581,7 @@ def format_autocomplete(context: Context, data_dict: DataDict) -> ActionResult.F .filter(_and_( model.Resource.state == 'active', )) - # type_ignore_reason: incomplete SQLAlchemy types - .filter(model.Resource.format.ilike(like_q)) # type: ignore + .filter(model.Resource.format.ilike(like_q)) .group_by(model.Resource.format) .order_by(text('total DESC')) .limit(limit)) @@ -1967,9 +1954,8 @@ def package_search(context: Context, data_dict: DataDict) -> ActionResult.Packag group_names.extend(facets.get(field_name, {}).keys()) groups = (session.query(model.Group.name, model.Group.title) - # type_ignore_reason: incomplete SQLAlchemy types - .filter(model.Group.name.in_(group_names)) # type: ignore - .all() + .filter(model.Group.name.in_(group_names)) + .all() if group_names else []) group_titles_by_name = dict(groups) @@ -2206,7 +2192,7 @@ def _tag_search( q = q.filter(model.Tag.vocabulary_id == vocab.id) else: # If no vocabulary_name in data dict then show free tags only. - q = q.filter(model.Tag.vocabulary_id == None) + q = q.filter(model.Tag.vocabulary_id.is_(None)) # If we're searching free tags, limit results to tags that are # currently applied to a package. q: Query[model.Tag] = q.distinct().join(model.Tag.package_tags) @@ -2221,9 +2207,7 @@ def _tag_search( for term in terms: escaped_term = misc.escape_sql_like_special_characters( term, escape='\\') - q = q.filter( - # type_ignore_reason: incomplete SQLAlchemy types - model.Tag.name.ilike('%' + escaped_term + '%')) # type: ignore + q = q.filter(model.Tag.name.ilike('%' + escaped_term + '%')) count = q.count() q = q.offset(offset) @@ -2367,7 +2351,7 @@ def term_translation_show( trans_table = model.term_translation_table - q = _select([trans_table]) + q = _select(trans_table) if 'terms' not in data_dict: raise ValidationError({'terms': 'terms not in data'}) diff --git a/ckan/logic/action/update.py b/ckan/logic/action/update.py index 83cbe69b277..418f42398f7 100644 --- a/ckan/logic/action/update.py +++ b/ckan/logic/action/update.py @@ -1109,8 +1109,7 @@ def _bulk_update_dataset( model = context['model'] model.Session.query(model.package_table) \ .filter( - # type_ignore_reason: incomplete SQLAlchemy types - model.Package.id.in_(datasets) # type: ignore + model.Package.id.in_(datasets) ) .filter(model.Package.owner_org == org_id) \ .update(update_dict, synchronize_session=False) diff --git a/ckan/migration/migrate_package_activity.py b/ckan/migration/migrate_package_activity.py index 2e80c0d6611..338006c699b 100644 --- a/ckan/migration/migrate_package_activity.py +++ b/ckan/migration/migrate_package_activity.py @@ -32,6 +32,7 @@ import argparse from collections import defaultdict from typing import Any +from sqlalchemy import text import sys @@ -53,23 +54,23 @@ def get_context(): return _context -def num_unmigrated(engine): - num_unmigrated = engine.execute(''' - SELECT count(*) FROM activity a JOIN package p ON a.object_id=p.id - WHERE a.activity_type IN ('new package', 'changed package') - AND a.data NOT LIKE '%%{"actor"%%' - AND p.private = false; - ''').fetchone()[0] +def num_unmigrated(conn): + num_unmigrated = conn.execute(text(''' + SELECT count(*) FROM activity a JOIN package p ON a.object_id=p.id + WHERE a.activity_type IN ('new package', 'changed package') + AND a.data NOT LIKE '%%{"actor"%%' + AND p.private = false; + ''')).scalar() return num_unmigrated def num_activities_migratable(): from ckan import model - num_activities = model.Session.execute(u''' + num_activities = model.Session.execute(text(''' SELECT count(*) FROM activity a JOIN package p ON a.object_id=p.id WHERE a.activity_type IN ('new package', 'changed package') AND p.private = false; - ''').fetchall()[0][0] + ''')).fetchall()[0][0] return num_activities @@ -225,9 +226,9 @@ def migrate_dataset(dataset_name, errors): def wipe_activity_detail(delete_activity_detail): from ckan import model activity_detail_has_rows = \ - bool(model.Session.execute( - u'SELECT count(*) ' - 'FROM (SELECT * FROM "activity_detail" LIMIT 1) as t;') + bool(model.Session.execute(text( + 'SELECT count(*) ' + 'FROM (SELECT * FROM "activity_detail" LIMIT 1) as t;')) .fetchall()[0][0]) if not activity_detail_has_rows: print(u'\nactivity_detail table is aleady emptied') @@ -245,7 +246,7 @@ def wipe_activity_detail(delete_activity_detail): if delete_activity_detail.lower()[:1] != u'y': return from ckan import model - model.Session.execute(u'DELETE FROM "activity_detail";') + model.Session.execute(text('DELETE FROM "activity_detail";')) model.Session.commit() print(u'activity_detail deleted') diff --git a/ckan/migration/revision_legacy_code.py b/ckan/migration/revision_legacy_code.py index de4ce54e3b7..be8362fde74 100644 --- a/ckan/migration/revision_legacy_code.py +++ b/ckan/migration/revision_legacy_code.py @@ -57,7 +57,7 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): result = pkg else: package_rev = revision_model.package_revision_table - q = select([package_rev]).where(package_rev.c.id == pkg.id) + q = select(package_rev).where(package_rev.c.id == pkg.id) result = execute(q, package_rev, context).first() if not result: raise logic.NotFound @@ -76,7 +76,7 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): mm_col = res._columns.get(u'metadata_modified') if mm_col is not None: res._columns.remove(mm_col) - q = select([res]).where(res.c.package_id == pkg.id) + q = select(res).where(res.c.package_id == pkg.id) result = execute(q, res, context) result_dict["resources"] = resource_list_dictize(result, context) result_dict['num_resources'] = len(result_dict.get(u'resources', [])) @@ -87,9 +87,9 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): pkg_tag = model.package_tag_table else: pkg_tag = revision_model.package_tag_revision_table - q = select([tag, pkg_tag.c.state], - from_obj=pkg_tag.join(tag, tag.c.id == pkg_tag.c.tag_id) - ).where(pkg_tag.c.package_id == pkg.id) + q = select(tag, pkg_tag.c.state).join( + pkg_tag, tag.c.id == pkg_tag.c.tag_id + ).where(pkg_tag.c.package_id == pkg.id) # type: ignore result = execute(q, pkg_tag, context) result_dict["tags"] = d.obj_list_dictize(result, context, lambda x: x["name"]) @@ -107,7 +107,7 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): extra = model.package_extra_table else: extra = revision_model.extra_revision_table - q = select([extra]).where(extra.c.package_id == pkg.id) + q = select(extra).where(extra.c.package_id == pkg.id) result = execute(q, extra, context) result_dict["extras"] = extras_list_dictize(result, context) @@ -117,11 +117,13 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): else: member = revision_model.member_revision_table group = model.group_table - q = select([group, member.c.capacity], - from_obj=member.join(group, group.c.id == member.c.group_id) - ).where(member.c.table_id == pkg.id)\ - .where(member.c.state == u'active') \ - .where(group.c.is_organization == False) # noqa + q = select(group, member.c.capacity).join( + member, group.c.id == member.c.group_id + ).where( # type: ignore + member.c.table_id == pkg.id, + member.c.state == u'active', + group.c.is_organization == False # noqa + ) # noqa result = execute(q, member, context) context['with_capacity'] = False # no package counts as cannot fetch from search index at the same @@ -135,7 +137,7 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): group = model.group_table else: group = revision_model.group_revision_table - q = select([group] + q = select(group ).where(group.c.id == result_dict['owner_org']) \ .where(group.c.state == u'active') result = execute(q, group, context) @@ -151,11 +153,11 @@ def package_dictize_with_revisions(pkg, context, include_plugin_data=False): else: rel = revision_model \ .package_relationship_revision_table - q = select([rel]).where(rel.c.subject_package_id == pkg.id) + q = select(rel).where(rel.c.subject_package_id == pkg.id) result = execute(q, rel, context) result_dict["relationships_as_subject"] = \ d.obj_list_dictize(result, context) - q = select([rel]).where(rel.c.object_package_id == pkg.id) + q = select(rel).where(rel.c.object_package_id == pkg.id) result = execute(q, rel, context) result_dict["relationships_as_object"] = \ d.obj_list_dictize(result, context) diff --git a/ckan/migration/versions/020_69a0b0efc609_add_changeset.py b/ckan/migration/versions/020_69a0b0efc609_add_changeset.py index e8fc4db68f2..e7e0e016c83 100644 --- a/ckan/migration/versions/020_69a0b0efc609_add_changeset.py +++ b/ckan/migration/versions/020_69a0b0efc609_add_changeset.py @@ -66,6 +66,6 @@ def upgrade(): def downgrade(): - op.drop_table('changeset') op.drop_table('change') + op.drop_table('changeset') op.drop_table('changemask') diff --git a/ckan/migration/versions/021_765143af2ba3_postgresql_upgrade_sql.py b/ckan/migration/versions/021_765143af2ba3_postgresql_upgrade_sql.py index babdd7c8199..3f5b905b7d5 100644 --- a/ckan/migration/versions/021_765143af2ba3_postgresql_upgrade_sql.py +++ b/ckan/migration/versions/021_765143af2ba3_postgresql_upgrade_sql.py @@ -102,4 +102,4 @@ def upgrade(): def downgrade(): for name, table, _ in indexes: - op.drop_index(name, table) + op.execute(f"ALTER TABLE \"{table}\" DROP CONSTRAINT IF EXISTS {name}") diff --git a/ckan/migration/versions/074_a4ca55f0f45e_remove_resource_groups.py b/ckan/migration/versions/074_a4ca55f0f45e_remove_resource_groups.py index 0a4c5b1e6ad..9c68ddd5308 100644 --- a/ckan/migration/versions/074_a4ca55f0f45e_remove_resource_groups.py +++ b/ckan/migration/versions/074_a4ca55f0f45e_remove_resource_groups.py @@ -100,8 +100,6 @@ def downgrade(): sa.Column( 'resource_group_id', sa.UnicodeText, - nullable=False, - server_default='' ) ) op.add_column( @@ -109,8 +107,6 @@ def downgrade(): sa.Column( 'resource_group_id', sa.UnicodeText, - nullable=False, - server_default='' ) ) op.create_index( diff --git a/ckan/migration/versions/083_f98d8fa2a7f7_remove_related_items.py b/ckan/migration/versions/083_f98d8fa2a7f7_remove_related_items.py index 6164819c1e5..14a9b6d471f 100644 --- a/ckan/migration/versions/083_f98d8fa2a7f7_remove_related_items.py +++ b/ckan/migration/versions/083_f98d8fa2a7f7_remove_related_items.py @@ -33,7 +33,9 @@ def upgrade(): if skip_based_on_legacy_engine_version(op, __name__): return conn = op.get_bind() - existing = conn.execute("SELECT COUNT(*) FROM related;").fetchone() + existing = conn.execute( + sa.text("SELECT COUNT(*) FROM related;") + ).fetchone() if existing[0] > 0: print(WARNING) return diff --git a/ckan/migration/versions/084_d85ce5783688_add_metadata_created.py b/ckan/migration/versions/084_d85ce5783688_add_metadata_created.py index 7e78c915f16..0981235a984 100644 --- a/ckan/migration/versions/084_d85ce5783688_add_metadata_created.py +++ b/ckan/migration/versions/084_d85ce5783688_add_metadata_created.py @@ -24,7 +24,7 @@ def upgrade(): ) op.add_column('package', sa.Column('metadata_created', sa.TIMESTAMP)) conn = op.get_bind() - conn.execute( + conn.execute(sa.text( ''' UPDATE package SET metadata_created= (SELECT revision_timestamp @@ -33,7 +33,7 @@ def upgrade(): ORDER BY revision_timestamp ASC LIMIT 1); ''' - ) + )) def downgrade(): diff --git a/ckan/migration/versions/085_f9bf3d5c4b4d_adjust_activity_timestamps.py b/ckan/migration/versions/085_f9bf3d5c4b4d_adjust_activity_timestamps.py index aee1a932503..ee8819b4240 100644 --- a/ckan/migration/versions/085_f9bf3d5c4b4d_adjust_activity_timestamps.py +++ b/ckan/migration/versions/085_f9bf3d5c4b4d_adjust_activity_timestamps.py @@ -8,6 +8,7 @@ """ import datetime from alembic import op +from sqlalchemy import text from ckan.migration import skip_based_on_legacy_engine_version # revision identifiers, used by Alembic. revision = 'f9bf3d5c4b4d' @@ -28,8 +29,8 @@ def upgrade(): return connection = op.get_bind() - sql = u"update activity set timestamp = timestamp + (%s - %s);" - connection.execute(sql, utc_date, local_date) + sql = text("update activity set timestamp = timestamp + (:utc - :local);") + connection.execute(sql, {"utc": utc_date, "local": local_date}) def downgrade(): @@ -42,5 +43,5 @@ def downgrade(): return connection = op.get_bind() - sql = u"update activity set timestamp = timestamp - (%s - %s);" - connection.execute(sql, utc_date, local_date) + sql = text("update activity set timestamp = timestamp - (:utc - :local);") + connection.execute(sql, {"utc": utc_date, "local": local_date}) diff --git a/ckan/migration/versions/088_3537d5420e0e_delete_extrase_which_are_deleted_state.py b/ckan/migration/versions/088_3537d5420e0e_delete_extrase_which_are_deleted_state.py index 71a0c97ad83..5f0c6519209 100644 --- a/ckan/migration/versions/088_3537d5420e0e_delete_extrase_which_are_deleted_state.py +++ b/ckan/migration/versions/088_3537d5420e0e_delete_extrase_which_are_deleted_state.py @@ -7,7 +7,7 @@ """ from alembic import op - +from sqlalchemy import text # revision identifiers, used by Alembic. revision = u'3537d5420e0e' down_revision = u'ff1b303cab77' @@ -24,8 +24,10 @@ def upgrade(): ) conn = op.get_bind() - conn.execute(u'''DELETE FROM "package_extra" WHERE state='deleted';''') - conn.execute(u'''DELETE FROM "group_extra" WHERE state='deleted';''') + conn.execute(text( + '''DELETE FROM "package_extra" WHERE state='deleted';''' + )) + conn.execute(text('''DELETE FROM "group_extra" WHERE state='deleted';''')) def downgrade(): diff --git a/ckan/migration/versions/089_23c92480926e_package_activity_migration_check.py b/ckan/migration/versions/089_23c92480926e_package_activity_migration_check.py index df93829c44c..a82d014f062 100644 --- a/ckan/migration/versions/089_23c92480926e_package_activity_migration_check.py +++ b/ckan/migration/versions/089_23c92480926e_package_activity_migration_check.py @@ -9,7 +9,7 @@ from __future__ import print_function from alembic import op - +from sqlalchemy import text from ckan.migration.migrate_package_activity import num_unmigrated # revision identifiers, used by Alembic. @@ -50,7 +50,7 @@ def upgrade(): else: # there are no unmigrated package activities are_any_datasets = bool( - conn.execute(u'SELECT id FROM PACKAGE LIMIT 1').rowcount + conn.execute(text('SELECT 1 FROM package LIMIT 1')).scalar() ) # no need to tell the user if there are no datasets - this could just # be a fresh CKAN install diff --git a/ckan/migration/versions/093_7f70d7d15445_remove_activity_revision_id.py b/ckan/migration/versions/093_7f70d7d15445_remove_activity_revision_id.py index dc1cc543e46..f7f9d2b9537 100644 --- a/ckan/migration/versions/093_7f70d7d15445_remove_activity_revision_id.py +++ b/ckan/migration/versions/093_7f70d7d15445_remove_activity_revision_id.py @@ -51,9 +51,6 @@ def downgrade(): op.add_column(u'system_info', sa.Column(u'revision_id', sa.TEXT(), autoincrement=False, nullable=True)) - op.create_foreign_key(u'resource_view_resource_id_fkey', u'resource_view', - u'resource', ['resource_id'], ['id'], - onupdate=u'CASCADE', ondelete=u'CASCADE') op.add_column(u'resource', sa.Column(u'revision_id', sa.TEXT(), autoincrement=False, nullable=True)) op.create_foreign_key(u'resource_revision_id_fkey', u'resource', diff --git a/ckan/model/__init__.py b/ckan/model/__init__.py index 0a77ec776dc..9dd6ccca632 100644 --- a/ckan/model/__init__.py +++ b/ckan/model/__init__.py @@ -8,6 +8,7 @@ from time import sleep from typing import Any, Optional +import sqlalchemy as sa from sqlalchemy import MetaData, Table, inspect from alembic.command import ( @@ -19,8 +20,8 @@ import ckan.model.meta as meta -from ckan.model.meta import Session - +from ckan.model.meta import Session, registry +from ckan.exceptions import CkanConfigurationException from ckan.model.core import ( State, ) @@ -117,7 +118,7 @@ from ckan.types import AlchemySession __all__ = [ - "Session", "State", "System", "Package", "PackageMember", + "registry", "Session", "State", "System", "Package", "PackageMember", "PACKAGE_NAME_MIN_LENGTH", "PACKAGE_NAME_MAX_LENGTH", "PACKAGE_VERSION_MAX_LENGTH", "package_table", "package_member_table", "Tag", "PackageTag", "MAX_TAG_LENGTH", "MIN_TAG_LENGTH", "tag_table", @@ -148,13 +149,11 @@ def init_model(engine: Engine) -> None: meta.Session.configure(bind=engine) meta.create_local_session.configure(bind=engine) meta.engine = engine - meta.metadata.bind = engine # sqlalchemy migrate version table import sqlalchemy.exc for i in reversed(range(DB_CONNECT_RETRIES)): try: - Table('alembic_version', meta.metadata, autoload=True) - break + Table('alembic_version', meta.metadata, autoload_with=engine) except sqlalchemy.exc.NoSuchTableError: break except sqlalchemy.exc.OperationalError as e: @@ -162,6 +161,29 @@ def init_model(engine: Engine) -> None: sleep(DB_CONNECT_RETRIES - i) continue raise + else: + break + + +def ensure_engine() -> Engine: + """Return initialized SQLAlchemy engine or raise an error. + + This function guarantees that engine is initialized and provides a hint + when someone attempts to use the database before model is properly + initialized. + + Prefer using this function instead of direct access to engine via + `meta.engine`. + + """ + if not meta.engine: + log.error( + "%s:%s must be called before any interaction with the database", + init_model.__module__, init_model.__name__ + + ) + raise CkanConfigurationException("Model is not initialized") + return meta.engine class Repository(): @@ -196,7 +218,7 @@ def init_db(self) -> None: that may have been setup with either upgrade_db or a previous run of init_db. ''' - warnings.filterwarnings('ignore', 'SAWarning') + self.session.rollback() self.session.remove() @@ -207,12 +229,17 @@ def init_db(self) -> None: def clean_db(self) -> None: self.commit_and_remove() - meta.metadata = MetaData(self.metadata.bind) + meta.metadata = MetaData() + + engine = ensure_engine() + with warnings.catch_warnings(): warnings.filterwarnings('ignore', '.*(reflection|tsvector).*') - meta.metadata.reflect() + meta.metadata.reflect(engine) + + with engine.begin() as conn: + meta.metadata.drop_all(conn) - meta.metadata.drop_all() self.tables_created_and_initialised = False log.info('Database tables dropped') @@ -221,7 +248,8 @@ def create_db(self) -> None: i.e. the same as init_db APART from when running tests, when init_db has shortcuts. ''' - self.metadata.create_all(bind=self.metadata.bind) + with ensure_engine().begin() as conn: + self.metadata.create_all(conn) log.info('Database tables created') def rebuild_db(self) -> None: @@ -246,7 +274,7 @@ def delete_all(self) -> None: for table in tables: if table.name == 'alembic_version': continue - connection.execute('delete from "%s"' % table.name) + connection.execute(sa.delete(table)) self.session.commit() log.info('Database table data deleted') @@ -264,19 +292,20 @@ def take_alembic_output(self, return output def setup_migration_version_control(self) -> None: - assert isinstance(self.metadata.bind, Engine) self.reset_alembic_output() alembic_config = AlembicConfig(self._alembic_ini) alembic_config.set_main_option( "sqlalchemy.url", config.get("sqlalchemy.url") ) + engine = ensure_engine() sqlalchemy_migrate_version = 0 - db_inspect = inspect(self.metadata.bind) + db_inspect = inspect(engine) if db_inspect.has_table("migrate_version"): - sqlalchemy_migrate_version = self.metadata.bind.execute( - u'select version from migrate_version' - ).scalar() + with engine.connect() as conn: + sqlalchemy_migrate_version = conn.execute( + sa.text('select version from migrate_version') + ).scalar() # this value is used for graceful upgrade from # sqlalchemy-migrate to alembic @@ -318,13 +347,14 @@ def upgrade_db(self, version: str='head') -> None: @param version: version to upgrade to (if None upgrade to latest) ''' - assert meta.engine - _assert_engine_msg: str = ( - u'Database migration - only Postgresql engine supported (not %s).' - ) % meta.engine.name - assert meta.engine.name in ( - u'postgres', u'postgresql' - ), _assert_engine_msg + engine = ensure_engine() + if engine.name not in ('postgres', 'postgresql'): + log.error( + 'Only Postgresql engine supported (not %s).', + engine.name, + ) + raise CkanConfigurationException(engine.name) + self.setup_migration_version_control() version_before = self.current_version() alembic_upgrade(self.alembic_config, version) @@ -340,10 +370,10 @@ def upgrade_db(self, version: str='head') -> None: log.info(u'CKAN database version remains as: %s', version_after) def are_tables_created(self) -> bool: - meta.metadata = MetaData(self.metadata.bind) + meta.metadata = MetaData() with warnings.catch_warnings(): warnings.filterwarnings('ignore', '.*(reflection|geometry).*') - meta.metadata.reflect() + meta.metadata.reflect(meta.engine) return bool(meta.metadata.tables) diff --git a/ckan/model/api_token.py b/ckan/model/api_token.py index 075b1949a55..d5974850f12 100644 --- a/ckan/model/api_token.py +++ b/ckan/model/api_token.py @@ -16,6 +16,8 @@ __all__ = [u"ApiToken", u"api_token_table"] +Mapped = orm.Mapped + def _make_token() -> str: nbytes = config.get(u"api_token.nbytes") @@ -35,13 +37,13 @@ def _make_token() -> str: class ApiToken(DomainObject): - id: str - name: str - user_id: Optional[str] - created_at: datetime.datetime - last_access: Optional[datetime.datetime] - plugin_extras: dict[str, Any] - owner: Optional[User] + id: Mapped[str] + name: Mapped[str] + user_id: Mapped[Optional[str]] + created_at: Mapped[datetime.datetime] + last_access: Mapped[Optional[datetime.datetime]] + plugin_extras: Mapped[dict[str, Any]] + owner: Mapped[Optional[User]] def __init__( self, user_id: Optional[str] = None, @@ -79,11 +81,11 @@ def set_extra(self, key: str, value: Any, commit: bool = False) -> None: meta.Session.commit() -meta.mapper( +meta.registry.map_imperatively( ApiToken, api_token_table, properties={ - u"owner": orm.relation( + u"owner": orm.relationship( User, backref=orm.backref(u"api_tokens", cascade=u"all, delete") ) diff --git a/ckan/model/base.py b/ckan/model/base.py index 13a3d6d21b3..9e9ba22f9e9 100644 --- a/ckan/model/base.py +++ b/ckan/model/base.py @@ -7,12 +7,12 @@ import sqlalchemy as sa from sqlalchemy import orm -from sqlalchemy.ext.declarative import declarative_base from typing_extensions import Self -from .meta import metadata, Session +from .meta import registry, Session -BaseModel = declarative_base(metadata=metadata) + +BaseModel = registry.generate_base() class SessionMixin: diff --git a/ckan/model/dashboard.py b/ckan/model/dashboard.py index f2eb22c40e7..97783382f80 100644 --- a/ckan/model/dashboard.py +++ b/ckan/model/dashboard.py @@ -2,6 +2,7 @@ import datetime import sqlalchemy +from sqlalchemy.orm import Mapped import ckan.model.meta as meta from typing import Optional from typing_extensions import Self @@ -20,9 +21,9 @@ class Dashboard(object): '''Saved data used for the user's dashboard.''' - user_id: str - activity_stream_last_viewed: datetime.datetime - email_last_sent: datetime.datetime + user_id: Mapped[str] + activity_stream_last_viewed: Mapped[datetime.datetime] + email_last_sent: Mapped[datetime.datetime] def __init__(self, user_id: str) -> None: self.user_id = user_id @@ -41,4 +42,4 @@ def get(cls, user_id: str) -> Optional[Self]: query = query.filter(Dashboard.user_id == user_id) return query.first() -meta.mapper(Dashboard, dashboard_table) +meta.registry.map_imperatively(Dashboard, dashboard_table) diff --git a/ckan/model/follower.py b/ckan/model/follower.py index d6c1b9c8b70..841733d30eb 100644 --- a/ckan/model/follower.py +++ b/ckan/model/follower.py @@ -16,7 +16,7 @@ from ckan.types import Query - +Mapped = sqlalchemy.orm.Mapped Follower = TypeVar("Follower", bound='ckan.model.User') Followed = TypeVar( "Followed", 'ckan.model.User', 'ckan.model.Package', 'ckan.model.Group') @@ -24,9 +24,9 @@ class ModelFollowingModel(domain_object.DomainObject, Generic[Follower, Followed]): - follower_id: str - object_id: str - datetime: _datetime.datetime + follower_id: Mapped[str] + object_id: Mapped[str] + datetime: Mapped[_datetime.datetime] def __init__(self, follower_id: str, object_id: str) -> None: self.follower_id = follower_id @@ -161,7 +161,7 @@ def _object_class(cls): sqlalchemy.Column('datetime', sqlalchemy.types.DateTime, nullable=False), ) -meta.mapper(UserFollowingUser, user_following_user_table) +meta.registry.map_imperatively(UserFollowingUser, user_following_user_table) class UserFollowingDataset( ModelFollowingModel['ckan.model.User', 'ckan.model.Package']): @@ -193,7 +193,7 @@ def _object_class(cls): sqlalchemy.Column('datetime', sqlalchemy.types.DateTime, nullable=False), ) -meta.mapper(UserFollowingDataset, user_following_dataset_table) +meta.registry.map_imperatively(UserFollowingDataset, user_following_dataset_table) class UserFollowingGroup( @@ -225,4 +225,4 @@ def _object_class(cls): sqlalchemy.Column('datetime', sqlalchemy.types.DateTime, nullable=False), ) -meta.mapper(UserFollowingGroup, user_following_group_table) +meta.registry.map_imperatively(UserFollowingGroup, user_following_group_table) diff --git a/ckan/model/group.py b/ckan/model/group.py index 3a3d893aad7..74d0f6233f1 100644 --- a/ckan/model/group.py +++ b/ckan/model/group.py @@ -23,6 +23,7 @@ 'Member', 'member_table'] +Mapped = orm.Mapped member_table = Table('member', meta.metadata, Column('id', types.UnicodeText, @@ -78,12 +79,12 @@ class Member(core.StatefulObjectMixin, in a hierarchy. - capacity is 'parent' ''' - id: str - table_name: Optional[str] - table_id: Optional[str] - capacity: str - group_id: Optional[str] - state: str + id: Mapped[str] + table_name: Mapped[Optional[str]] + table_id: Mapped[Optional[str]] + capacity: Mapped[str] + group_id: Mapped[Optional[str]] + state: Mapped[str] group: Optional['Group'] @@ -156,20 +157,20 @@ def __str__(self): class Group(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - name: str - title: str | None - type: str - description: str - image_url: str - created: datetime.datetime - is_organization: bool - approval_status: str - state: str - - _extras: dict[str, Any] # list['GroupExtra'] + id: Mapped[str] + name: Mapped[str] + title: Mapped[str | None] + type: Mapped[str] + description: Mapped[str] + image_url: Mapped[str] + created: Mapped[datetime.datetime] + is_organization: Mapped[bool] + approval_status: Mapped[str] + state: Mapped[str] + + _extras: Mapped[dict[str, Any]] extras: AssociationProxy - member_all: list[Member] + member_all: Mapped[list[Member]] def __init__(self, name: str = u'', title: str = u'', description: str = u'', image_url: str = u'', @@ -208,8 +209,7 @@ def all(cls, group_type: Optional[str] = None, """ q = meta.Session.query(cls) if state: - # type_ignore_reason: incomplete SQLAlchemy types - q = q.filter(cls.state.in_(state)) # type: ignore + q = q.filter(cls.state.in_(state)) if group_type: q = q.filter(cls.type == group_type) @@ -298,7 +298,7 @@ def get_top_level_groups(cls, type: str='group') -> list[Self]: Member.table_name == 'group', Member.state == 'active')).\ filter( - Member.id == None # type: ignore + Member.id.is_(None) ).filter(Group.type == type).\ filter(Group.state == 'active').\ order_by(Group.title).all() @@ -407,10 +407,9 @@ def search_by_name_or_title( cls, text_query: str, group_type: Optional[str] = None, is_org: bool = False, limit: int = 20) -> Query[Self]: text_query = text_query.strip().lower() - # type_ignore_reason: incomplete SQLAlchemy types q = meta.Session.query(cls) \ - .filter(or_(cls.name.contains(text_query), # type: ignore - cls.title.ilike('%' + text_query + '%'))) # type: ignore + .filter(or_(cls.name.contains(text_query), + cls.title.ilike('%' + text_query + '%'))) if is_org: q = q.filter(cls.type == 'organization') else: @@ -435,12 +434,13 @@ def add_package_by_name(self, package_name: str) -> None: def __repr__(self): return '' % self.name -meta.mapper(Group, group_table) +meta.registry.map_imperatively(Group, group_table) -meta.mapper(Member, member_table, properties={ - 'group': orm.relation(Group, +meta.registry.map_imperatively(Member, member_table, properties={ + 'group': orm.relationship(Group, backref=orm.backref('member_all', - cascade='all, delete-orphan')), + cascade='all, delete-orphan', + cascade_backrefs=False)), }) diff --git a/ckan/model/group_extra.py b/ckan/model/group_extra.py index c94e13c91b6..65fe00f0398 100644 --- a/ckan/model/group_extra.py +++ b/ckan/model/group_extra.py @@ -14,6 +14,7 @@ __all__ = ['GroupExtra', 'group_extra_table'] +Mapped = orm.Mapped group_extra_table = Table('group_extra', meta.metadata, Column('id', types.UnicodeText, primary_key=True, default=_types.make_uuid), Column('group_id', types.UnicodeText, ForeignKey('group.id')), @@ -25,17 +26,17 @@ class GroupExtra(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - group_id: str - key: str - value: str - state: str + id: Mapped[str] + group_id: Mapped[str] + key: Mapped[str] + value: Mapped[str] + state: Mapped[str] group: group.Group -# type_ignore_reason: incomplete SQLAlchemy types -meta.mapper(GroupExtra, group_extra_table, properties={ - 'group': orm.relation(group.Group, + +meta.registry.map_imperatively(GroupExtra, group_extra_table, properties={ + 'group': orm.relationship(group.Group, backref=orm.backref( '_extras', collection_class=orm.collections.attribute_mapped_collection(u'key'), # type: ignore diff --git a/ckan/model/meta.py b/ckan/model/meta.py index 58ba9d7c0c9..2048d29cd33 100644 --- a/ckan/model/meta.py +++ b/ckan/model/meta.py @@ -11,7 +11,6 @@ __all__ = ['Session'] - # SQLAlchemy database engine. Updated by model.init_model() engine: Optional[Engine] = None @@ -79,9 +78,7 @@ def ckan_after_rollback(session: Any): del session._object_cache -#mapper = Session.mapper mapper = orm.mapper -# Global metadata. If you have multiple databases with overlapping table -# names, you'll need a metadata for each database metadata = MetaData() +registry = orm.registry(metadata=metadata) diff --git a/ckan/model/package.py b/ckan/model/package.py index bda6232074f..f49179bcbd4 100644 --- a/ckan/model/package.py +++ b/ckan/model/package.py @@ -16,6 +16,7 @@ from sqlalchemy import orm, types, Column, Table, ForeignKey from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.ext.associationproxy import AssociationProxy from ckan.common import config @@ -35,7 +36,7 @@ Group, ) - +Mapped = orm.Mapped PrintableRelationship: TypeAlias = "tuple[Package, str, Optional[str]]" logger = logging.getLogger(__name__) @@ -92,34 +93,34 @@ class Package(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - name: str - title: str - version: str - url: str - author: str - author_email: str - maintainer: str - maintainer_email: str - notes: str - licensce_id: str - type: str - owner_org: Optional[str] - creator_user_id: str - metadata_created: datetime.datetime - metadata_modified: datetime.datetime - private: bool - state: str - plugin_data: dict[str, Any] - - package_tags: list["PackageTag"] - - resources_all: list["Resource"] - _extras: dict[str, Any] # list['PackageExtra'] - extras: dict[str, Any] - - relationships_as_subject: 'PackageRelationship' - relationships_as_object: 'PackageRelationship' + id: Mapped[str] + name: Mapped[str] + title: Mapped[str] + version: Mapped[str] + url: Mapped[str] + author: Mapped[str] + author_email: Mapped[str] + maintainer: Mapped[str] + maintainer_email: Mapped[str] + notes: Mapped[str] + licensce_id: Mapped[str] + type: Mapped[str] + owner_org: Mapped[Optional[str]] + creator_user_id: Mapped[str] + metadata_created: Mapped[datetime.datetime] + metadata_modified: Mapped[datetime.datetime] + private: Mapped[bool] + state: Mapped[str] + plugin_data: Mapped[dict[str, Any]] + + package_tags: Mapped[list["PackageTag"]] + + resources_all: Mapped[list["Resource"]] + _extras: Mapped[dict[str, Any]] + extras: AssociationProxy + + relationships_as_subject: Mapped['PackageRelationship'] + relationships_as_object: Mapped['PackageRelationship'] _license_register: ClassVar['_license.LicenseRegister'] @@ -128,8 +129,7 @@ class Package(core.StatefulObjectMixin, @classmethod def search_by_name(cls, text_query: str) -> Query[Self]: return meta.Session.query(cls).filter( - # type_ignore_reason: incomplete SQLAlchemy types - cls.name.contains(text_query.lower()) # type: ignore + cls.name.contains(text_query.lower()) ) @classmethod @@ -229,7 +229,7 @@ def get_tags(self, vocab: Optional["Vocabulary"] = None) -> list["Tag"]: if vocab: query = query.filter(model.Tag.vocabulary_id == vocab.id) else: - query = query.filter(model.Tag.vocabulary_id == None) + query = query.filter(model.Tag.vocabulary_id.is_(None)) query = query.order_by(model.Tag.name) tags = query.all() return tags @@ -482,17 +482,16 @@ def extras_list(self) -> list['PackageExtra']: class PackageMember(domain_object.DomainObject): - package_id: str - user_id: str - capacity: str - modified: datetime.datetime + package_id: Mapped[str] + user_id: Mapped[str] + capacity: Mapped[str] + modified: Mapped[datetime.datetime] # import here to prevent circular import from ckan.model import tag -# type_ignore_reason: incomplete SQLAlchemy types -meta.mapper(Package, package_table, properties={ +meta.registry.map_imperatively(Package, package_table, properties={ # delete-orphan on cascade does NOT work! # Why? Answer: because of way SQLAlchemy/our code works there are points # where PackageTag object is created *and* flushed but does not yet have @@ -500,11 +499,13 @@ class PackageMember(domain_object.DomainObject): # second commit happens in which the package_id is correctly set. # However after first commit PackageTag does not have Package and # delete-orphan kicks in to remove it! - 'package_tags':orm.relation(tag.PackageTag, backref='package', + 'package_tags':orm.relationship( + tag.PackageTag, backref='package', cascade='all, delete', #, delete-orphan', - ), - }) + cascade_backrefs=False + ) +}) -meta.mapper(tag.PackageTag, tag.package_tag_table) +meta.registry.map_imperatively(tag.PackageTag, tag.package_tag_table) -meta.mapper(PackageMember, package_member_table) +meta.registry.map_imperatively(PackageMember, package_member_table) diff --git a/ckan/model/package_extra.py b/ckan/model/package_extra.py index 3da85bd0cf7..09a0d5899ba 100644 --- a/ckan/model/package_extra.py +++ b/ckan/model/package_extra.py @@ -14,6 +14,7 @@ __all__ = ['PackageExtra', 'package_extra_table'] +Mapped = orm.Mapped package_extra_table = Table('package_extra', meta.metadata, Column('id', types.UnicodeText, primary_key=True, default=_types.make_uuid), # NB: only (package, key) pair is unique @@ -25,11 +26,11 @@ class PackageExtra(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - package_id: str - key: str - value: str - state: str + id: Mapped[str] + package_id: Mapped[str] + key: Mapped[str] + value: Mapped[str] + state: Mapped[str] package: _package.Package @@ -37,9 +38,8 @@ def related_packages(self) -> list[_package.Package]: return [self.package] -# type_ignore_reason: incomplete SQLAlchemy types -meta.mapper(PackageExtra, package_extra_table, properties={ - 'package': orm.relation(_package.Package, +meta.registry.map_imperatively(PackageExtra, package_extra_table, properties={ + 'package': orm.relationship(_package.Package, backref=orm.backref('_extras', collection_class=orm.collections.attribute_mapped_collection(u'key'), # type: ignore cascade='all, delete, delete-orphan', diff --git a/ckan/model/package_relationship.py b/ckan/model/package_relationship.py index c5215f0805a..cf8dd1558d5 100644 --- a/ckan/model/package_relationship.py +++ b/ckan/model/package_relationship.py @@ -25,6 +25,7 @@ def _(*args: Any, **kwargs: Any) -> str: __all__ = ['PackageRelationship', 'package_relationship_table'] +Mapped = orm.Mapped package_relationship_table = Table('package_relationship', meta.metadata, Column('id', types.UnicodeText, primary_key=True, default=_types.make_uuid), @@ -44,15 +45,15 @@ class PackageRelationship(core.StatefulObjectMixin, from both packages in the relationship and the type is swapped from forward to reverse accordingly, for meaningful display to the user.''' - id: str - subject_package_id: str - object_package_id: str - type: str - comment: str - state: str + id: Mapped[str] + subject_package_id: Mapped[str] + object_package_id: Mapped[str] + type: Mapped[str] + comment: Mapped[str] + state: Mapped[str] - object: _package.Package - subject: _package.Package + object: Mapped[_package.Package] + subject: Mapped[_package.Package] all_types: Optional[list[str]] fwd_types: Optional[list[str]] @@ -185,10 +186,15 @@ def make_type_printable(cls, type_: str) -> str: return cls.types_printable[i][j] raise TypeError(type_) -meta.mapper(PackageRelationship, package_relationship_table, properties={ - 'subject':orm.relation(_package.Package, primaryjoin=\ - package_relationship_table.c["subject_package_id"]==_package.Package.id, - backref='relationships_as_subject'), - 'object':orm.relation(_package.Package, primaryjoin=package_relationship_table.c["object_package_id"]==_package.Package.id, - backref='relationships_as_object'), - }) +meta.registry.map_imperatively(PackageRelationship, package_relationship_table, properties={ + 'subject':orm.relationship( + _package.Package, primaryjoin=\ + package_relationship_table.c["subject_package_id"]==_package.Package.id, + backref=orm.backref('relationships_as_subject', cascade_backrefs=False), + ), + 'object':orm.relationship( + _package.Package, + primaryjoin=package_relationship_table.c["object_package_id"]==_package.Package.id, + backref=orm.backref('relationships_as_object', cascade_backrefs=False) + ), +}) diff --git a/ckan/model/resource.py b/ckan/model/resource.py index ec27295ee66..20b0adda86d 100644 --- a/ckan/model/resource.py +++ b/ckan/model/resource.py @@ -22,6 +22,7 @@ __all__ = ['Resource', 'resource_table'] +Mapped = orm.Mapped CORE_RESOURCE_COLUMNS = ['url', 'format', 'description', 'hash', 'name', 'resource_type', 'mimetype', 'mimetype_inner', 'size', 'created', 'last_modified', @@ -60,29 +61,29 @@ class Resource(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - package_id: Optional[str] - url: str - format: str - description: str - hash: str - position: int - name: str - resource_type: str - mimetype: str - size: int - created: datetime.datetime - last_modified: datetime.datetime - metadata_modified: datetime.datetime - cache_url: str - cache_last_update: datetime.datetime - url_type: str + id: Mapped[str] + package_id: Mapped[Optional[str]] + url: Mapped[str] + format: Mapped[str] + description: Mapped[str] + hash: Mapped[str] + position: Mapped[int] + name: Mapped[str] + resource_type: Mapped[str] + mimetype: Mapped[str] + size: Mapped[int] + created: Mapped[datetime.datetime] + last_modified: Mapped[datetime.datetime] + metadata_modified: Mapped[datetime.datetime] + cache_url: Mapped[str] + cache_last_update: Mapped[datetime.datetime] + url_type: Mapped[str] extras: dict[str, Any] - state: str + state: Mapped[str] extra_columns: ClassVar[Optional[list[str]]] = None - package: Package + package: Mapped[Package] url_changed: Optional[bool] @@ -164,8 +165,8 @@ def related_packages(self) -> list[Package]: ## Mappers -meta.mapper(Resource, resource_table, properties={ - 'package': orm.relation( +meta.registry.map_imperatively(Resource, resource_table, properties={ + 'package': orm.relationship( Package, # all resources including deleted # formally package_resources_all diff --git a/ckan/model/resource_view.py b/ckan/model/resource_view.py index c9c087d1cea..3ee4321feef 100644 --- a/ckan/model/resource_view.py +++ b/ckan/model/resource_view.py @@ -4,6 +4,7 @@ from typing import Any, Collection, Optional import sqlalchemy as sa +from sqlalchemy.orm import Mapped from typing_extensions import Self import ckan.model.meta as meta @@ -28,13 +29,13 @@ class ResourceView(domain_object.DomainObject): - id: str - resource_id: str - title: Optional[str] - description: Optional[str] - view_type: str - order: int - config: dict[str, Any] + id: Mapped[str] + resource_id: Mapped[str] + title: Mapped[Optional[str]] + description: Mapped[Optional[str]] + view_type: Mapped[str] + order: Mapped[int] + config: Mapped[dict[str, Any]] @classmethod def get(cls, reference: str) -> Optional[Self]: @@ -56,18 +57,16 @@ def get_count_not_in_view_types( view_type = cls.view_type query: 'Query[tuple[str, int]]' = meta.Session.query( view_type, sa.func.count(cls.id)).group_by(view_type).filter( - # type_ignore_reason: incomplete SQLAlchemy types - sa.not_(view_type.in_(view_types))) # type: ignore + sa.not_(view_type.in_(view_types))) return query.all() @classmethod def delete_not_in_view_types(cls, view_types: Collection[str]) -> int: '''Delete the Resource Views not in the received view types list''' - query = meta.Session.query(cls) \ - .filter(sa.not_( - # type_ignore_reason: incomplete SQLAlchemy types - cls.view_type.in_(view_types))) # type: ignore + query = meta.Session.query(cls).filter( + sa.not_(cls.view_type.in_(view_types)) + ) return query.delete(synchronize_session='fetch') @@ -77,11 +76,9 @@ def delete_all(cls, view_types: Optional[Collection[str]] = None) -> int: query = meta.Session.query(cls) if view_types: - query = query.filter( - # type_ignore_reason: incomplete SQLAlchemy types - cls.view_type.in_(view_types)) # type: ignore + query = query.filter(cls.view_type.in_(view_types)) return query.delete(synchronize_session='fetch') -meta.mapper(ResourceView, resource_view_table) +meta.registry.map_imperatively(ResourceView, resource_view_table) diff --git a/ckan/model/system_info.py b/ckan/model/system_info.py index acb08cce401..238bd2502d1 100644 --- a/ckan/model/system_info.py +++ b/ckan/model/system_info.py @@ -10,6 +10,7 @@ from typing import Any, Optional from sqlalchemy import types, Column, Table +from sqlalchemy.orm import Mapped from sqlalchemy.exc import ProgrammingError @@ -31,10 +32,10 @@ class SystemInfo(core.StatefulObjectMixin, domain_object.DomainObject): - id: int - key: str - value: str - state: str + id: Mapped[int] + key: Mapped[str] + value: Mapped[str] + state: Mapped[str] def __init__(self, key: str, value: Any) -> None: @@ -44,7 +45,7 @@ def __init__(self, key: str, value: Any) -> None: self.value = str(value) -meta.mapper(SystemInfo, system_info_table) +meta.registry.map_imperatively(SystemInfo, system_info_table) def get_system_info(key: str, default: Optional[str]=None) -> Optional[str]: diff --git a/ckan/model/tag.py b/ckan/model/tag.py index c047c94f359..3034f50696c 100644 --- a/ckan/model/tag.py +++ b/ckan/model/tag.py @@ -3,7 +3,7 @@ from typing import Optional, Any -from sqlalchemy.orm import relation +from sqlalchemy.orm import relationship, Mapped from sqlalchemy import types, Column, Table, ForeignKey, UniqueConstraint from typing_extensions import Self @@ -45,12 +45,12 @@ class Tag(domain_object.DomainObject): - id: str - name: str - vocabulary_id: Optional[str] + id: Mapped[str] + name: Mapped[str] + vocabulary_id: Mapped[Optional[str]] - package_tags: list['PackageTag'] - vocabulary: Optional['ckan.model.Vocabulary'] + package_tags: Mapped[list['PackageTag']] + vocabulary: Mapped[Optional['ckan.model.Vocabulary']] def __init__(self, name: str='', vocabulary_id: Optional[str]=None) -> None: self.name = name @@ -108,7 +108,7 @@ def by_name( Tag.vocabulary_id==vocab.id) else: query = meta.Session.query(Tag).filter(Tag.name==name).filter( - Tag.vocabulary_id==None) + Tag.vocabulary_id.is_(None)) query = query.autoflush(autoflush) tag = query.first() return tag @@ -183,8 +183,7 @@ def search_by_name( else: query = meta.Session.query(Tag) search_term = search_term.strip().lower() - # type_ignore_reason: incomplete SQLAlchemy types - query = query.filter(Tag.name.contains(search_term)) # type: ignore + query = query.filter(Tag.name.contains(search_term)) query: 'Query[Tag]' = query.distinct().join(Tag.package_tags) return query @@ -216,7 +215,7 @@ def all(cls, vocab_id_or_name: Optional[str]=None) -> Query[Self]: filter(PackageTag.state == 'active').subquery() query = meta.Session.query(Tag).\ - filter(Tag.vocabulary_id == None).\ + filter(Tag.vocabulary_id.is_(None)).\ distinct().\ join(subquery, Tag.id==subquery.c.tag_id) @@ -242,14 +241,14 @@ def __repr__(self) -> str: class PackageTag(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - package_id: str - tag_id: str - state: Optional[str] + id: Mapped[str] + package_id: Mapped[str] + tag_id: Mapped[str] + state: Mapped[Optional[str]] - pkg: Optional['ckan.model.Package'] - package: Optional['ckan.model.Package'] - tag: Optional[Tag] + pkg: Mapped[Optional['ckan.model.Package']] + package: Mapped[Optional['ckan.model.Package']] + tag: Mapped[Optional[Tag]] def __init__( self, package: Optional['ckan.model.Package'] = None, @@ -318,12 +317,12 @@ def related_packages(self) -> list['ckan.model.Package']: return [self.package] return [] -# type_ignore_reason: incomplete SQLAlchemy types -meta.mapper(Tag, tag_table, properties={ - 'package_tags': relation(PackageTag, backref='tag', - cascade='all, delete, delete-orphan', +meta.registry.map_imperatively(Tag, tag_table, properties={ + 'package_tags': relationship(PackageTag, backref='tag', + cascade='all, delete, delete-orphan', + cascade_backrefs=False, ), - 'vocabulary': relation(vocabulary.Vocabulary, + 'vocabulary': relationship(vocabulary.Vocabulary, order_by=tag_table.c["name"]) }) diff --git a/ckan/model/task_status.py b/ckan/model/task_status.py index 4758127e8f4..23db147916c 100644 --- a/ckan/model/task_status.py +++ b/ckan/model/task_status.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional from sqlalchemy import types, Column, Table, UniqueConstraint +from sqlalchemy.orm import Mapped from typing_extensions import Self import ckan.model.meta as meta @@ -25,15 +26,15 @@ ) class TaskStatus(domain_object.DomainObject): - id: str - entity_id: str - entuty_type: str - task_type: str - key: str - value: str - state: str - error: str - last_updated: datetime + id: Mapped[str] + entity_id: Mapped[str] + entuty_type: Mapped[str] + task_type: Mapped[str] + key: Mapped[str] + value: Mapped[str] + state: Mapped[str] + error: Mapped[str] + last_updated: Mapped[datetime] @classmethod def get(cls, reference: str) -> Optional[Self]: @@ -44,4 +45,4 @@ def get(cls, reference: str) -> Optional[Self]: task = meta.Session.query(cls).get(reference) return task -meta.mapper(TaskStatus, task_status_table) +meta.registry.map_imperatively(TaskStatus, task_status_table) diff --git a/ckan/model/types.py b/ckan/model/types.py index 79f06c61767..1e46ab63057 100644 --- a/ckan/model/types.py +++ b/ckan/model/types.py @@ -43,6 +43,8 @@ class JsonType(types.TypeDecorator): # type: ignore ''' impl = types.UnicodeText + cache_ok = False + def process_bind_param(self, value: Any, dialect: Any): # ensure we stores nulls in db not json "null" if value is None or value == {}: @@ -71,6 +73,8 @@ class JsonDictType(JsonType): impl = types.UnicodeText + cache_ok = False + def process_bind_param(self, value: Any, dialect: Any): # ensure we stores nulls in db not json "null" if value is None or value == {}: diff --git a/ckan/model/user.py b/ckan/model/user.py index acbafcf6f64..36b0d46c4fd 100644 --- a/ckan/model/user.py +++ b/ckan/model/user.py @@ -10,7 +10,7 @@ import passlib.utils from passlib.hash import pbkdf2_sha512 from sqlalchemy.sql.expression import or_ -from sqlalchemy.orm import synonym +from sqlalchemy.orm import synonym, Mapped from sqlalchemy import types, Column, Table, func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.mutable import MutableDict @@ -65,22 +65,22 @@ def set_api_key() -> Optional[str]: class User(core.StatefulObjectMixin, domain_object.DomainObject): - id: str - name: str + id: Mapped[str] + name: Mapped[str] # password: str - fullname: Optional[str] - email: Optional[str] - apikey: Optional[str] - created: datetime.datetime - reset_key: str - about: str - activity_streams_email_notifications: bool - sysadmin: bool - state: str - image_url: str - plugin_extras: dict[str, Any] - - api_tokens: list['ApiToken'] + fullname: Mapped[Optional[str]] + email: Mapped[Optional[str]] + apikey: Mapped[Optional[str]] + created: Mapped[datetime.datetime] + reset_key: Mapped[str] + about: Mapped[str] + activity_streams_email_notifications: Mapped[bool] + sysadmin: Mapped[bool] + state: Mapped[str] + image_url: Mapped[str] + plugin_extras: Mapped[dict[str, Any]] + + api_tokens: Mapped[list['ApiToken']] VALID_NAME = re.compile(r"^[a-zA-Z0-9_\-]{3,255}$") DOUBLE_SLASH = re.compile(r':\/([^/])') @@ -230,8 +230,8 @@ def number_created_packages(self, include_private_and_draft: bool=False) -> int: q = q.filter_by(state='active', private=False) result: int = meta.Session.execute( - q.statement.with_only_columns( - [func.count()] + q.statement.with_only_columns( # type: ignore + func.count() ).order_by( None ) @@ -300,15 +300,13 @@ def search(cls, querystr: str, query = sqlalchemy_query qstr = '%' + querystr + '%' filters: list[Any] = [ - # type_ignore_reason: incomplete SQLAlchemy types - cls.name.ilike(qstr), # type: ignore - cls.fullname.ilike(qstr), # type: ignore + cls.name.ilike(qstr), + cls.fullname.ilike(qstr), ] # sysadmins can search on user emails import ckan.authz as authz if user_name and authz.is_sysadmin(user_name): - # type_ignore_reason: incomplete SQLAlchemy types - filters.append(cls.email.ilike(qstr)) # type: ignore + filters.append(cls.email.ilike(qstr)) query = query.filter(or_(*filters)) return query @@ -320,9 +318,8 @@ def user_ids_for_name_or_id(cls, user_list: Iterable[str]=()) -> list[str]: names or ids ''' query: Any = meta.Session.query(cls.id) - # type_ignore_reason: incomplete SQLAlchemy types - query = query.filter(or_(cls.name.in_(user_list), # type: ignore - cls.id.in_(user_list))) # type: ignore + query = query.filter(or_(cls.name.in_(user_list), + cls.id.in_(user_list))) return [user.id for user in query.all()] def get_id(self) -> str: @@ -365,7 +362,7 @@ class AnonymousUser(AnonymousUserMixin): email: str = "" -meta.mapper( +meta.registry.map_imperatively( User, user_table, properties={'password': synonym('_password', map_column=True)} ) diff --git a/ckan/model/vocabulary.py b/ckan/model/vocabulary.py index 850d795a065..a5b41986881 100644 --- a/ckan/model/vocabulary.py +++ b/ckan/model/vocabulary.py @@ -3,6 +3,7 @@ from typing import Optional, TYPE_CHECKING from sqlalchemy import types, Column, Table +from sqlalchemy.orm import Mapped from typing_extensions import Self import ckan.model.meta as meta @@ -27,8 +28,8 @@ class Vocabulary(domain_object.DomainObject): - id: str - name: str + id: Mapped[str] + name: Mapped[str] def __init__(self, name: str) -> None: self.id = _types.make_uuid() @@ -51,4 +52,4 @@ def tags(self) -> Query[Tag]: query = meta.Session.query(tag.Tag) return query.filter(tag.Tag.vocabulary_id == self.id) -meta.mapper(Vocabulary, vocabulary_table) +meta.registry.map_imperatively(Vocabulary, vocabulary_table) diff --git a/ckan/templates/user/list.html b/ckan/templates/user/list.html index 08296fb0914..935581740ae 100644 --- a/ckan/templates/user/list.html +++ b/ckan/templates/user/list.html @@ -16,7 +16,7 @@

diff --git a/ckan/tests/cli/test_db.py b/ckan/tests/cli/test_db.py index 8bf62e52d10..2f184567b4f 100644 --- a/ckan/tests/cli/test_db.py +++ b/ckan/tests/cli/test_db.py @@ -3,6 +3,8 @@ import os import pytest +from sqlalchemy import inspect + import ckan.migration as migration import ckanext.example_database_migrations.plugin as example_plugin @@ -47,9 +49,10 @@ def test_current_migration_version(self): assert version == "base" def check_upgrade(self, has_x, has_y, expected_version): - has_table = model.Session.bind.has_table - assert has_table("example_database_migrations_x") is has_x - assert has_table("example_database_migrations_y") is has_y + inspector = inspect(model.Session.bind) + + assert inspector.has_table("example_database_migrations_x") is has_x + assert inspector.has_table("example_database_migrations_y") is has_y version = db.current_revision("example_database_migrations") assert version == expected_version diff --git a/ckan/tests/logic/action/test_create.py b/ckan/tests/logic/action/test_create.py index df84ed207a0..a2676e825e2 100644 --- a/ckan/tests/logic/action/test_create.py +++ b/ckan/tests/logic/action/test_create.py @@ -6,7 +6,7 @@ import operator import unittest.mock as mock import pytest - +import sqlalchemy as sa import ckan.logic as logic from ckan.logic.action.get import package_show as core_package_show @@ -1783,7 +1783,7 @@ def test_stored_on_create_if_sysadmin(self): plugin_extras_from_db = ( model.Session.execute( - 'SELECT plugin_extras FROM "user" WHERE id=:id', + sa.text('SELECT plugin_extras FROM "user" WHERE id=:id'), {"id": created_user["id"]}, ) .first()[0] @@ -2178,7 +2178,7 @@ def test_stored_on_create_if_sysadmin(self): } } plugin_data_from_db = model.Session.execute( - 'SELECT plugin_data FROM "package" WHERE id=:id', + sa.text('SELECT plugin_data FROM "package" WHERE id=:id'), {'id': created_pkg["id"]} ).first()[0] @@ -2205,7 +2205,7 @@ def test_ignored_on_create_if_non_sysadmin(self): assert "plugin_data" not in created_pkg plugin_data_from_db = model.Session.execute( - 'SELECT plugin_data FROM "package" WHERE id=:id', + sa.text('SELECT plugin_data FROM "package" WHERE id=:id'), {'id': created_pkg["id"]} ).first()[0] assert plugin_data_from_db is None diff --git a/ckan/tests/logic/action/test_update.py b/ckan/tests/logic/action/test_update.py index ee00190abab..830a39c75bb 100644 --- a/ckan/tests/logic/action/test_update.py +++ b/ckan/tests/logic/action/test_update.py @@ -4,6 +4,7 @@ import unittest.mock as mock import pytest +import sqlalchemy as sa import ckan.lib.app_globals as app_globals import ckan.logic as logic @@ -1983,7 +1984,7 @@ def test_stored_on_update_if_sysadmin(self): plugin_extras_from_db = ( model.Session.execute( - 'SELECT plugin_extras FROM "user" WHERE id=:id', + sa.text('SELECT plugin_extras FROM "user" WHERE id=:id'), {"id": user["id"]}, ) .first()[0] @@ -2081,7 +2082,7 @@ def test_nested_updates_are_reflected_in_db(self): plugin_extras = ( model.Session.execute( - 'SELECT plugin_extras FROM "user" WHERE id=:id', + sa.text('SELECT plugin_extras FROM "user" WHERE id=:id'), {"id": user["id"]}, ) .first()[0] @@ -2444,7 +2445,7 @@ def test_stored_on_update_if_sysadmin(self): } plugin_data_from_db = model.Session.execute( - 'SELECT plugin_data from "package" where id=:id', + sa.text('SELECT plugin_data from "package" where id=:id'), {"id": dataset["id"]} ).first() diff --git a/ckan/tests/pytest_ckan/ckan_setup.py b/ckan/tests/pytest_ckan/ckan_setup.py index 2f2af13bddb..c65e38ac2fa 100644 --- a/ckan/tests/pytest_ckan/ckan_setup.py +++ b/ckan/tests/pytest_ckan/ckan_setup.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - import ckan.plugins as plugins from ckan.config.middleware import make_app from ckan.cli import load_config diff --git a/ckan/tests/pytest_ckan/test_fixtures.py b/ckan/tests/pytest_ckan/test_fixtures.py index 674a7ef7a37..ab50a0970d8 100644 --- a/ckan/tests/pytest_ckan/test_fixtures.py +++ b/ckan/tests/pytest_ckan/test_fixtures.py @@ -4,6 +4,7 @@ import pytest from urllib.parse import urlparse +from sqlalchemy import inspect import ckan.plugins as plugins from ckan.common import config, asbool @@ -115,14 +116,16 @@ class TestMigrateDbFor(object): @pytest.mark.usefixtures("with_plugins", "clean_db") def test_migrations_applied(self, migrate_db_for): import ckan.model as model - has_table = model.Session.bind.has_table - assert not has_table("example_database_migrations_x") - assert not has_table("example_database_migrations_y") + + inspector = inspect(model.Session.bind) + assert not inspector.has_table("example_database_migrations_x") + assert not inspector.has_table("example_database_migrations_y") migrate_db_for("example_database_migrations") - assert has_table("example_database_migrations_x") - assert has_table("example_database_migrations_y") + inspector = inspect(model.Session.bind) + assert inspector.has_table("example_database_migrations_x") + assert inspector.has_table("example_database_migrations_y") @pytest.mark.usefixtures("non_clean_db") diff --git a/ckan/types/model.py b/ckan/types/model.py index 8e77e4e1a53..06a3668a6b0 100644 --- a/ckan/types/model.py +++ b/ckan/types/model.py @@ -4,7 +4,7 @@ from typing_extensions import Protocol from sqlalchemy.orm.scoping import ScopedSession -from sqlalchemy.orm import Query, sessionmaker +from sqlalchemy.orm import Query, sessionmaker, Session from sqlalchemy import Table if TYPE_CHECKING: @@ -16,7 +16,7 @@ class Meta(Protocol): - create_local_session: sessionmaker + create_local_session: "sessionmaker[Session]" class Model(Protocol): diff --git a/ckan/views/admin.py b/ckan/views/admin.py index e406095d3b8..d2e2c051bab 100644 --- a/ckan/views/admin.py +++ b/ckan/views/admin.py @@ -28,8 +28,7 @@ def _get_sysadmins() -> "Query[model.User]": q = model.Session.query(model.User).filter( - # type_ignore_reason: incomplete SQLAlchemy types - model.User.sysadmin.is_(True), # type: ignore + model.User.sysadmin.is_(True), model.User.state == u'active') return q diff --git a/ckan/views/api.py b/ckan/views/api.py index 07452e8c5fc..277548da240 100644 --- a/ckan/views/api.py +++ b/ckan/views/api.py @@ -241,7 +241,7 @@ def action(logic_function: str, ver: int = API_DEFAULT_VERSION) -> Response: u'api_version': ver, u'auth_user_obj': current_user } - model.Session()._context = context + model.Session()._context = context # type: ignore return_dict: dict[str, Any] = { u'help': url_for(u'api.action', diff --git a/ckan/views/user.py b/ckan/views/user.py index 6a1efda8fbb..7aeb44163fd 100644 --- a/ckan/views/user.py +++ b/ckan/views/user.py @@ -13,6 +13,7 @@ import ckan.lib.authenticator as authenticator import ckan.lib.base as base import ckan.lib.captcha as captcha +from ckan.lib.dictization import model_dictize import ckan.lib.helpers as h import ckan.lib.mailer as mailer import ckan.lib.maintain as maintain @@ -86,6 +87,10 @@ def index(): order_by = request.args.get('order_by', 'name') default_limit: int = config.get('ckan.user_list_limit') limit = int(request.args.get('limit', default_limit)) + offset = page_number * limit - limit + + # get SQLAlchemy Query object from the action to avoid dictizing all + # existing users at once context: Context = { u'return_query': True, u'user': current_user.name, @@ -93,8 +98,8 @@ def index(): } data_dict = { - u'q': q, - u'order_by': order_by + 'q': q, + 'order_by': order_by, } try: @@ -104,9 +109,18 @@ def index(): users_list = logic.get_action(u'user_list')(context, data_dict) + # in template we don't need complex row objects from query. Let's dictize + # subset of users that are shown on the current page + users = [ + model_dictize.user_dictize(user[0], context) + for user in + users_list.limit(limit).offset(offset) + ] + page = h.Page( - collection=users_list, + collection=users, page=page_number, + presliced_list=True, url=h.pager_url, item_count=users_list.count(), items_per_page=limit) diff --git a/ckanext/activity/model/activity.py b/ckanext/activity/model/activity.py index d3b61a6847a..5ae57e70bed 100644 --- a/ckanext/activity/model/activity.py +++ b/ckanext/activity/model/activity.py @@ -5,7 +5,7 @@ from typing import Any, Iterable, Optional, Type, TypeVar from typing_extensions import TypeAlias -from sqlalchemy.orm import relationship, backref +from sqlalchemy.orm import relationship, backref, Mapped from sqlalchemy import ( types, Column, @@ -13,7 +13,6 @@ or_, and_, not_, - union_all, text, ) @@ -46,15 +45,13 @@ class Activity(domain_object.DomainObject, BaseModel): # type: ignore ) timestamp = Column("timestamp", types.DateTime) user_id = Column("user_id", types.UnicodeText) - object_id = Column("object_id", types.UnicodeText) + object_id: Any = Column("object_id", types.UnicodeText) # legacy revision_id values are used by migrate_package_activity.py revision_id = Column("revision_id", types.UnicodeText) activity_type = Column("activity_type", types.UnicodeText) data = Column("data", _types.JsonDictType) permission_labels = Column("permission_labels", types.Text) - activity_detail: "ActivityDetail" - def __init__( self, user_id: str, @@ -157,7 +154,7 @@ def activity_list_dictize( # deprecated -class ActivityDetail(domain_object.DomainObject): +class ActivityDetail(domain_object.DomainObject, BaseModel): # type: ignore __tablename__ = "activity_detail" id = Column( "id", types.UnicodeText, primary_key=True, default=_types.make_uuid @@ -170,7 +167,7 @@ class ActivityDetail(domain_object.DomainObject): activity_type = Column("activity_type", types.UnicodeText) data = Column("data", _types.JsonDictType) - activity = relationship( + activity: Mapped[Activity] = relationship( # type: ignore Activity, backref=backref("activity_detail", cascade="all, delete-orphan"), ) @@ -232,11 +229,10 @@ def _activities_union_all(*qlist: QActivity) -> QActivity: Return union of two or more activity queries sorted by timestamp, and remove duplicates """ - q: QActivity = ( - model.Session.query(Activity) - .select_entity_from(union_all(*[q.subquery().select() for q in qlist])) - .distinct(Activity.timestamp) - ) + q, *rest = qlist + for query in rest: + q = q.union(query) + return q diff --git a/ckanext/datapusher/logic/action.py b/ckanext/datapusher/logic/action.py index 798db5d375a..adccccfdaf5 100644 --- a/ckanext/datapusher/logic/action.py +++ b/ckanext/datapusher/logic/action.py @@ -126,7 +126,9 @@ def datapusher_submit(context: Context, data_dict: dict[str, Any]): # Use local session for task_status_update, so it can commit its own # results without messing up with the parent session that contains pending # updats of dataset/resource/etc. - context['session'] = context['model'].meta.create_local_session() + context.update({ + 'session': context['model'].meta.create_local_session() # type: ignore + }) p.toolkit.get_action('task_status_update')(context, task) timeout = config.get('ckan.requests.timeout') diff --git a/ckanext/datapusher/tests/test_action.py b/ckanext/datapusher/tests/test_action.py index 69dff7d86f6..1d680079e8a 100644 --- a/ckanext/datapusher/tests/test_action.py +++ b/ckanext/datapusher/tests/test_action.py @@ -23,9 +23,23 @@ def _pending_task(resource_id): } +@pytest.fixture() +def with_datapusher_token(non_clean_db, ckan_config, monkeypatch): + """Set mandatory datapusher option. + + It must be applied before `datapusher` plugin is loaded via `with_plugins`, + but after DB initialization via `non_clean_db`. + + """ + monkeypatch.setitem( + ckan_config, + "ckan.datapusher.api_token", + get_api_token(), + ) + + @pytest.mark.ckan_config("ckan.plugins", "datapusher datastore") -@pytest.mark.ckan_config("ckan.datapusher.api_token", get_api_token()) -@pytest.mark.usefixtures("non_clean_db", "with_plugins") +@pytest.mark.usefixtures("non_clean_db", "with_datapusher_token", "with_plugins") class TestSubmit: def test_submit(self, monkeypatch): """Auto-submit when creating a resource with supported format. diff --git a/ckanext/datastore/backend/postgres.py b/ckanext/datastore/backend/postgres.py index 7d3918ddb82..f1242047b54 100644 --- a/ckanext/datastore/backend/postgres.py +++ b/ckanext/datastore/backend/postgres.py @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- from __future__ import annotations + +import itertools + from typing_extensions import TypeAlias -import sqlalchemy.exc from sqlalchemy.engine.base import Engine +from sqlalchemy.dialects.postgresql import REGCLASS from ckan.types import Context, ErrorDict import copy import logging @@ -11,7 +14,7 @@ from typing import ( Any, Callable, Container, Dict, Iterable, Optional, Set, Union, cast) -import sqlalchemy +import sqlalchemy as sa import os import pprint import sqlalchemy.engine.url as sa_url @@ -52,6 +55,7 @@ _pg_types: dict[str, str] = {} _type_names: Set[str] = set() _engines: Dict[str, Engine] = {} +WhereClauses: TypeAlias = "list[tuple[str, dict[str, Any]] | tuple[str]]" _TIMEOUT = 60000 # milliseconds @@ -126,9 +130,9 @@ def _get_engine_from_url(connection_url: str, **kwargs: Any) -> Engine: config.setdefault('ckan.datastore.sqlalchemy.pool_pre_ping', True) for key, value in kwargs.items(): config.setdefault(key, value) - engine = sqlalchemy.engine_from_config(config, - 'ckan.datastore.sqlalchemy.', - **extras) + engine = sa.engine_from_config(config, + 'ckan.datastore.sqlalchemy.', + **extras) _engines[connection_url] = engine # don't automatically convert to python objects @@ -136,7 +140,7 @@ def _get_engine_from_url(connection_url: str, **kwargs: Any) -> Engine: # http://initd.org/psycopg/docs/extras.html#adapt-json _loads: Callable[[Any], Any] = lambda x: x register_default_json( - conn_or_curs=engine.raw_connection().connection, + conn_or_curs=engine.raw_connection().connection, # type: ignore globally=False, loads=_loads) @@ -211,8 +215,8 @@ def _result_fields(fields_types: 'OrderedDict[str, str]', return result_fields -def _get_type(connection: Any, oid: str) -> str: - _cache_types(connection) +def _get_type(engine: Engine, oid: str) -> str: + _cache_types(engine) return _pg_types[oid] @@ -249,7 +253,7 @@ def _guess_type(field: Any): def _get_unique_key(context: Context, data_dict: dict[str, Any]) -> list[str]: - sql_get_unique_key = ''' + sql_get_unique_key = sa.text(""" SELECT a.attname AS column_names FROM @@ -263,17 +267,20 @@ def _get_unique_key(context: Context, data_dict: dict[str, Any]) -> list[str]: AND t.relkind = 'r' AND idx.indisunique = true AND idx.indisprimary = false - AND t.relname = %s - ''' - key_parts = context['connection'].execute(sql_get_unique_key, - data_dict['resource_id']) + AND t.relname = :relname + """) + + key_parts = context['connection'].execute( + sql_get_unique_key, + {"relname": data_dict['resource_id']}, + ) return [x[0] for x in key_parts] def _get_field_info(connection: Any, resource_id: str) -> dict[str, Any]: u'''return a dictionary mapping column names to their info data, when present''' - qtext = sqlalchemy.text(u''' + qtext = sa.text(u''' select pa.attname as name, pd.description as info from pg_class pc, pg_attribute pa, pg_description pd where pa.attrelid = pc.oid and pd.objoid = pc.oid @@ -283,7 +290,7 @@ def _get_field_info(connection: Any, resource_id: str) -> dict[str, Any]: try: return dict( (n, json.loads(v)) for (n, v) in - connection.execute(qtext, res_id=resource_id).fetchall()) + connection.execute(qtext, {"res_id": resource_id}).fetchall()) except ValueError: # don't die on non-json comments return {} @@ -294,50 +301,60 @@ def _get_fields(connection: Any, resource_id: str): for the passed resource_id, excluding '_'-prefixed columns. ''' fields: list[dict[str, Any]] = [] - all_fields = connection.execute( - u'SELECT * FROM "{0}" LIMIT 1'.format(resource_id) - ) + all_fields = connection.execute(sa.select( + sa.text("*") + ).select_from(sa.table(resource_id)).limit(1)) + for field in all_fields.cursor.description: if not field[0].startswith('_'): fields.append({ 'id': str(field[0]), - 'type': _get_type(connection, field[1]) + 'type': _get_type(connection.engine, field[1]) }) return fields -def _cache_types(connection: Any) -> None: +def _cache_types(engine: Engine) -> None: if not _pg_types: - results = connection.execute( - 'SELECT oid, typname FROM pg_type;' - ) + with engine.begin() as conn: + results = conn.execute(sa.text( + 'SELECT oid, typname FROM pg_type;' + )) for result in results: _pg_types[result[0]] = result[1] _type_names.add(result[1]) if 'nested' not in _type_names: - native_json = _pg_version_is_at_least(connection, '9.2') + with engine.begin() as conn: + native_json = _pg_version_is_at_least(conn, '9.2') log.info("Create nested type. Native JSON: {0!r}".format( native_json)) backend = DatastorePostgresqlBackend.get_active_backend() - engine: Engine = backend._get_write_engine() # type: ignore - with cast(Any, engine.begin()) as write_connection: - write_connection.execute( + write_engine: Engine = backend._get_write_engine() # type: ignore + with write_engine.begin() as write_connection: + write_connection.execute(sa.text( 'CREATE TYPE "nested" AS (json {0}, extra text)'.format( - 'json' if native_json else 'text')) + 'json' if native_json else 'text'))) _pg_types.clear() # redo cache types with json now available. - return _cache_types(connection) + return _cache_types(engine) - register_composite('nested', connection.connection.connection, True) + with engine.connect() as conn: + register_composite( + 'nested', + conn.connection.connection, + True + ) def _pg_version_is_at_least(connection: Any, version: Any): try: v = distutils.version.LooseVersion(version) - pg_version = connection.execute('select version();').fetchone() + pg_version = connection.execute( + sa.text('select version();') + ).fetchone() pg_version_number = pg_version[0].split()[1] pv = distutils.version.LooseVersion(pg_version_number) return v <= pv @@ -369,24 +386,32 @@ def _validate_record(record: Any, num: int, field_names: Iterable[str]): def _where_clauses( data_dict: dict[str, Any], fields_types: dict[str, Any] -) -> list[Any]: +) -> WhereClauses: filters = data_dict.get('filters', {}) - clauses: list[Any] = [] + clauses: WhereClauses = [] + + idx_gen = itertools.count() for field, value in filters.items(): if field not in fields_types: continue field_array_type = _is_array_type(fields_types[field]) - # "%" needs to be escaped as "%%" in any query going to - # connection.execute, otherwise it will think the "%" is for - # substituting a bind parameter - field = field.replace('%', '%%') + if isinstance(value, list) and not field_array_type: - clause_str = (u'"{0}" in ({1})'.format(field, - ','.join(['%s'] * len(value)))) - clause = (clause_str,) + tuple(value) + placeholders = [ + f"value_{next(idx_gen)}" for _ in value + ] + clause_str = ('{0} in ({1})'.format( + sa.column(field), + ','.join(f":{p}" for p in placeholders) + )) + clause = (clause_str, dict(zip(placeholders, value))) else: - clause: tuple[Any, ...] = (u'"{0}" = %s'.format(field), value) + placeholder = f"value_{next(idx_gen)}" + clause: tuple[Any, ...] = ( + f'{sa.column(field)} = :{placeholder}', + {placeholder: value} + ) clauses.append(clause) # add full-text search where clause @@ -437,7 +462,8 @@ def _where_clauses( def _update_where_clauses_on_q_dict( data_dict: dict[str, str], fields_types: dict[str, str], - q: dict[str, str], clauses: list[tuple[str]]) -> None: + q: dict[str, str], + clauses: WhereClauses) -> None: lang = _fts_lang(data_dict.get('language')) for field, _ in q.items(): if field not in fields_types: @@ -585,20 +611,24 @@ def _ts_query_alias(field: Optional[str] = None): def _get_aliases(context: Context, data_dict: dict[str, Any]): '''Get a list of aliases for a resource.''' res_id = data_dict['resource_id'] - alias_sql = sqlalchemy.text( + alias_sql = sa.text( u'SELECT name FROM "_table_metadata" WHERE alias_of = :id') - results = context['connection'].execute(alias_sql, id=res_id).fetchall() - return [x[0] for x in results] + return [ + item[0] for item in + context['connection'].execute(alias_sql, {"id": res_id}) + ] def _get_resources(context: Context, alias: str): '''Get a list of resources for an alias. There could be more than one alias in a resource_dict.''' - alias_sql = sqlalchemy.text( + alias_sql = sa.text( u'''SELECT alias_of FROM "_table_metadata" WHERE name = :alias AND alias_of IS NOT NULL''') - results = context['connection'].execute(alias_sql, alias=alias).fetchall() - return [x[0] for x in results] + return [ + item[0] for item in + context['connection'].execute(alias_sql, {"alias": alias}) + ] def create_alias(context: Context, data_dict: dict[str, Any]): @@ -610,7 +640,7 @@ def create_alias(context: Context, data_dict: dict[str, Any]): previous_aliases = _get_aliases(context, data_dict) for alias in previous_aliases: sql_alias_drop_string = u'DROP VIEW "{0}"'.format(alias) - context['connection'].execute(sql_alias_drop_string) + context['connection'].execute(sa.text(sql_alias_drop_string)) try: for alias in aliases: @@ -627,7 +657,9 @@ def create_alias(context: Context, data_dict: dict[str, Any]): alias)] }) - context['connection'].execute(sql_alias_string) + context['connection'].execute(sa.text( + sql_alias_string + )) except DBAPIError as e: if e.orig.pgcode in [_PG_ERR_CODE['duplicate_table'], _PG_ERR_CODE['duplicate_alias']]: @@ -683,7 +715,7 @@ def cast_as_text(x: str): def _drop_indexes(context: Context, data_dict: dict[str, Any], unique: bool = False): - sql_drop_index = u'DROP INDEX "{0}" CASCADE' + sql_drop_index = u'DROP INDEX {0} CASCADE' sql_get_index_string = u""" SELECT i.relname AS index_name @@ -697,17 +729,20 @@ def _drop_indexes(context: Context, data_dict: dict[str, Any], AND t.relkind = 'r' AND idx.indisunique = {unique} AND idx.indisprimary = false - AND t.relname = %s + AND t.relname = :relname """.format(unique='true' if unique else 'false') indexes_to_drop = context['connection'].execute( - sql_get_index_string, data_dict['resource_id']).fetchall() + sa.text(sql_get_index_string), + {"relname": data_dict['resource_id']} + ).fetchall() for index in indexes_to_drop: - context['connection'].execute( - sql_drop_index.format(index[0]).replace('%', '%%')) + context['connection'].execute(sa.text( + sql_drop_index.format(sa.column(index[0])) + )) def _get_index_names(connection: Any, resource_id: str): - sql = u""" + sql = sa.text(""" SELECT i.relname AS index_name FROM @@ -718,9 +753,9 @@ def _get_index_names(connection: Any, resource_id: str): t.oid = idx.indrelid AND i.oid = idx.indexrelid AND t.relkind = 'r' - AND t.relname = %s - """ - results = connection.execute(sql, resource_id).fetchall() + AND t.relname = :relname + """) + results = connection.execute(sql, {"relname": resource_id}).fetchall() return [result[0] for result in results] @@ -730,7 +765,10 @@ def _is_valid_pg_type(context: Context, type_name: str): else: connection = context['connection'] try: - connection.execute('SELECT %s::regtype', type_name) + connection.execute( + sa.text('SELECT cast(:type as regtype)'), + {"type": type_name} + ) except ProgrammingError as e: if e.orig.pgcode in [_PG_ERR_CODE['undefined_object'], _PG_ERR_CODE['syntax_error']]: @@ -741,13 +779,20 @@ def _is_valid_pg_type(context: Context, type_name: str): def _execute_single_statement( - context: Context, sql_string: str, where_values: Any): + context: Context, sql_string: str, where_values: list[dict[str, Any]]): if not datastore_helpers.is_single_statement(sql_string): raise ValidationError({ 'query': ['Query is not a single statement.'] }) - results = context['connection'].execute(sql_string, [where_values]) + params = {} + for chunk in where_values: + params.update(chunk) + + results = context['connection'].execute( + sa.text(sql_string), + params + ) return results @@ -791,20 +836,20 @@ def _insert_links(data_dict: dict[str, Any], limit: int, offset: int): def _where( - where_clauses_and_values: list[tuple[Any, ...]] -) -> tuple[str, list[Any]]: + where_clauses_and_values: WhereClauses +) -> tuple[str, list[dict[str, Any]]]: '''Return a SQL WHERE clause from list with clauses and values :param where_clauses_and_values: list of tuples with format - (where_clause, param1, ...) + (where_clause, {placeholder_name_1: param1, ...}) :type where_clauses_and_values: list of tuples :returns: SQL WHERE string with placeholders for the parameters, and list of parameters :rtype: string ''' - where_clauses = [] - values: list[Any] = [] + where_clauses: list[str] = [] + values: list[dict[str, Any]] = [] for clause_and_values in where_clauses_and_values: where_clauses.append('(' + clause_and_values[0] + ')') @@ -901,9 +946,9 @@ def create_indexes(context: Context, data_dict: dict[str, Any]): index)] }) fields_string = u', '.join( - ['(("{0}").json::text)'.format(field) + ['(({0}).json::text)'.format(identifier(field)) if field in json_fields else - '"%s"' % field + identifier(field) for field in index_fields]) sql_index_strings.append(sql_index_string.format( res_id=data_dict['resource_id'], @@ -911,14 +956,13 @@ def create_indexes(context: Context, data_dict: dict[str, Any]): name=_generate_index_name(data_dict['resource_id'], fields_string), fields=fields_string)) - sql_index_strings = [x.replace('%', '%%') for x in sql_index_strings] current_indexes = _get_index_names(context['connection'], data_dict['resource_id']) for sql_index_string in sql_index_strings: has_index = [c for c in current_indexes if sql_index_string.find(c) != -1] if not has_index: - connection.execute(sql_index_string) + connection.execute(sa.text(sql_index_string)) def create_table(context: Context, data_dict: dict[str, Any]): @@ -1012,8 +1056,8 @@ def create_table(context: Context, data_dict: dict[str, Any]): literal_string( json.dumps(info, ensure_ascii=False)))) - context['connection'].execute( - (sql_string + u';'.join(info_sql)).replace(u'%', u'%%')) + context['connection'].execute(sa.text( + sql_string + u';'.join(info_sql))) def alter_table(context: Context, data_dict: dict[str, Any]): @@ -1102,8 +1146,9 @@ def alter_table(context: Context, data_dict: dict[str, Any]): identifier(id_))) if alter_sql: - context['connection'].execute( - u';'.join(alter_sql).replace(u'%', u'%%')) + context['connection'].execute(sa.text( + ';'.join(alter_sql) + )) def insert_data(context: Context, data_dict: dict[str, Any]): @@ -1131,26 +1176,28 @@ def upsert_data(context: Context, data_dict: dict[str, Any]): for num, record in enumerate(records): _validate_record(record, num, field_names) - row = [] - for field in fields: + row = {} + for idx, field in enumerate(fields): value = record.get(field['id']) if value is not None and field['type'].lower() == 'nested': # a tuple with an empty second value value = (json.dumps(value), '') elif value == '' and field['type'] != 'text': value = None - row.append(value) + row[f"val_{idx}"] = value rows.append(row) - sql_string = u'''INSERT INTO {res_id} ({columns}) + sql_string = '''INSERT INTO {res_id} ({columns}) VALUES ({values});'''.format( res_id=identifier(data_dict['resource_id']), - columns=sql_columns.replace('%', '%%'), - values=', '.join(['%s' for _ in field_names]) + columns=sql_columns, + values=', '.join([ + f":val_{idx}" for idx in range(0, len(field_names)) + ]) ) try: - context['connection'].execute(sql_string, rows) + context['connection'].execute(sa.text(sql_string), rows) except (DatabaseError, DataError) as err: raise ValidationError({ 'records': [_programming_error_summary(err)], @@ -1195,21 +1242,36 @@ def upsert_data(context: Context, data_dict: dict[str, Any]): ', '.join(non_existing_field_names))] }) - if '_id' in record: - unique_values = [record['_id']] - pk_sql = '"_id"' - pk_values_sql = '%s' - else: - unique_values = [record[key] for key in unique_keys] - pk_sql = ','.join([identifier(part) for part in unique_keys]) - pk_values_sql = ','.join(['%s'] * len(unique_keys)) + idx_gen = itertools.count() used_fields = [field for field in fields if field['id'] in record] used_field_names = _pluck('id', used_fields) - used_values = [record[field] for field in used_field_names] + value_placeholders = [ + f"val_{next(idx_gen)}" for _ in used_field_names + ] + values = [":" + p for p in value_placeholders] + used_values = dict(zip( + value_placeholders, + [record[field] for field in used_field_names] + )) + + if '_id' in record: + placeholder = f'val_{next(idx_gen)}' + unique_values = {placeholder: record['_id']} + pk_sql = '"_id"' + pk_values_sql = ":" + placeholder + else: + placeholders = [ + f"val_{next(idx_gen)}" for _ in range(len(unique_keys)) + ] + unique_values = dict(zip( + placeholders, [record[key] for key in unique_keys] + )) + pk_sql = ','.join([identifier(part) for part in unique_keys]) + pk_values_sql = ','.join([":" + p for p in placeholders]) if method == _UPDATE: sql_string = u''' @@ -1220,16 +1282,16 @@ def upsert_data(context: Context, data_dict: dict[str, Any]): res_id=identifier(data_dict['resource_id']), columns=u', '.join( [identifier(field) - for field in used_field_names]).replace('%', '%%'), - values=u', '.join( - ['%s' for _ in used_field_names]), - primary_key=pk_sql.replace('%', '%%'), + for field in used_field_names]), + values=u', '.join(values), + primary_key=pk_sql, primary_value=pk_values_sql, ) try: results = context['connection'].execute( - sql_string, used_values + unique_values) - except sqlalchemy.exc.DatabaseError as err: + sa.text(sql_string), + {**used_values, **unique_values}) + except DatabaseError as err: raise ValidationError({ 'records': [_programming_error_summary(err)], 'records_row': num, @@ -1242,30 +1304,40 @@ def upsert_data(context: Context, data_dict: dict[str, Any]): }) elif method == _UPSERT: - sql_string = u''' - UPDATE {res_id} - SET ({columns}, "_full_text") = ({values}, NULL) - WHERE ({primary_key}) = ({primary_value}); - INSERT INTO {res_id} ({columns}) - SELECT {values} - WHERE NOT EXISTS (SELECT 1 FROM {res_id} - WHERE ({primary_key}) = ({primary_value})); - '''.format( + format_params = dict( res_id=identifier(data_dict['resource_id']), columns=u', '.join( [identifier(field) - for field in used_field_names]).replace('%', '%%'), - values=u', '.join(['%s::nested' - if field['type'] == 'nested' else '%s' - for field in used_fields]), - primary_key=pk_sql.replace('%', '%%'), + for field in used_field_names]), + values=u', '.join([ + f'cast(:{p} as nested)' + if field['type'] == 'nested' else ":" + p + for p, field in zip(value_placeholders, used_fields) + ]), + primary_key=pk_sql, primary_value=pk_values_sql, ) + + update_string = """ + UPDATE {res_id} + SET ({columns}, "_full_text") = ({values}, NULL) + WHERE ({primary_key}) = ({primary_value}) + """.format(**format_params) + + insert_string = """ + INSERT INTO {res_id} ({columns}) + SELECT {values} + WHERE NOT EXISTS (SELECT 1 FROM {res_id} + WHERE ({primary_key}) = ({primary_value})) + """.format(**format_params) + + values = {**used_values, **unique_values} try: context['connection'].execute( - sql_string, - (used_values + unique_values) * 2) - except sqlalchemy.exc.DatabaseError as err: + sa.text(update_string), values) + context['connection'].execute( + sa.text(insert_string), values) + except DatabaseError as err: raise ValidationError({ 'records': [_programming_error_summary(err)], 'records_row': num, @@ -1338,9 +1410,9 @@ def search_data(context: Context, data_dict: dict[str, Any]): where_clause, where_values = _where(query_dict['where']) # FIXME: Remove duplicates on select columns - select_columns = ', '.join(query_dict['select']).replace('%', '%%') - ts_query = cast(str, query_dict['ts_query']).replace('%', '%%') - resource_id = data_dict['resource_id'].replace('%', '%%') + select_columns = ', '.join(query_dict['select']) + ts_query = cast(str, query_dict['ts_query']) + resource_id = data_dict['resource_id'] sort = query_dict['sort'] limit = query_dict['limit'] offset = query_dict['offset'] @@ -1354,7 +1426,7 @@ def search_data(context: Context, data_dict: dict[str, Any]): sort = ['_id'] if sort: - sort_clause = 'ORDER BY %s' % (', '.join(sort)).replace('%', '%%') + sort_clause = 'ORDER BY {}'.format(', '.join(sort)) else: sort_clause = '' @@ -1369,7 +1441,7 @@ def search_data(context: Context, data_dict: dict[str, Any]): elif records_format == u'lists': select_columns = u" || ',' || ".join( s for s in query_dict['select'] - ).replace('%', '%%') + ) sql_fmt = u''' SELECT '[' || array_to_string(array_agg(j.v), ',') || ']' FROM ( SELECT {distinct} '[' || {select} || ']' v @@ -1436,13 +1508,15 @@ def search_data(context: Context, data_dict: dict[str, Any]): # See: https://wiki.postgresql.org/wiki/Count_estimate # (We also tried using the EXPLAIN to estimate filtered queries but # it didn't estimate well in tests) - analyze_count_sql = sqlalchemy.text(''' + analyze_count_sql = sa.text(''' SELECT reltuples::BIGINT AS approximate_row_count FROM pg_class WHERE relname=:resource; ''') - count_result = context['connection'].execute(analyze_count_sql, - resource=resource_id) + count_result = context['connection'].execute( + analyze_count_sql, + {"resource": resource_id}, + ) try: estimated_total = count_result.fetchall()[0][0] except ValueError: @@ -1479,7 +1553,7 @@ def search_data(context: Context, data_dict: dict[str, Any]): def _execute_single_statement_copy_to( context: Context, sql_string: str, - where_values: Any, buf: Any): + where_values: list[dict[str, Any]], buf: Any): if not datastore_helpers.is_single_statement(sql_string): raise ValidationError({ 'query': ['Query is not a single statement.'] @@ -1495,15 +1569,16 @@ def format_results(context: Context, results: Any, data_dict: dict[str, Any]): for field in results.cursor.description: result_fields.append({ 'id': str(field[0]), - 'type': _get_type(context['connection'], field[1]) + 'type': _get_type(context['connection'].engine, field[1]) }) records = [] + for row in results: converted_row = {} for field in result_fields: - converted_row[field['id']] = convert(row[field['id']], - field['type']) + converted_row[field['id']] = convert( + row._mapping[field['id']], field['type']) records.append(converted_row) data_dict['records'] = records if data_dict.get('records_truncated', False): @@ -1545,9 +1620,16 @@ def _create_triggers(connection: Any, resource_id: str, or "for_each" parameters from triggers list. ''' existing = connection.execute( - u"""SELECT tgname FROM pg_trigger - WHERE tgrelid = %s::regclass AND tgname LIKE 't___'""", - resource_id) + sa.select( + sa.column("tgname") + ).select_from(sa.table("pg_trigger")).where( + sa.column("tgrelid") == sa.cast( + resource_id, # type: ignore + REGCLASS + ), + sa.column("tgname").like("t___") + ) + ) sql_list = ( [u'DROP TRIGGER {name} ON {table}'.format( name=identifier(r[0]), @@ -1563,17 +1645,17 @@ def _create_triggers(connection: Any, resource_id: str, for i, t in enumerate(triggers)]) try: if sql_list: - connection.execute(u';\n'.join(sql_list)) + connection.execute(sa.text(";\n".join(sql_list))) except ProgrammingError as pe: raise ValidationError({u'triggers': [_programming_error_summary(pe)]}) def _create_fulltext_trigger(connection: Any, resource_id: str): - connection.execute( + connection.execute(sa.text( u'''CREATE TRIGGER zfulltext BEFORE INSERT OR UPDATE ON {table} FOR EACH ROW EXECUTE PROCEDURE populate_full_text_trigger()'''.format( - table=identifier(resource_id))) + table=identifier(resource_id)))) def upsert(context: Context, data_dict: dict[str, Any]): @@ -1592,8 +1674,8 @@ def upsert(context: Context, data_dict: dict[str, Any]): trans: Any = context['connection'].begin() try: # check if table already existes - context['connection'].execute( - u'SET LOCAL statement_timeout TO {0}'.format(timeout)) + context['connection'].execute(sa.text( + f"SET LOCAL statement_timeout TO {timeout}")) upsert_data(context, data_dict) if data_dict.get(u'dry_run', False): trans.rollback() @@ -1633,13 +1715,14 @@ def upsert(context: Context, data_dict: dict[str, Any]): def search(context: Context, data_dict: dict[str, Any]): backend = DatastorePostgresqlBackend.get_active_backend() engine = backend._get_read_engine() # type: ignore + _cache_types(engine) context['connection'] = engine.connect() timeout = context.get('query_timeout', _TIMEOUT) - _cache_types(context['connection']) try: - context['connection'].execute( - u'SET LOCAL statement_timeout TO {0}'.format(timeout)) + context['connection'].execute(sa.text( + f"SET LOCAL statement_timeout TO {timeout}" + )) return search_data(context, data_dict) except DBAPIError as e: if e.orig.pgcode == _PG_ERR_CODE['query_canceled']: @@ -1661,12 +1744,12 @@ def search(context: Context, data_dict: dict[str, Any]): def search_sql(context: Context, data_dict: dict[str, Any]): backend = DatastorePostgresqlBackend.get_active_backend() engine = backend._get_read_engine() # type: ignore + _cache_types(engine) context['connection'] = engine.connect() timeout = context.get('query_timeout', _TIMEOUT) - _cache_types(context['connection']) - sql = data_dict['sql'].replace('%', '%%') + sql = data_dict['sql'] # limit the number of results to ckan.datastore.search.rows_max + 1 # (the +1 is so that we know if the results went over the limit or not) @@ -1675,8 +1758,9 @@ def search_sql(context: Context, data_dict: dict[str, Any]): try: - context['connection'].execute( - u'SET LOCAL statement_timeout TO {0}'.format(timeout)) + context['connection'].execute(sa.text( + f"SET LOCAL statement_timeout TO {timeout}" + )) get_names = datastore_helpers.get_table_and_function_names_from_sql table_names, function_names = get_names(context, sql) @@ -1700,7 +1784,7 @@ def search_sql(context: Context, data_dict: dict[str, Any]): 'Not authorized to call function {}'.format(f) ) - results: Any = context['connection'].execute(sql) + results: Any = context['connection'].execute(sa.text(sql)) if results.rowcount == rows_max + 1: data_dict['records_truncated'] = True @@ -1778,8 +1862,9 @@ def _is_read_only_database(self): for url in [self.ckan_url, self.write_url, self.read_url]: connection = _get_engine_from_url(url).connect() try: - sql = u"SELECT has_schema_privilege('public', 'CREATE')" - is_writable: bool = connection.execute(sql).one()[0] + is_writable: bool = connection.scalar(sa.select( + sa.func.has_schema_privilege("public", "CREATE") + )) finally: connection.close() if is_writable: @@ -1803,26 +1888,30 @@ def _read_connection_has_correct_privileges(self): only user. A table is created by the write user to test the read only user. ''' - write_connection = self._get_write_engine().connect() read_connection_user = sa_url.make_url(self.read_url).username - drop_foo_sql = u'DROP TABLE IF EXISTS _foo' - - write_connection.execute(drop_foo_sql) + drop_foo_sql = sa.text("DROP TABLE IF EXISTS _foo") + engine = self._get_write_engine() try: - write_connection.execute(u'CREATE TEMP TABLE _foo ()') - for privilege in ['INSERT', 'UPDATE', 'DELETE']: - privilege_sql = u"SELECT has_table_privilege(%s, '_foo', %s)" - have_privilege: bool = write_connection.execute( - privilege_sql, - (read_connection_user, privilege) - ).one()[0] - if have_privilege: - return False + with engine.begin() as conn: + conn.execute(drop_foo_sql) + conn.execute(sa.text("CREATE TEMP TABLE _foo ()")) + + for privilege in ['INSERT', 'UPDATE', 'DELETE']: + have_privilege: bool = conn.scalar(sa.select( + sa.func.has_table_privilege( + read_connection_user, + "_foo", + privilege + ) + )) + if have_privilege: + return False finally: - write_connection.execute(drop_foo_sql) - write_connection.close() + with engine.begin() as conn: + conn.execute(drop_foo_sql) + return True def configure(self, config: CKANConfig): @@ -1965,27 +2054,19 @@ def datastore_search( def delete(self, context: Context, data_dict: dict[str, Any]): engine = self._get_write_engine() - context['connection'] = engine.connect() - _cache_types(context['connection']) + _cache_types(engine) - trans = context['connection'].begin() - try: + with engine.begin() as conn: + context["connection"] = conn # check if table exists if 'filters' not in data_dict: - context['connection'].execute( - u'DROP TABLE "{0}" CASCADE'.format( - data_dict['resource_id']) - ) + conn.execute(sa.text('DROP TABLE {0} CASCADE'.format( + identifier(data_dict['resource_id']) + ))) else: delete_data(context, data_dict) - trans.commit() return _unrename_json_field(data_dict) - except Exception: - trans.rollback() - raise - finally: - context['connection'].close() def create(self, context: Context, data_dict: dict[str, Any]): ''' @@ -2004,21 +2085,22 @@ def create(self, context: Context, data_dict: dict[str, Any]): Should be transactional. ''' engine = get_write_engine() + _cache_types(engine) + context['connection'] = engine.connect() timeout = context.get('query_timeout', _TIMEOUT) - _cache_types(context['connection']) _rename_json_field(data_dict) trans = context['connection'].begin() try: # check if table already exists - context['connection'].execute( - u'SET LOCAL statement_timeout TO {0}'.format(timeout)) - result = context['connection'].execute( - u'SELECT * FROM pg_tables WHERE tablename = %s', - data_dict['resource_id'] - ).fetchone() + context['connection'].execute(sa.text( + f"SET LOCAL statement_timeout TO {timeout}" + )) + result = context['connection'].execute(sa.text( + 'SELECT * FROM pg_tables WHERE tablename = :table' + ), {"table": data_dict['resource_id']}).fetchone() if not result: create_table(context, data_dict) _create_fulltext_trigger( @@ -2084,22 +2166,24 @@ def search_sql(self, context: Context, data_dict: dict[str, Any]): return search_sql(context, data_dict) def resource_exists(self, id: str) -> bool: - resources_sql = sqlalchemy.text( - u'''SELECT 1 FROM "_table_metadata" + resources_sql = sa.text( + '''SELECT 1 FROM "_table_metadata" WHERE name = :id AND alias_of IS NULL''') - results = self._get_read_engine().execute(resources_sql, id=id) + with self._get_read_engine().connect() as conn: + results = conn.execute(resources_sql, {"id": id}) res_exists = results.rowcount > 0 return res_exists def resource_id_from_alias(self, alias: str) -> tuple[bool, Optional[str]]: real_id: Optional[str] = None - resources_sql = sqlalchemy.text( + resources_sql = sa.text( u'''SELECT alias_of FROM "_table_metadata" WHERE name = :id''') - results = self._get_read_engine().execute(resources_sql, id=alias) + with self._get_read_engine().connect() as conn: + results = conn.execute(resources_sql, {"id": alias}) res_exists = results.rowcount > 0 - if res_exists: - real_id = results.fetchone()[0] # type: ignore + if res_exists and (row := results.fetchone()): + real_id = row[0] return res_exists, real_id # def resource_info(self, id): @@ -2116,54 +2200,60 @@ def resource_fields(self, id: str) -> dict[str, Any]: info['meta']['id'] = id # count of rows in table - meta_sql = sqlalchemy.text( - u'SELECT count(_id) FROM "{0}"'.format(id)) - meta_results = engine.execute(meta_sql) - info['meta']['count'] = meta_results.fetchone()[0] # type: ignore + meta_sql = sa.text( + u'SELECT count(_id) FROM {0}'.format(identifier(id))) + with engine.connect() as conn: + meta_results = conn.execute(meta_sql) + info['meta']['count'] = meta_results.one()[0] # table_type - BASE TABLE, VIEW, FOREIGN TABLE, MATVIEW - tabletype_sql = sqlalchemy.text(u''' + tabletype_sql = sa.text(f''' SELECT table_type FROM INFORMATION_SCHEMA.TABLES - WHERE table_name = '{0}' - '''.format(id)) - tabletype_results = engine.execute(tabletype_sql) + WHERE table_name = {literal_string(id)} + ''') + with engine.connect() as conn: + tabletype_results = conn.execute(tabletype_sql) info['meta']['table_type'] = \ - tabletype_results.fetchone()[0] # type: ignore + tabletype_results.one()[0] # MATERIALIZED VIEWS show as BASE TABLE, so # we check pg_matviews - matview_sql = sqlalchemy.text(u''' + matview_sql = sa.text(f''' SELECT count(*) FROM pg_matviews - WHERE matviewname = '{0}' - '''.format(id)) - matview_results = engine.execute(matview_sql) - if matview_results.fetchone()[0]: # type: ignore + WHERE matviewname = {literal_string(id)} + ''') + with engine.connect() as conn: + matview_results = conn.execute(matview_sql) + if matview_results.one()[0]: info['meta']['table_type'] = 'MATERIALIZED VIEW' # SIZE - size of table in bytes - size_sql = sqlalchemy.text( - u"SELECT pg_relation_size('{0}')".format(id)) - size_results = engine.execute(size_sql) - info['meta']['size'] = size_results.fetchone()[0] # type: ignore + size_sql = sa.text( + f"SELECT pg_relation_size({literal_string(id)})") + with engine.connect() as conn: + size_results = conn.execute(size_sql) + info['meta']['size'] = size_results.one()[0] # DB_SIZE - size of database in bytes - dbsize_sql = sqlalchemy.text( + dbsize_sql = sa.text( u"SELECT pg_database_size(current_database())") - dbsize_results = engine.execute(dbsize_sql) - info['meta']['db_size'] = \ - dbsize_results.fetchone()[0] # type: ignore + with engine.connect() as conn: + dbsize_results = conn.execute(dbsize_sql) + info['meta']['db_size'] = dbsize_results.one()[0] # IDXSIZE - size of all indices for table in bytes - idxsize_sql = sqlalchemy.text( - u"SELECT pg_indexes_size('{0}')".format(id)) - idxsize_results = engine.execute(idxsize_sql) - info['meta']['idx_size'] = \ - idxsize_results.fetchone()[0] # type: ignore + idxsize_sql = sa.text( + f"SELECT pg_indexes_size({literal_string(id)})") + with engine.connect() as conn: + idxsize_results = conn.execute(idxsize_sql) + info['meta']['idx_size'] = idxsize_results.one()[0] # all the aliases for this resource - alias_sql = sqlalchemy.text(u''' - SELECT name FROM "_table_metadata" WHERE alias_of = '{0}' - '''.format(id)) - alias_results = engine.execute(alias_sql) + alias_sql = sa.text(f''' + SELECT name FROM "_table_metadata" + WHERE alias_of = {literal_string(id)} + ''') + with engine.connect() as conn: + alias_results = conn.execute(alias_sql) aliases = [] for alias in alias_results.fetchall(): aliases.append(alias[0]) @@ -2172,7 +2262,7 @@ def resource_fields(self, id: str) -> dict[str, Any]: # get the data dictionary for the resource data_dictionary = datastore_helpers.datastore_dictionary(id) - schema_sql = sqlalchemy.text(u''' + schema_sql = sa.text(f''' SELECT f.attname AS column_name, pg_catalog.format_type(f.atttypid,f.atttypmod) AS native_type, @@ -2196,11 +2286,12 @@ def resource_fields(self, id: str) -> dict[str, Any]: AND c.oid = f.attrelid AND c.oid = ix.indrelid LEFT JOIN pg_class AS i ON ix.indexrelid = i.oid WHERE c.relkind = 'r'::char - AND c.relname = '{0}' + AND c.relname = {literal_string(id)} AND f.attnum > 0 ORDER BY c.relname,f.attnum; - '''.format(id)) - schema_results = engine.execute(schema_sql) + ''') + with engine.connect() as conn: + schema_results = conn.execute(schema_sql) schemainfo = {} for row in schema_results.fetchall(): row: Any # Row has incomplete type definition @@ -2222,12 +2313,15 @@ def resource_fields(self, id: str) -> dict[str, Any]: pass return info - def get_all_ids(self): - resources_sql = sqlalchemy.text( + def get_all_ids(self) -> list[str]: + resources_sql = sa.text( u'''SELECT name FROM "_table_metadata" WHERE alias_of IS NULL''') - query = self._get_read_engine().execute(resources_sql) - return [q[0] for q in query.fetchall()] + with self._get_read_engine().connect() as conn: + return [ + item for item in + conn.scalars(resources_sql) + ] def create_function(self, *args: Any, **kwargs: Any): return create_function(*args, **kwargs) @@ -2246,12 +2340,12 @@ def calculate_record_count(self, resource_id: str): Postgresql's pg_stat_user_tables. This number will be used when specifying `total_estimation_threshold` ''' - connection = get_write_engine().connect() - sql = 'ANALYZE "{}"'.format(resource_id) - try: - connection.execute(sql) - except sqlalchemy.exc.DatabaseError as err: - raise DatastoreException(err) + sql = f'ANALYZE {identifier(resource_id)}' + with get_write_engine().connect() as conn: + try: + conn.execute(sa.text(sql)) + except DatabaseError as err: + raise DatastoreException(err) def create_function(name: str, arguments: Iterable[dict[str, Any]], @@ -2295,18 +2389,8 @@ def drop_function(name: str, if_exists: bool): def _write_engine_execute(sql: str): - connection = get_write_engine().connect() - # No special meaning for '%' in sql parameter: - connection: Any = connection.execution_options(no_parameters=True) - trans = connection.begin() - try: - connection.execute(sql) - trans.commit() - except Exception: - trans.rollback() - raise - finally: - connection.close() + with get_write_engine().begin() as conn: + conn.execute(sa.text(sql)) def _programming_error_summary(pe: Any): diff --git a/ckanext/datastore/helpers.py b/ckanext/datastore/helpers.py index 949f4adce81..6df05080a22 100644 --- a/ckanext/datastore/helpers.py +++ b/ckanext/datastore/helpers.py @@ -9,7 +9,7 @@ from typing_extensions import Literal import sqlparse - +import sqlalchemy as sa import ckan.common as converters import ckan.plugins.toolkit as tk from ckan.types import Context @@ -113,12 +113,12 @@ def get_table_and_function_names_from_sql(context: Context, sql: str): function_names.extend(_get_function_names_from_sql(sql)) - result = context['connection'].execute( - 'EXPLAIN (VERBOSE, FORMAT JSON) {0}'.format( - str(sql))).fetchone() + result = context['connection'].scalar(sa.text( + f"EXPLAIN (VERBOSE, FORMAT JSON) {sql}" + )) try: - query_plan = json.loads(result['QUERY PLAN']) + query_plan = json.loads(result) plan = query_plan[0]['Plan'] t, q, f = _parse_query_plan(plan) diff --git a/ckanext/datastore/interfaces.py b/ckanext/datastore/interfaces.py index d706200621a..c1993ea8b5c 100644 --- a/ckanext/datastore/interfaces.py +++ b/ckanext/datastore/interfaces.py @@ -71,17 +71,23 @@ def datastore_search(self, context: Context, data_dict: dict[str, Any], The ``where`` key is a special case. It's elements are on the form: - (format_string, param1, param2, ...) + (format_string, {placeholder_for_param_1: param_1}) The ``format_string`` isn't escaped for SQL Injection attacks, so - everything coming from the user should be in the params list. With this + everything coming from the user should be in the params dict. With this format, you could do something like: - ('"age" BETWEEN %s AND %s', age_between[0], age_between[1]) + ( + '"age" BETWEEN :my_ext_min AND :my_ext_max', + {"my_ext_min": age_between[0], "my_ext_max": age_between[1]}, + ) This escapes the ``age_between[0]`` and ``age_between[1]`` making sure we're not vulnerable. + ..note:: Use unique prefix for the parameter's names to avoid conflicts + with other plugins + After finishing this, you should return your modified ``query_dict``. :param context: the context @@ -97,6 +103,7 @@ def datastore_search(self, context: Context, data_dict: dict[str, Any], :returns: the query_dict with your modifications :rtype: dictionary + ''' return query_dict diff --git a/ckanext/datastore/tests/helpers.py b/ckanext/datastore/tests/helpers.py index a8f1a058076..07f9c250e8e 100644 --- a/ckanext/datastore/tests/helpers.py +++ b/ckanext/datastore/tests/helpers.py @@ -1,5 +1,8 @@ # encoding: utf-8 +from __future__ import annotations +from typing import Any +import sqlalchemy as sa from sqlalchemy import orm import ckan.model as model @@ -18,9 +21,9 @@ def clear_db(Session): # noqa drop_tables = u"""select 'drop table "' || tablename || '" cascade;' from pg_tables where schemaname = 'public' """ c = Session.connection() - results = c.execute(drop_tables) + results = c.execute(sa.text(drop_tables)) for result in results: - c.execute(result[0]) + c.execute(sa.text(result[0])) drop_functions_sql = u""" SELECT 'drop function if exists ' || quote_ident(proname) || '();' @@ -28,9 +31,11 @@ def clear_db(Session): # noqa INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid) WHERE ns.nspname = 'public' AND proname != 'populate_full_text_trigger' """ - drop_functions = u"".join(r[0] for r in c.execute(drop_functions_sql)) + drop_functions = u"".join( + r[0] for r in c.execute(sa.text(drop_functions_sql)) + ) if drop_functions: - c.execute(drop_functions) + c.execute(sa.text(drop_functions)) Session.commit() Session.remove() @@ -60,19 +65,19 @@ def set_url_type(resources, user): p.toolkit.get_action("resource_update")(context, resource) -def execute_sql(sql, *args): +def execute_sql(sql: str, params: dict[str, Any]): engine = db.get_write_engine() session = orm.scoped_session(orm.sessionmaker(bind=engine)) - return session.connection().execute(sql, *args) + return session.connection().execute(sa.text(sql), params) def when_was_last_analyze(resource_id): results = execute_sql( """SELECT last_analyze FROM pg_stat_user_tables - WHERE relname=%s; + WHERE relname=:relname; """, - resource_id, + {"relname": resource_id}, ).fetchall() return results[0][0] diff --git a/ckanext/datastore/tests/sample_datastore_plugin.py b/ckanext/datastore/tests/sample_datastore_plugin.py index d82b813784a..575d456284a 100644 --- a/ckanext/datastore/tests/sample_datastore_plugin.py +++ b/ckanext/datastore/tests/sample_datastore_plugin.py @@ -33,18 +33,16 @@ def _where(self, data_dict): age_between = filters["age_between"] clause = ( - '"age" >= %s AND "age" <= %s', - age_between[0], - age_between[1], + '"age" >= :min AND "age" <= :max', + {"min": age_between[0], "max": age_between[1]}, ) where_clauses.append(clause) if "age_not_between" in filters: age_not_between = filters["age_not_between"] clause = ( - '"age" < %s OR "age" > %s', - age_not_between[0], - age_not_between[1], + '"age" < :min OR "age" > :max', + {"min": age_not_between[0], "max": age_not_between[1]}, ) where_clauses.append(clause) if "insecure_filter" in filters: diff --git a/ckanext/datastore/tests/test_create.py b/ckanext/datastore/tests/test_create.py index 6133d879323..b70a791f067 100644 --- a/ckanext/datastore/tests/test_create.py +++ b/ckanext/datastore/tests/test_create.py @@ -2,6 +2,7 @@ import json import pytest +import sqlalchemy as sa import sqlalchemy.orm as orm import ckan.lib.create_test_data as ctd @@ -29,10 +30,10 @@ def _has_index_on_field(self, resource_id, field): FROM pg_class WHERE - pg_class.relname = %s + pg_class.relname = :relname """ index_name = db._generate_index_name(resource_id, field) - results = execute_sql(sql, index_name).fetchone() + results = execute_sql(sql, {"relname": index_name}).fetchone() return bool(results) def _get_index_names(self, resource_id): @@ -47,9 +48,9 @@ def _get_index_names(self, resource_id): t.oid = idx.indrelid AND i.oid = idx.indexrelid AND t.relkind = 'r' - AND t.relname = %s + AND t.relname = :relname """ - results = execute_sql(sql, resource_id).fetchall() + results = execute_sql(sql, {"relname": resource_id}).fetchall() return [result[0] for result in results] def test_create_works_with_empty_array_in_json_field(self): @@ -713,31 +714,34 @@ def test_create_basic(self, app): assert res["fields"] == data["fields"], res["fields"] c = self.Session.connection() - results = c.execute('select * from "{0}"'.format(resource.id)) + results = c.execute(sa.text( + 'select * from "{0}"'.format(resource.id) + )) assert results.rowcount == 3 for i, row in enumerate(results): - assert data["records"][i].get("boo%k") == row["boo%k"] + assert data["records"][i].get("boo%k") == getattr(row, "boo%k") + author = getattr(row, "author") assert data["records"][i].get("author") == ( - json.loads(row["author"][0]) if row["author"] else None + json.loads(author[0]) if author else None ) - results = c.execute( + results = c.execute(sa.text( """ select * from "{0}" where _full_text @@ to_tsquery('warandpeace') """.format( resource.id ) - ) + )) assert results.rowcount == 1, results.rowcount - results = c.execute( + results = c.execute(sa.text( """ select * from "{0}" where _full_text @@ to_tsquery('tolstoy') """.format( resource.id ) - ) + )) assert results.rowcount == 2 self.Session.remove() @@ -747,21 +751,23 @@ def test_create_basic(self, app): results = [ row - for row in c.execute( - u'select * from "{0}"'.format(resource.id) - ) + for row in c.execute(sa.text( + 'select * from "{0}"'.format(resource.id) + )) ] results_alias = [ - row for row in c.execute(u'select * from "{0}"'.format(alias)) + row for row in c.execute(sa.text( + 'select * from "{0}"'.format(alias) + )) ] assert results == results_alias sql = ( u"select * from _table_metadata " - "where alias_of=%s and name=%s" + "where alias_of=:alias and name=:name" ) - results = c.execute(sql, resource.id, alias) + results = c.execute(sa.text(sql), {"alias": resource.id, "name": alias}) assert results.rowcount == 1 self.Session.remove() @@ -790,26 +796,29 @@ def test_create_basic(self, app): assert res_dict["success"] is True c = self.Session.connection() - results = c.execute('select * from "{0}"'.format(resource.id)) + results = c.execute(sa.text( + 'select * from "{0}"'.format(resource.id) + )) self.Session.remove() assert results.rowcount == 4 all_data = data["records"] + data2["records"] for i, row in enumerate(results): - assert all_data[i].get("boo%k") == row["boo%k"] + assert all_data[i].get("boo%k") == getattr(row, "boo%k") + author = getattr(row, "author") assert all_data[i].get("author") == ( - json.loads(row["author"][0]) if row["author"] else None + json.loads(author[0]) if author else None ) c = self.Session.connection() - results = c.execute( + results = c.execute(sa.text( """ select * from "{0}" where _full_text @@ 'tolstoy' """.format( resource.id ) - ) + )) self.Session.remove() assert results.rowcount == 3 @@ -837,26 +846,31 @@ def test_create_basic(self, app): assert res_dict["success"] is True c = self.Session.connection() - results = c.execute('select * from "{0}"'.format(resource.id)) + results = c.execute(sa.text( + 'select * from "{0}"'.format(resource.id) + )) assert results.rowcount == 5 all_data = data["records"] + data2["records"] + data3["records"] for i, row in enumerate(results): - assert all_data[i].get("boo%k") == row["boo%k"], ( + book = getattr(row, "boo%k") + assert all_data[i].get("boo%k") == book, ( i, all_data[i].get("boo%k"), - row["boo%k"], + book, ) + + author = getattr(row, "author") assert all_data[i].get("author") == ( - json.loads(row["author"][0]) if row["author"] else None + json.loads(author[0]) if author else None ) - results = c.execute( + results = c.execute(sa.text( """select * from "{0}" where _full_text @@ to_tsquery('dostoevsky') """.format( resource.id ) - ) + )) self.Session.remove() assert results.rowcount == 2 @@ -902,13 +916,19 @@ def test_create_basic(self, app): for alias in aliases: sql = ( "select * from _table_metadata " - "where alias_of=%s and name=%s" + "where alias_of=:alias and name=:name" + ) + results = c.execute( + sa.text(sql), + {"alias": resource.id, "name": alias} ) - results = c.execute(sql, resource.id, alias) assert results.rowcount == 0 - sql = "select * from _table_metadata " "where alias_of=%s and name=%s" - results = c.execute(sql, resource.id, "another_alias") + sql = "select * from _table_metadata " "where alias_of=:alias and name=:name" + results = c.execute( + sa.text(sql), + {"alias": resource.id, "name": "another_alias"} + ) assert results.rowcount == 1 self.Session.remove() @@ -1069,7 +1089,9 @@ def test_guess_types(self, app): res_dict = json.loads(res.data) c = self.Session.connection() - results = c.execute("""select * from "{0}" """.format(resource.id)) + results = c.execute(sa.text( + """select * from "{0}" """.format(resource.id) + )) types = [ db._pg_types[field[1]] for field in results.cursor.description @@ -1087,9 +1109,9 @@ def test_guess_types(self, app): assert results.rowcount == 3 for i, row in enumerate(results): - assert data["records"][i].get("book") == row["book"] + assert data["records"][i].get("book") == row._mapping["book"] assert data["records"][i].get("author") == ( - json.loads(row["author"][0]) if row["author"] else None + json.loads(row._mapping["author"][0]) if row._mapping["author"] else None ) self.Session.remove() @@ -1128,7 +1150,9 @@ def test_guess_types(self, app): res_dict = json.loads(res.data) c = self.Session.connection() - results = c.execute("""select * from "{0}" """.format(resource.id)) + results = c.execute(sa.text( + """select * from "{0}" """.format(resource.id) + )) self.Session.remove() types = [ diff --git a/ckanext/datastore/tests/test_db.py b/ckanext/datastore/tests/test_db.py index 7ae19722174..a646ed8a32e 100644 --- a/ckanext/datastore/tests/test_db.py +++ b/ckanext/datastore/tests/test_db.py @@ -119,7 +119,8 @@ def _assert_created_index_on( ) calls = connection.execute.call_args_list - was_called = [call for call in calls if call[0][0].find(sql_str) != -1] + + was_called = any(sql_str in str(call.args[0]) for call in calls) assert was_called, ( "Expected 'connection.execute' to have been " diff --git a/ckanext/datastore/tests/test_delete.py b/ckanext/datastore/tests/test_delete.py index cc87d79c25d..7945df2fdd6 100644 --- a/ckanext/datastore/tests/test_delete.py +++ b/ckanext/datastore/tests/test_delete.py @@ -3,6 +3,7 @@ import json import pytest +import sqlalchemy as sa import sqlalchemy.orm as orm import ckan.lib.create_test_data as ctd @@ -55,7 +56,8 @@ def test_delete_basic(self): assert resobj.extras.get('datastore_active') is False results = execute_sql( - u"select 1 from pg_views where viewname = %s", u"b\xfck2" + u"select 1 from pg_views where viewname = :name", + {"name": "b\xfck2"} ) assert results.rowcount == 0 @@ -63,8 +65,8 @@ def test_delete_basic(self): results = execute_sql( u"""SELECT table_name FROM information_schema.tables - WHERE table_name=%s;""", - resource["id"], + WHERE table_name=:name;""", + {"name": resource["id"]}, ) assert results.rowcount == 0 @@ -155,7 +157,8 @@ def test_delete_records_basic(self): helpers.call_action("datastore_records_delete", **data) results = execute_sql( - u"select 1 from pg_views where viewname = %s", u"b\xfck2" + u"select 1 from pg_views where viewname = :name", + {"name": "b\xfck2"} ) assert results.rowcount == 1 @@ -163,8 +166,8 @@ def test_delete_records_basic(self): results = execute_sql( u"""SELECT table_name FROM information_schema.tables - WHERE table_name=%s;""", - resource["id"], + WHERE table_name=:name;""", + {"name": resource["id"]}, ) assert results.rowcount == 1 @@ -366,7 +369,9 @@ def test_delete_filters(self, app): assert res_dict["success"] is True c = self.Session.connection() - result = c.execute(u'select * from "{0}";'.format(resource_id)) + result = c.execute(sa.text( + u'select * from "{0}";'.format(resource_id) + )) results = [r for r in result] assert len(results) == 1 assert results[0].book == "annakarenina" @@ -388,7 +393,9 @@ def test_delete_filters(self, app): assert res_dict["success"] is True c = self.Session.connection() - result = c.execute(u'select * from "{0}";'.format(resource_id)) + result = c.execute(sa.text( + 'select * from "{0}";'.format(resource_id) + )) results = [r for r in result] assert len(results) == 1 assert results[0].book == "annakarenina" @@ -409,7 +416,9 @@ def test_delete_filters(self, app): assert res_dict["success"] is True c = self.Session.connection() - result = c.execute(u'select * from "{0}";'.format(resource_id)) + result = c.execute(sa.text( + 'select * from "{0}";'.format(resource_id) + )) results = [r for r in result] assert len(results) == 0 self.Session.remove() diff --git a/ckanext/datastore/tests/test_helpers.py b/ckanext/datastore/tests/test_helpers.py index 201efec4ba4..806c6188f65 100644 --- a/ckanext/datastore/tests/test_helpers.py +++ b/ckanext/datastore/tests/test_helpers.py @@ -2,6 +2,7 @@ import re import pytest +import sqlalchemy as sa import sqlalchemy.orm as orm from sqlalchemy.exc import ProgrammingError @@ -77,13 +78,13 @@ def test_get_table_names(self): engine = db.get_write_engine() session = orm.scoped_session(orm.sessionmaker(bind=engine)) create_tables = [ - u"CREATE TABLE test_a (id_a text)", - u"CREATE TABLE test_b (id_b text)", - u'CREATE TABLE "TEST_C" (id_c text)', - u'CREATE TABLE test_d ("α/α" integer)', + "CREATE TABLE test_a (id_a text)", + "CREATE TABLE test_b (id_b text)", + 'CREATE TABLE "TEST_C" (id_c text)', + 'CREATE TABLE test_d ("α/α" integer)', ] for create_table_sql in create_tables: - session.execute(create_table_sql) + session.execute(sa.text(create_table_sql)) test_cases = [ (u"SELECT * FROM test_a", ["test_a"]), @@ -134,7 +135,7 @@ def test_get_function_names(self): u"CREATE TABLE test_b (name text, subject_id text)", ] for create_table_sql in create_tables: - session.execute(create_table_sql) + session.execute(sa.text(create_table_sql)) test_cases = [ (u"SELECT max(id) from test_a", ["max"]), @@ -167,7 +168,7 @@ def test_get_function_names_custom_function(self): """ ] for create_table_sql in create_tables: - session.execute(create_table_sql) + session.execute(sa.text(create_table_sql)) context = {"connection": session.connection()} @@ -191,7 +192,7 @@ def test_get_function_names_crosstab(self): u"CREATE TABLE test_b (name text, subject_id text)", ] for create_table_sql in create_tables: - session.execute(create_table_sql) + session.execute(sa.text(create_table_sql)) test_cases = [ (u"""SELECT * diff --git a/ckanext/datastore/tests/test_search.py b/ckanext/datastore/tests/test_search.py index 4914ae5c349..720902732f4 100644 --- a/ckanext/datastore/tests/test_search.py +++ b/ckanext/datastore/tests/test_search.py @@ -2,6 +2,7 @@ import json import pytest +import sqlalchemy as sa import sqlalchemy.orm as orm import decimal @@ -129,7 +130,8 @@ def test_estimate_total(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) search_data = { "resource_id": resource["id"], "total_estimation_threshold": 50, @@ -153,7 +155,8 @@ def test_estimate_total_with_filters(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) search_data = { "resource_id": resource["id"], "filters": {u"the year": 1901}, @@ -179,7 +182,8 @@ def test_estimate_total_with_distinct(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) search_data = { "resource_id": resource["id"], "fields": ["the year"], @@ -226,7 +230,8 @@ def test_estimate_total_with_zero_threshold(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) search_data = { "resource_id": resource["id"], "total_estimation_threshold": 0, @@ -251,7 +256,8 @@ def test_estimate_total_off(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) search_data = { "resource_id": resource["id"], "total_estimation_threshold": None, @@ -275,7 +281,9 @@ def test_estimate_total_default_off(self): """.format( resource=resource["id"] ) - db.get_write_engine().execute(analyze_sql) + with db.get_write_engine().connect() as conn: + conn.execute(sa.text(analyze_sql)) + search_data = { "resource_id": resource["id"], # don't specify total_estimation_threshold diff --git a/ckanext/datastore/tests/test_upsert.py b/ckanext/datastore/tests/test_upsert.py index 48c2294ae92..d66b7e579c1 100644 --- a/ckanext/datastore/tests/test_upsert.py +++ b/ckanext/datastore/tests/test_upsert.py @@ -12,11 +12,11 @@ def _search(resource_id): return helpers.call_action(u"datastore_search", resource_id=resource_id) +@pytest.mark.ckan_config("ckan.plugins", "datastore") +@pytest.mark.usefixtures("clean_datastore", "with_plugins", "with_request_context") class TestDatastoreUpsert(object): # Test action 'datastore_upsert' with 'method': 'upsert' - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_requires_auth(self): resource = factories.Resource(url_type=u"datastore") data = { @@ -43,8 +43,6 @@ def test_upsert_requires_auth(self): in str(context.value) ) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_empty_fails(self): resource = factories.Resource(url_type=u"datastore") data = { @@ -64,8 +62,6 @@ def test_upsert_empty_fails(self): helpers.call_action("datastore_upsert", **data) assert u"'Missing value'" in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_basic_as_update(self): resource = factories.Resource() data = { @@ -96,8 +92,6 @@ def test_basic_as_update(self): assert search_result["records"][0]["book"] == "The boy" assert search_result["records"][0]["author"] == "F Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_basic_as_insert(self): resource = factories.Resource() data = { @@ -128,8 +122,6 @@ def test_basic_as_insert(self): assert search_result["records"][0]["book"] == u"El Niño" assert search_result["records"][1]["book"] == u"The boy" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_only_one_field(self): resource = factories.Resource() data = { @@ -160,8 +152,6 @@ def test_upsert_only_one_field(self): assert search_result["records"][0]["book"] == "The boy" assert search_result["records"][0]["author"] == "Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_field_types(self): resource = factories.Resource(url_type="datastore") data = { @@ -220,8 +210,6 @@ def test_field_types(self): assert search_result["records"][2]["characters"] == ["Bob", "Marvin"] assert search_result["records"][2]["nested"] == {"baz": 3} - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_percent(self): resource = factories.Resource() data = { @@ -253,8 +241,6 @@ def test_percent(self): assert search_result["records"][0]["bo%ok"] == "The % boy" assert search_result["records"][1]["bo%ok"] == "Gu%ide" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_missing_key(self): resource = factories.Resource() data = { @@ -280,8 +266,6 @@ def test_missing_key(self): helpers.call_action("datastore_upsert", **data) assert u'fields "id" are missing' in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_non_existing_field(self): resource = factories.Resource(url_type="datastore") data = { @@ -307,8 +291,6 @@ def test_non_existing_field(self): helpers.call_action("datastore_upsert", **data) assert u'fields "dummy" do not exist' in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_works_with_empty_list_in_json_field(self): resource = factories.Resource() data = { @@ -335,8 +317,6 @@ def test_upsert_works_with_empty_list_in_json_field(self): assert search_result["total"] == 1 assert search_result["records"][0]["nested"] == [] - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_delete_field_value(self): resource = factories.Resource() data = { @@ -365,8 +345,6 @@ def test_delete_field_value(self): assert search_result["records"][0]["book"] is None assert search_result["records"][0]["author"] == "Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_doesnt_crash_with_json_field(self): resource = factories.Resource() data = { @@ -394,8 +372,6 @@ def test_upsert_doesnt_crash_with_json_field(self): } helpers.call_action("datastore_upsert", **data) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_upsert_doesnt_crash_with_json_field_with_string_value(self): resource = factories.Resource() data = { @@ -417,8 +393,6 @@ def test_upsert_doesnt_crash_with_json_field_with_string_value(self): } helpers.call_action("datastore_upsert", **data) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_dry_run(self): ds = factories.Dataset() table = helpers.call_action( @@ -438,8 +412,6 @@ def test_dry_run(self): ) assert result["records"] == [] - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_dry_run_type_error(self): ds = factories.Dataset() table = helpers.call_action( @@ -462,8 +434,6 @@ def test_dry_run_type_error(self): else: assert 0, "error not raised" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_dry_run_trigger_error(self): ds = factories.Dataset() helpers.call_action( @@ -497,8 +467,6 @@ def test_dry_run_trigger_error(self): else: assert 0, "error not raised" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_calculate_record_count_is_false(self): resource = factories.Resource() data = { @@ -523,8 +491,6 @@ def test_calculate_record_count_is_false(self): last_analyze = when_was_last_analyze(resource["id"]) assert last_analyze is None - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") @pytest.mark.flaky(reruns=2) # because analyze is sometimes delayed def test_calculate_record_count(self): resource = factories.Resource() @@ -551,8 +517,6 @@ def test_calculate_record_count(self): last_analyze = when_was_last_analyze(resource["id"]) assert last_analyze is not None - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_no_pk_update(self): resource = factories.Resource() data = { @@ -579,8 +543,6 @@ def test_no_pk_update(self): assert search_result["total"] == 1 assert search_result["records"][0]["book"] == "The boy" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_id_instead_of_pk_update(self): resource = factories.Resource() data = { @@ -612,8 +574,6 @@ def test_id_instead_of_pk_update(self): assert search_result["records"][0]["book"] == "The boy" assert search_result["records"][0]["author"] == "F Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_empty_string_instead_of_null(self): resource = factories.Resource() data = { @@ -645,11 +605,11 @@ def test_empty_string_instead_of_null(self): assert rec == {'_id': 1, 'pk': '1000', 'n': None, 'd': None} +@pytest.mark.ckan_config("ckan.plugins", "datastore") +@pytest.mark.usefixtures("clean_datastore", "with_plugins", "with_request_context") class TestDatastoreInsert(object): # Test action 'datastore_upsert' with 'method': 'insert' - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_basic_insert(self): resource = factories.Resource() data = { @@ -687,8 +647,6 @@ def test_basic_insert(self): u"author": u"Torres", } - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_non_existing_field(self): resource = factories.Resource(url_type="datastore") data = { @@ -714,8 +672,6 @@ def test_non_existing_field(self): helpers.call_action("datastore_upsert", **data) assert u'row "1" has extra keys "dummy"' in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_key_already_exists(self): resource = factories.Resource() data = { @@ -749,8 +705,6 @@ def test_key_already_exists(self): context.value ) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_empty_string_instead_of_null(self): resource = factories.Resource() data = { @@ -807,11 +761,11 @@ def test_insert_wrong_type(self): assert u'invalid input syntax for integer: "notanumber"' in str(context.value) +@pytest.mark.ckan_config("ckan.plugins", "datastore") +@pytest.mark.usefixtures("clean_datastore", "with_plugins", "with_request_context") class TestDatastoreUpdate(object): # Test action 'datastore_upsert' with 'method': 'update' - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_basic(self): resource = factories.Resource(url_type="datastore") data = { @@ -839,8 +793,6 @@ def test_basic(self): assert search_result["records"][0]["book"] == "The boy" assert search_result["records"][0]["author"] == "Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_field_types(self): resource = factories.Resource(url_type="datastore") data = { @@ -896,8 +848,6 @@ def test_field_types(self): assert search_result["records"][2]["characters"] == ["Bob", "Marvin"] assert search_result["records"][2]["nested"] == {"baz": 3} - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_update_unspecified_key(self): resource = factories.Resource(url_type="datastore") data = { @@ -923,8 +873,6 @@ def test_update_unspecified_key(self): helpers.call_action("datastore_upsert", **data) assert u'fields "id" are missing' in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_update_unknown_key(self): resource = factories.Resource(url_type="datastore") data = { @@ -946,12 +894,9 @@ def test_update_unknown_key(self): "records": [{"id": "1", "author": "tolkien"}], # unknown } - with pytest.raises(ValidationError) as context: + with pytest.raises(ValidationError, match=r"key .*\\\'1\\\'.* not found"): helpers.call_action("datastore_upsert", **data) - assert u"key \"[\\'1\\']\" not found" in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_update_non_existing_field(self): resource = factories.Resource(url_type="datastore") data = { @@ -977,8 +922,6 @@ def test_update_non_existing_field(self): helpers.call_action("datastore_upsert", **data) assert u'fields "dummy" do not exist' in str(context.value) - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_no_pk_update(self): resource = factories.Resource() data = { @@ -1005,8 +948,6 @@ def test_no_pk_update(self): assert search_result["total"] == 1 assert search_result["records"][0]["book"] == "The boy" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_id_instead_of_pk_update(self): resource = factories.Resource() data = { @@ -1038,8 +979,6 @@ def test_id_instead_of_pk_update(self): assert search_result["records"][0]["book"] == "The boy" assert search_result["records"][0]["author"] == "F Torres" - @pytest.mark.ckan_config("ckan.plugins", "datastore") - @pytest.mark.usefixtures("clean_datastore", "with_plugins") def test_empty_string_instead_of_null(self): resource = factories.Resource() data = { diff --git a/ckanext/example_idatastorebackend/example_sqlite.py b/ckanext/example_idatastorebackend/example_sqlite.py index 5c5129d0422..945339de331 100644 --- a/ckanext/example_idatastorebackend/example_sqlite.py +++ b/ckanext/example_idatastorebackend/example_sqlite.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- # type: ignore +from __future__ import annotations import logging -from sqlalchemy import create_engine - +from typing import Any +import sqlalchemy as sa from ckanext.datastore.backend import DatastoreBackend @@ -15,23 +15,22 @@ class DatastoreExampleSqliteBackend(DatastoreBackend): def __init__(self): self._engine = None + def execute(self, sql: str, params: dict[str, Any] | None = None): + with self._get_engine().begin() as conn: + return conn.execute(sa.text(sql), params) + def _get_engine(self): if not self._engine: - self._engine = create_engine(self.write_url) + self._engine = sa.create_engine(self.write_url) return self._engine def _insert_records(self, table, records): if len(records): for record in records: - self._get_engine().execute( - u'INSERT INTO "{0}"({1}) VALUES({2})'.format( - table, - u', '.join(record.keys()), - u', '.join(['?'] * len(record.keys())) - ), - list(record.values()) - ) - pass + sql = sa.insert( + sa.table(table, *map(sa.column, record)) + ).values(record) + self.execute(sql) def configure(self, config): self.write_url = config.get( @@ -42,13 +41,14 @@ def configure(self, config): def create(self, context, data_dict): columns = str(u', '.join( - [e['id'] + u' text' for e in data_dict['fields']])) - engine = self._get_engine() - engine.execute( - u' CREATE TABLE IF NOT EXISTS "{name}"({columns});'.format( - name=data_dict['resource_id'], + [str(sa.column(e['id'])) + " text" for e in data_dict['fields']])) + + self.execute(sa.text( + 'CREATE TABLE IF NOT EXISTS "{name}"({columns});'.format( + name=sa.table(data_dict['resource_id']), columns=columns - )) + ) + )) self._insert_records(data_dict['resource_id'], data_dict['records']) return data_dict @@ -56,15 +56,13 @@ def upsert(self, context, data_dict): raise NotImplementedError() def delete(self, context, data_dict): - engine = self._get_engine() - engine.execute(u'DROP TABLE IF EXISTS "{0}"'.format( + self.execute('DROP TABLE IF EXISTS "{0}"'.format( data_dict['resource_id'] )) return data_dict def search(self, context, data_dict): - engine = self._get_engine() - result = engine.execute(u'SELECT * FROM "{0}" LIMIT {1}'.format( + result = self.execute('SELECT * FROM "{0}" LIMIT {1}'.format( data_dict['resource_id'], data_dict.get(u'limit', 10) )) @@ -92,17 +90,15 @@ def make_public(self, context, data_dict): pass def resource_exists(self, id): - return self._get_engine().execute( - u''' - select name from sqlite_master - where type = "table" and name = "{0}"'''.format( - id) + return self.execute( + 'select name from sqlite_master where ' + + f'type = "table" and name = "{id}"' ).fetchone() def resource_fields(self, id): - engine = self._get_engine() - info = engine.execute( - u'PRAGMA table_info("{0}")'.format(id)).fetchall() + info = self.execute( + 'PRAGMA table_info("{0}")'.format(id) + ).fetchall() schema = {} for col in info: @@ -115,8 +111,7 @@ def resource_id_from_alias(self, alias): return False, alias def get_all_ids(self): - return [t.name for t in self._get_engine().execute( - u''' - select name from sqlite_master + return [t.name for t in self.execute( + '''select name from sqlite_master where type = "table"''' ).fetchall()] diff --git a/ckanext/example_idatastorebackend/test/test_plugin.py b/ckanext/example_idatastorebackend/test/test_plugin.py index 209b1ae7b0d..b9550b3125d 100644 --- a/ckanext/example_idatastorebackend/test/test_plugin.py +++ b/ckanext/example_idatastorebackend/test/test_plugin.py @@ -1,9 +1,9 @@ # encoding: utf-8 -from unittest.mock import patch, Mock, call +from unittest.mock import patch, Mock import pytest - +import sqlalchemy as sa from ckan.common import config import ckan.tests.factories as factories import ckan.tests.helpers as helpers @@ -56,44 +56,47 @@ def test_sqlite_engine(self): @pytest.mark.usefixtures("with_request_context") @pytest.mark.ckan_config(u"ckan.datastore.write_url", u"sqlite://x") @pytest.mark.ckan_config(u"ckan.datastore.read_url", u"sqlite://x") - @patch(class_to_patch + u"._get_engine") - def test_backend_functionality(self, get_engine): - engine = get_engine() - execute = engine.execute + @patch(class_to_patch + u".execute") + def test_backend_functionality(self, execute): fetchall = execute().fetchall execute.reset_mock() + COLUMN = "a;\"\' x" DatastoreExampleSqliteBackend.resource_fields = Mock( - return_value={u"meta": {}, u"schema": {u"a": u"text"}} + return_value={u"meta": {}, u"schema": {COLUMN: u"text"}} ) records = [ - {u"a": u"x"}, - {u"a": u"y"}, - {u"a": u"z"}, + {COLUMN: u"x"}, + {COLUMN: u"y"}, + {COLUMN: u"z"}, ] DatastoreBackend.set_active_backend(config) res = factories.Resource(url_type=u"datastore") helpers.call_action( u"datastore_create", resource_id=res["id"], - fields=[{u"id": u"a"}], + fields=[{u"id": COLUMN}], records=records, ) # check, create and 3 inserts assert 4 == execute.call_count - insert_query = u'INSERT INTO "{0}"(a) VALUES(?)'.format(res["id"]) - execute.assert_has_calls( - [ - call( - u' CREATE TABLE IF NOT EXISTS "{0}"(a text);'.format( - res["id"] - ) - ), - call(insert_query, ["x"]), - call(insert_query, ["y"]), - call(insert_query, ["z"]), - ] - ) + insert_query = sa.insert(sa.table( + res["id"], sa.column(COLUMN) + )) + + call_args = [ + str(call.args[0]) + for call in execute.call_args_list + ] + assert call_args == [ + 'CREATE TABLE IF NOT EXISTS "{}"({} text);'.format( + res["id"], + sa.column(COLUMN) + ), + str(insert_query.values({COLUMN: "x"})), + str(insert_query.values({COLUMN: "y"})), + str(insert_query.values({COLUMN: "z"})), + ] execute.reset_mock() fetchall.return_value = records @@ -112,9 +115,7 @@ def test_backend_functionality(self, get_engine): execute.reset_mock() helpers.call_action(u"datastore_info", id=res["id"]) # check - c = u''' - select name from sqlite_master - where type = "table" and name = "{0}"'''.format( + c = u'''select name from sqlite_master where type = "table" and name = "{0}"'''.format( res["id"] ) execute.assert_called_with(c) diff --git a/ckanext/stats/stats.py b/ckanext/stats/stats.py index b8538cf6858..e1e714eb2c2 100644 --- a/ckanext/stats/stats.py +++ b/ckanext/stats/stats.py @@ -15,7 +15,11 @@ def table(name: str): - return Table(name, model.meta.metadata, autoload=True) + return Table( + name, + model.meta.metadata, + autoload_with=model.meta.engine + ) def datetime2date(datetime_: datetime.datetime): @@ -33,7 +37,10 @@ def largest_groups(cls, limit: int = 10) -> list[tuple[Optional[model.Group], in j = join(activity, package, activity.c["object_id"] == package.c["id"]) s = ( - select([package.c["owner_org"], func.count(package.c["id"])]) + select( + package.c["owner_org"], + func.count(package.c["id"]) + ) .select_from(j) .group_by(package.c["owner_org"]) .where( @@ -64,21 +71,25 @@ def top_tags(cls, limit: int = 10, tag = table("tag") package_tag = table("package_tag") package = table("package") + if returned_tag_info == "name": - from_obj = [package_tag.join(tag)] tag_column = tag.c["name"] + s = select( + tag_column, + func.count(package_tag.c["package_id"]) + ).join(package_tag) else: - from_obj = None tag_column = package_tag.c["tag_id"] + s = select( + tag_column, + func.count(package_tag.c["package_id"]) + ) + j = join( package_tag, package, package_tag.c["package_id"] == package.c["id"] ) s = ( - select( - [tag_column, func.count(package_tag.c["package_id"])], - from_obj=from_obj, - ) - .select_from(j) + s.select_from(j) .where( and_( package_tag.c["state"] == "active", @@ -130,12 +141,10 @@ def most_edited_packages(cls, limit: int = 10) -> list[tuple[model.Package, int] s = ( select( - [package.c["id"], func.count(activity.c["id"])], - from_obj=[ - activity.join( - package, activity.c["object_id"] == package.c["id"] - ) - ], + package.c["id"], + func.count(activity.c["id"]) + ).select_from(activity).join( + package, activity.c["object_id"] == package.c["id"] ) .where( and_( @@ -167,11 +176,13 @@ def get_package_revisions(cls) -> list[Any]: package = table("package") activity = table("activity") s = select( - [package.c["id"], activity.c["timestamp"]], - from_obj=[ - activity.join(package, activity.c["object_id"] == package.c["id"]) - ], - ).order_by(activity.c["timestamp"]) + package.c["id"], + activity.c["timestamp"] + ).select_from(activity).join( + package, activity.c["object_id"] == package.c["id"] + ).order_by( + activity.c["timestamp"] + ) res = model.Session.execute(s).fetchall() return res @@ -256,14 +267,15 @@ def new_packages() -> list[tuple[str, int]]: activity = table("activity") s = ( select( - [package.c["id"], func.min(activity.c["timestamp"])], - from_obj=[ - activity.join( - package, activity.c["object_id"] == package.c["id"] - ) - ], + package.c["id"], + func.min(activity.c["timestamp"]) + ).select_from(activity) + .join( + package, activity.c["object_id"] == package.c["id"] + ) + .group_by( + package.c["id"] ) - .group_by(package.c["id"]) .order_by(func.min(activity.c["timestamp"])) ) res = model.Session.execute(s).fetchall() @@ -385,14 +397,15 @@ def deleted_packages() -> list[tuple[str, int]]: s = ( select( - [package.c["id"], func.min(activity.c["timestamp"])], - from_obj=[ - activity.join( - package, activity.c["object_id"] == package.c["id"] - ) - ], + package.c["id"], + func.min(activity.c["timestamp"]) + ).select_from(activity) + .join( + package, activity.c["object_id"] == package.c["id"] + ) + .where( + activity.c["activity_type"] == "deleted package" ) - .where(activity.c["activity_type"] == "deleted package") .group_by(package.c["id"]) .order_by(func.min(activity.c["timestamp"])) ) diff --git a/ckanext/tracking/cli/tracking.py b/ckanext/tracking/cli/tracking.py index a2ff44c5fb4..e9a5c6d3b54 100644 --- a/ckanext/tracking/cli/tracking.py +++ b/ckanext/tracking/cli/tracking.py @@ -6,6 +6,7 @@ from typing import NamedTuple, Optional import click +import sqlalchemy as sa import ckan.model as model import ckan.logic as logic @@ -49,11 +50,14 @@ def update_all(engine: model.Engine, start_date: Optional[str] = None): # No date given. See when we last have data for and get data # from 2 days before then in case new data is available. # If no date here then use 2011-01-01 as the start date - sql = '''SELECT tracking_date from tracking_summary - ORDER BY tracking_date DESC LIMIT 1;''' - result = engine.execute(sql).fetchall() + sql = sa.text(""" + SELECT tracking_date from tracking_summary + ORDER BY tracking_date DESC LIMIT 1; + """) + with engine.connect() as conn: + result = conn.execute(sql).scalar() if result: - date = result[0]['tracking_date'] + date = result date += datetime.timedelta(-2) # convert date to datetime combine = datetime.datetime.combine @@ -73,34 +77,30 @@ def update_all(engine: model.Engine, start_date: Optional[str] = None): def _total_views(engine: model.Engine): - sql = ''' - SELECT p.id, - p.name, - COALESCE(SUM(s.count), 0) AS total_views - FROM package AS p - LEFT OUTER JOIN tracking_summary AS s ON s.package_id = p.id - GROUP BY p.id, p.name - ORDER BY total_views DESC - ''' - return [ViewCount(*t) for t in engine.execute(sql).fetchall()] + sql = sa.text(""" + SELECT p.id, p.name, COALESCE(SUM(s.count), 0) AS total_views + FROM package AS p + LEFT OUTER JOIN tracking_summary AS s ON s.package_id = p.id + GROUP BY p.id, p.name + ORDER BY total_views DESC + """) + with engine.connect() as conn: + return [ViewCount(*t) for t in conn.execute(sql)] def _recent_views(engine: model.Engine, measure_from: datetime.date): - sql = ''' - SELECT p.id, - p.name, - COALESCE(SUM(s.count), 0) AS total_views - FROM package AS p - LEFT OUTER JOIN tracking_summary AS s ON s.package_id = p.id - WHERE s.tracking_date >= %(measure_from)s - GROUP BY p.id, p.name - ORDER BY total_views DESC - ''' - return [ - ViewCount(*t) for t in engine.execute( - sql, measure_from=str(measure_from) - ).fetchall() - ] + sql = sa.text(""" + SELECT p.id, p.name, COALESCE(SUM(s.count), 0) AS total_views + FROM package AS p + LEFT OUTER JOIN tracking_summary AS s ON s.package_id = p.id + WHERE s.tracking_date >= :from + GROUP BY p.id, p.name + ORDER BY total_views DESC + """) + with engine.connect() as conn: + return [ + ViewCount(*t) for t in conn.execute(sql, {"from": measure_from}) + ] def export_tracking(engine: model.Engine, output_filename: str): @@ -130,85 +130,106 @@ def export_tracking(engine: model.Engine, output_filename: str): def update_tracking(engine: model.Engine, summary_date: datetime.datetime): package_url = '/dataset/' # clear out existing data before adding new - sql = '''DELETE FROM tracking_summary - WHERE tracking_date='%s'; ''' % summary_date - engine.execute(sql) - - sql = '''SELECT DISTINCT url, user_key, - CAST(access_timestamp AS Date) AS tracking_date, - tracking_type INTO tracking_tmp - FROM tracking_raw - WHERE CAST(access_timestamp as Date)=%s; - - INSERT INTO tracking_summary - (url, count, tracking_date, tracking_type) - SELECT url, count(user_key), tracking_date, tracking_type - FROM tracking_tmp - GROUP BY url, tracking_date, tracking_type; - - DROP TABLE tracking_tmp; - COMMIT;''' - engine.execute(sql, summary_date) - - # get ids for dataset urls - sql = '''UPDATE tracking_summary t - SET package_id = COALESCE( - (SELECT id FROM package p - WHERE p.name = regexp_replace - (' ' || t.url, '^[ ]{1}(/\\w{2}){0,1}' || %s, '')) - ,'~~not~found~~') - WHERE t.package_id IS NULL - AND tracking_type = 'page';''' - engine.execute(sql, package_url) - - # update summary totals for resources - sql = '''UPDATE tracking_summary t1 - SET running_total = ( - SELECT sum(count) - FROM tracking_summary t2 - WHERE t1.url = t2.url - AND t2.tracking_date <= t1.tracking_date - ) - ,recent_views = ( - SELECT sum(count) - FROM tracking_summary t2 - WHERE t1.url = t2.url - AND t2.tracking_date <= t1.tracking_date - AND t2.tracking_date >= t1.tracking_date - 14 - ) - WHERE t1.running_total = 0 AND tracking_type = 'resource';''' - engine.execute(sql) - - # update summary totals for pages - sql = '''UPDATE tracking_summary t1 - SET running_total = ( - SELECT sum(count) - FROM tracking_summary t2 - WHERE t1.package_id = t2.package_id - AND t2.tracking_date <= t1.tracking_date - ) - ,recent_views = ( - SELECT sum(count) - FROM tracking_summary t2 - WHERE t1.package_id = t2.package_id - AND t2.tracking_date <= t1.tracking_date - AND t2.tracking_date >= t1.tracking_date - 14 - ) - WHERE t1.running_total = 0 AND tracking_type = 'page' - AND t1.package_id IS NOT NULL - AND t1.package_id != '~~not~found~~';''' - engine.execute(sql) + with engine.begin() as conn: + conn.execute( + sa.text(""" + DELETE FROM tracking_summary + WHERE tracking_date=:date; + """), + {"date": summary_date} + ) + + conn.execute( + sa.text(""" + SELECT DISTINCT url, user_key, + CAST(access_timestamp AS Date) AS tracking_date, + tracking_type INTO tracking_tmp + FROM tracking_raw + WHERE CAST(access_timestamp as Date)=:date + """), {"date": summary_date} + ) + + conn.execute( + sa.text(""" + INSERT INTO tracking_summary + (url, count, tracking_date, tracking_type) + SELECT url, count(user_key), tracking_date, tracking_type + FROM tracking_tmp + GROUP BY url, tracking_date, tracking_type; + """) + ) + conn.execute( + sa.text(""" + DROP TABLE tracking_tmp; + """) + ) + + with engine.begin() as conn: + conn.execute(sa.text(""" + UPDATE tracking_summary t + SET package_id = + COALESCE( + (SELECT id FROM package p WHERE p.name = + regexp_replace (' ' || t.url, '^[ ]{1}(/\\w{2}){0,1}' || :url, '')), + '~~not~found~~' + ) + WHERE t.package_id IS NULL + AND tracking_type = 'page' + """), {"url": package_url}) + + # update summary totals for resources + conn.execute(sa.text(""" + UPDATE tracking_summary t1 + SET running_total = ( + SELECT sum(count) + FROM tracking_summary t2 + WHERE t1.url = t2.url + AND t2.tracking_date <= t1.tracking_date + ) + ,recent_views = ( + SELECT sum(count) + FROM tracking_summary t2 + WHERE t1.url = t2.url + AND t2.tracking_date <= t1.tracking_date + AND t2.tracking_date >= t1.tracking_date - 14 + ) + WHERE t1.running_total = 0 AND tracking_type = 'resource' + """)) + + # update summary totals for pages + conn.execute(sa.text(""" + UPDATE tracking_summary t1 + SET running_total = ( + SELECT sum(count) + FROM tracking_summary t2 + WHERE t1.package_id = t2.package_id + AND t2.tracking_date <= t1.tracking_date + ) + ,recent_views = ( + SELECT sum(count) + FROM tracking_summary t2 + WHERE t1.package_id = t2.package_id + AND t2.tracking_date <= t1.tracking_date + AND t2.tracking_date >= t1.tracking_date - 14 + ) + WHERE t1.running_total = 0 AND tracking_type = 'page' + AND t1.package_id IS NOT NULL + AND t1.package_id != '~~not~found~~' + """)) def update_tracking_solr(engine: model.Engine, start_date: datetime.datetime): - sql = '''SELECT package_id FROM tracking_summary - where package_id!='~~not~found~~' - and tracking_date >= %s;''' - results = engine.execute(sql, start_date) + sql = sa.text(""" + SELECT package_id FROM tracking_summary + where package_id!='~~not~found~~' + and tracking_date >= :date + """) + with engine.connect() as conn: + results = conn.execute(sql, {"date": start_date}) package_ids: set[str] = set() for row in results: - package_ids.add(row['package_id']) + package_ids.add(row[0]) total = len(package_ids) not_found = 0 diff --git a/ckanext/tracking/middleware.py b/ckanext/tracking/middleware.py index eebe190cbc8..ea6b854f255 100644 --- a/ckanext/tracking/middleware.py +++ b/ckanext/tracking/middleware.py @@ -1,7 +1,10 @@ import hashlib +from typing import cast from urllib.parse import unquote +import sqlalchemy as sa + from ckan.model.meta import engine from ckan.common import request from ckan.types import Response @@ -33,8 +36,13 @@ def track_request(response: Response) -> Response: # store key/data here sql = '''INSERT INTO tracking_raw (user_key, url, tracking_type) - VALUES (%s, %s, %s)''' - engine.execute( # type: ignore - sql, key, data.get('url'), data.get('type') - ) + VALUES (:key, :url, :type)''' + + with cast(sa.engine.Engine, engine).begin() as conn: + conn.execute(sa.text(sql), { + "key": key, + "url": data.get("url"), + "type": data.get("type"), + }) + return response diff --git a/ckanext/tracking/model.py b/ckanext/tracking/model.py index 1f3fe60f459..6469d59ac0e 100644 --- a/ckanext/tracking/model.py +++ b/ckanext/tracking/model.py @@ -18,8 +18,8 @@ import datetime from sqlalchemy import types, Column, Table, text +from sqlalchemy.orm import Mapped -from ckan.plugins.toolkit import BaseModel from ckan.model import meta from ckan.model import domain_object @@ -49,16 +49,15 @@ ) -class TrackingSummary(domain_object.DomainObject, BaseModel): # type: ignore - __tablename__ = 'tracking_summary' - url: str - package_id: str - tracking_type: str +class TrackingSummary(domain_object.DomainObject): + url: Mapped[str] + package_id: Mapped[str] + tracking_type: Mapped[str] # count attribute shadows DomainObject.count() - count: int - running_total: int - recent_views: int - tracking_date: datetime.datetime + count: Mapped[int] + running_total: Mapped[int] + recent_views: Mapped[int] + tracking_date: Mapped[datetime.datetime] @classmethod def get_for_package(cls, package_id: str) -> dict[str, int]: @@ -80,3 +79,6 @@ def get_for_resource(cls, url: str) -> dict[str, int]: return {'total': data.running_total, 'recent': data.recent_views} return {'total': 0, 'recent': 0} + + +meta.registry.map_imperatively(TrackingSummary, tracking_summary_table) diff --git a/dev-requirements.txt b/dev-requirements.txt index 1ea7c6e7b73..ffbe77339e9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -13,7 +13,6 @@ pip-tools==7.3.0 Pillow==10.0.1 responses==0.23.3 sphinx-rtd-theme==1.3.0 -sqlalchemy-stubs==0.4 sphinx==7.1.2 toml==0.10.2 towncrier==22.12.0 diff --git a/pyproject.toml b/pyproject.toml index 5c9ddf79d63..1f8528e8960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,9 @@ preview = true [tool.pytest.ini_options] filterwarnings = [ - "ignore::sqlalchemy.exc.SADeprecationWarning", - "ignore::sqlalchemy.exc.SAWarning", "ignore::DeprecationWarning", - "ignore::bs4.GuessedAtParserWarning" # using lxml as default parser + "ignore::bs4.GuessedAtParserWarning", # using lxml as default parser + "error::sqlalchemy.exc.RemovedIn20Warning", ] markers = [ "ckan_config: patch configuration used by other fixtures via (key, value) pair.", diff --git a/requirements.txt b/requirements.txt index d913217cf6c..3f7c2ea7b2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -137,7 +137,7 @@ sqlalchemy[mypy]==1.4.49 # via # -r requirements.in # alembic -sqlalchemy2-stubs==0.0.2a27 +sqlalchemy2-stubs==0.0.2a36 # via sqlalchemy sqlparse==0.4.4 # via -r requirements.in diff --git a/scripts/4042_fix_resource_extras.py b/scripts/4042_fix_resource_extras.py deleted file mode 100644 index 90cadedca40..00000000000 --- a/scripts/4042_fix_resource_extras.py +++ /dev/null @@ -1,95 +0,0 @@ -# encoding: utf-8 -u''' -This script fixes resource extras affected by a bug introduced in #3425 and -raised in #4042 - -#3422 (implemented in #3425) introduced a major bug where if a resource was -deleted and the DataStore was active, extras from all resources on the site -where changed. This is now fixed starting from version 2.7.3 but if your -database is already affected you will need to run this script to restore -the extras to their previous state. - -Remember, you only need to run this script if all the following are true: - - 1. You are currently running CKAN 2.7.0 or 2.7.2, and - 2. You have enabled the DataStore, and - 3. One or more resources with data on the DataStore have been deleted - (or your suspect they might have been) - -If all these are true you can run this script like this: - - python fix_resource_extras.py -c path/to/the/ini/file - -As ever when making changes in the database please do a backup before running -this script. - -Note that it requires SQLAlchemy, so you should run it with the virtualenv -activated. -''' - -import json -from configparser import ConfigParser -from argparse import ArgumentParser - -from sqlalchemy import create_engine -from sqlalchemy.sql import text - -config = ConfigParser() -parser = ArgumentParser() -parser.add_argument( - u'-c', u'--config', help=u'Configuration file', required=True) - -SIMPLE_Q = ( - u"SELECT id, r.extras, rr.extras revision " - u"FROM resource r JOIN resource_revision rr " - u"USING(id, revision_id) WHERE r.extras != rr.extras" -) -UPDATE_Q = text(u"UPDATE resource SET extras = :extras WHERE id = :id") - - -def main(): - args = parser.parse_args() - config.read(args.config) - engine = create_engine(config.get(u'app:main', u'sqlalchemy.url')) - records = engine.execute(SIMPLE_Q) - - total = records.rowcount - print(u'Found {} datasets with inconsistent extras.'.format(total)) - - skip_confirmation = False - i = 0 - - while True: - row = records.fetchone() - if row is None: - break - - id, current, rev = row - current_json = json.loads(current) - rev_json = json.loads(rev) - if (dict(current_json, datastore_active=None) == - dict(rev_json, datastore_active=None)): - continue - i += 1 - - print(u'[{:{}}/{}] Resource <{}>'.format( - i, len(str(total)), total, id)) - print(u'\tCurrent extras state in DB: {}'.format(current)) - print(u'\tAccording to the revision: {}'.format(rev)) - if not skip_confirmation: - choice = input( - u'\tRequired action: ' - u'y - rewrite; n - skip; ! - rewrite all; q - skip all: ' - ).strip().lower() - if choice == u'q': - break - elif choice == u'n': - continue - elif choice == u'!': - skip_confirmation = True - engine.execute(UPDATE_Q, id=id, extras=rev) - print(u'\tUpdated') - - -if __name__ == u'__main__': - main()