diff --git a/LNX-docker-compose.yml b/LNX-docker-compose.yml index eaf3a48c..3b2e15e1 100644 --- a/LNX-docker-compose.yml +++ b/LNX-docker-compose.yml @@ -45,7 +45,7 @@ services: interval: 15s fakeservices.datajoint.io: <<: *net - image: datajoint/nginx:v0.2.7 + image: datajoint/nginx:v0.2.8 environment: - ADD_db_TYPE=DATABASE - ADD_db_ENDPOINT=db:3306 diff --git a/tests/__init__.py b/tests/__init__.py index 219f7f5c..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,26 +0,0 @@ -import datajoint as dj -from packaging import version -import pytest -import os - -PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest") - -# Connection for testing -CONN_INFO = dict( - host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"), - user=os.environ.get("DJ_TEST_USER", "datajoint"), - password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"), -) - -CONN_INFO_ROOT = dict( - host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"), - user=os.environ.get("DJ_USER", "root"), - password=os.environ.get("DJ_PASS", "simple"), -) - -S3_CONN_INFO = dict( - endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"), - access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"), - secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"), - bucket=os.environ.get("S3_BUCKET", "datajoint.test"), -) diff --git a/tests/conftest.py b/tests/conftest.py index a9474b50..cc2c8062 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import datajoint as dj from packaging import version -from typing import Dict +from typing import Dict, List import os from os import environ, remove import minio @@ -18,9 +18,6 @@ DataJointError, ) from . import ( - PREFIX, - CONN_INFO, - S3_CONN_INFO, schema, schema_simple, schema_advanced, @@ -30,6 +27,11 @@ ) +@pytest.fixture(scope="session") +def prefix(): + return os.environ.get("DJ_TEST_DB_PREFIX", "djtest") + + @pytest.fixture(scope="session") def monkeysession(): with pytest.MonkeyPatch.context() as mp: @@ -81,7 +83,7 @@ def connection_root_bare(db_creds_root): @pytest.fixture(scope="session") -def connection_root(connection_root_bare): +def connection_root(connection_root_bare, prefix): """Root user database connection.""" dj.config["safemode"] = False conn_root = connection_root_bare @@ -136,7 +138,7 @@ def connection_root(connection_root_bare): # Teardown conn_root.query("SET FOREIGN_KEY_CHECKS=0") - cur = conn_root.query('SHOW DATABASES LIKE "{}\\_%%"'.format(PREFIX)) + cur = conn_root.query('SHOW DATABASES LIKE "{}\\_%%"'.format(prefix)) for db in cur.fetchall(): conn_root.query("DROP DATABASE `{}`".format(db[0])) conn_root.query("SET FOREIGN_KEY_CHECKS=1") @@ -151,9 +153,9 @@ def connection_root(connection_root_bare): @pytest.fixture(scope="session") -def connection_test(connection_root, db_creds_test): +def connection_test(connection_root, prefix, db_creds_test): """Test user database connection.""" - database = f"{PREFIX}%%" + database = f"{prefix}%%" permission = "ALL PRIVILEGES" # Create MySQL users @@ -191,7 +193,17 @@ def connection_test(connection_root, db_creds_test): @pytest.fixture(scope="session") -def stores_config(tmpdir_factory): +def s3_creds() -> Dict: + return dict( + endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"), + access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"), + secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"), + bucket=os.environ.get("S3_BUCKET", "datajoint.test"), + ) + + +@pytest.fixture(scope="session") +def stores_config(s3_creds, tmpdir_factory): stores_config = { "raw": dict(protocol="file", location=tmpdir_factory.mktemp("raw")), "repo": dict( @@ -200,7 +212,7 @@ def stores_config(tmpdir_factory): location=tmpdir_factory.mktemp("repo"), ), "repo-s3": dict( - S3_CONN_INFO, + s3_creds, protocol="s3", location="dj/repo", stage=tmpdir_factory.mktemp("repo-s3"), @@ -209,7 +221,7 @@ def stores_config(tmpdir_factory): protocol="file", location=tmpdir_factory.mktemp("local"), subfolding=(1, 1) ), "share": dict( - S3_CONN_INFO, protocol="s3", location="dj/store/repo", subfolding=(2, 4) + s3_creds, protocol="s3", location="dj/store/repo", subfolding=(2, 4) ), } return stores_config @@ -238,9 +250,9 @@ def mock_cache(tmpdir_factory): @pytest.fixture -def schema_any(connection_test): +def schema_any(connection_test, prefix): schema_any = dj.Schema( - PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test + prefix + "_test1", schema.LOCALS_ANY, connection=connection_test ) assert schema.LOCALS_ANY, "LOCALS_ANY is empty" try: @@ -292,9 +304,9 @@ def schema_any(connection_test): @pytest.fixture -def schema_simp(connection_test): +def schema_simp(connection_test, prefix): schema = dj.Schema( - PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test + prefix + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test ) schema(schema_simple.IJ) schema(schema_simple.JI) @@ -319,9 +331,9 @@ def schema_simp(connection_test): @pytest.fixture -def schema_adv(connection_test): +def schema_adv(connection_test, prefix): schema = dj.Schema( - PREFIX + "_advanced", + prefix + "_advanced", schema_advanced.LOCALS_ADVANCED, connection=connection_test, ) @@ -339,9 +351,11 @@ def schema_adv(connection_test): @pytest.fixture -def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache): +def schema_ext( + connection_test, enable_filepath_feature, mock_stores, mock_cache, prefix +): schema = dj.Schema( - PREFIX + "_extern", + prefix + "_extern", context=schema_external.LOCALS_EXTERNAL, connection=connection_test, ) @@ -358,9 +372,9 @@ def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache @pytest.fixture -def schema_uuid(connection_test): +def schema_uuid(connection_test, prefix): schema = dj.Schema( - PREFIX + "_test1", + prefix + "_test1", context=schema_uuid_module.LOCALS_UUID, connection=connection_test, ) @@ -386,12 +400,12 @@ def http_client(): @pytest.fixture(scope="session") -def minio_client_bare(http_client): +def minio_client_bare(s3_creds, http_client): """Initialize MinIO with an endpoint and access/secret keys.""" client = minio.Minio( - S3_CONN_INFO["endpoint"], - access_key=S3_CONN_INFO["access_key"], - secret_key=S3_CONN_INFO["secret_key"], + s3_creds["endpoint"], + access_key=s3_creds["access_key"], + secret_key=s3_creds["secret_key"], secure=True, http_client=http_client, ) @@ -399,12 +413,12 @@ def minio_client_bare(http_client): @pytest.fixture(scope="session") -def minio_client(minio_client_bare): +def minio_client(s3_creds, minio_client_bare): """Initialize a MinIO client and create buckets for testing session.""" # Setup MinIO bucket aws_region = "us-east-1" try: - minio_client_bare.make_bucket(S3_CONN_INFO["bucket"], location=aws_region) + minio_client_bare.make_bucket(s3_creds["bucket"], location=aws_region) except minio.error.S3Error as e: if e.code != "BucketAlreadyOwnedByYou": raise e @@ -412,11 +426,84 @@ def minio_client(minio_client_bare): yield minio_client_bare # Teardown S3 - objs = list(minio_client_bare.list_objects(S3_CONN_INFO["bucket"], recursive=True)) + objs = list(minio_client_bare.list_objects(s3_creds["bucket"], recursive=True)) objs = [ minio_client_bare.remove_object( - S3_CONN_INFO["bucket"], o.object_name.encode("utf-8") + s3_creds["bucket"], o.object_name.encode("utf-8") ) for o in objs ] - minio_client_bare.remove_bucket(S3_CONN_INFO["bucket"]) + minio_client_bare.remove_bucket(s3_creds["bucket"]) + + +@pytest.fixture +def test(schema_any): + yield schema.TTest() + + +@pytest.fixture +def test2(schema_any): + yield schema.TTest2() + + +@pytest.fixture +def test_extra(schema_any): + yield schema.TTestExtra() + + +@pytest.fixture +def test_no_extra(schema_any): + yield schema.TTestNoExtra() + + +@pytest.fixture +def user(schema_any): + return schema.User() + + +@pytest.fixture +def lang(schema_any): + yield schema.Language() + + +@pytest.fixture +def languages(lang) -> List: + og_contents = lang.contents + languages = og_contents.copy() + yield languages + lang.contents = og_contents + + +@pytest.fixture +def subject(schema_any): + yield schema.Subject() + + +@pytest.fixture +def experiment(schema_any): + return schema.Experiment() + + +@pytest.fixture +def ephys(schema_any): + return schema.Ephys() + + +@pytest.fixture +def img(schema_any): + return schema.Image() + + +@pytest.fixture +def trial(schema_any): + return schema.Trial() + + +@pytest.fixture +def channel(schema_any): + return schema.Ephys.Channel() + + +@pytest.fixture +def trash(schema_any): + return schema.UberTrash() diff --git a/tests/schema_alter.py b/tests/schema_alter.py new file mode 100644 index 00000000..d607bc7c --- /dev/null +++ b/tests/schema_alter.py @@ -0,0 +1,56 @@ +import datajoint as dj +import inspect + + +class Experiment(dj.Imported): + original_definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject + --- + experiment_date :date # date when experiment was started + -> [nullable] User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + definition1 = """ # Experiment + -> Subject + experiment_id :smallint # experiment number for this subject + --- + data_path : int # some number + extra=null : longblob # just testing + -> [nullable] User + subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + +class Parent(dj.Manual): + definition = """ + parent_id: int + """ + + class Child(dj.Part): + definition = """ + -> Parent + """ + definition_new = """ + -> master + --- + child_id=null: int + """ + + class Grandchild(dj.Part): + definition = """ + -> master.Child + """ + definition_new = """ + -> master.Child + --- + grandchild_id=null: int + """ + + +LOCALS_ALTER = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ALTER) diff --git a/tests/schema_external.py b/tests/schema_external.py index 294ecb07..ce51af9c 100644 --- a/tests/schema_external.py +++ b/tests/schema_external.py @@ -5,7 +5,6 @@ import tempfile import inspect import datajoint as dj -from . import PREFIX, CONN_INFO, S3_CONN_INFO import numpy as np diff --git a/tests/schema_uuid.py b/tests/schema_uuid.py index 6bf994b5..00b45ee7 100644 --- a/tests/schema_uuid.py +++ b/tests/schema_uuid.py @@ -1,7 +1,6 @@ import uuid import inspect import datajoint as dj -from . import PREFIX, CONN_INFO top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000") diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index bbe8456f..714da8a6 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -6,9 +6,11 @@ from itertools import zip_longest from . import schema_adapted from .schema_adapted import Connectivity, Layout -from . import PREFIX, S3_CONN_INFO -SCHEMA_NAME = PREFIX + "_test_custom_datatype" + +@pytest.fixture +def schema_name(prefix): + return prefix + "_test_custom_datatype" @pytest.fixture @@ -22,24 +24,21 @@ def schema_ad( adapted_graph_instance, enable_adapted_types, enable_filepath_feature, + s3_creds, + tmpdir, + schema_name, ): - stores_config = { + dj.config["stores"] = { "repo-s3": dict( - S3_CONN_INFO, - protocol="s3", - location="adapted/repo", - stage=tempfile.mkdtemp(), + s3_creds, protocol="s3", location="adapted/repo", stage=str(tmpdir) ) } - dj.config["stores"] = stores_config - layout_to_filepath = schema_adapted.LayoutToFilepath() context = { **schema_adapted.LOCALS_ADAPTED, "graph": adapted_graph_instance, - "layout_to_filepath": layout_to_filepath, + "layout_to_filepath": schema_adapted.LayoutToFilepath(), } - schema = dj.schema(SCHEMA_NAME, context=context, connection=connection_test) - graph = adapted_graph_instance + schema = dj.schema(schema_name, context=context, connection=connection_test) schema(schema_adapted.Connectivity) schema(schema_adapted.Layout) yield schema @@ -47,19 +46,19 @@ def schema_ad( @pytest.fixture -def local_schema(schema_ad): +def local_schema(schema_ad, schema_name): """Fixture for testing spawned classes""" - local_schema = dj.Schema(SCHEMA_NAME) + local_schema = dj.Schema(schema_name) local_schema.spawn_missing_classes() yield local_schema local_schema.drop() @pytest.fixture -def schema_virtual_module(schema_ad, adapted_graph_instance): +def schema_virtual_module(schema_ad, adapted_graph_instance, schema_name): """Fixture for testing virtual modules""" schema_virtual_module = dj.VirtualModule( - "virtual_module", SCHEMA_NAME, add_objects={"graph": adapted_graph_instance} + "virtual_module", schema_name, add_objects={"graph": adapted_graph_instance} ) return schema_virtual_module @@ -93,7 +92,6 @@ def test_adapted_filepath_type(schema_ad, minio_client): t = Layout() t.insert1((0, layout)) result = t.fetch1("layout") - # TODO: may fail, used to be assert_dict_equal assert result == layout t.delete() c.delete() diff --git a/tests/test_admin.py b/tests/test_admin.py index 1ab89c1a..43b418f8 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -7,19 +7,17 @@ import pymysql import pytest -from . import CONN_INFO_ROOT - @pytest.fixture() -def user_alice() -> dict: +def user_alice(db_creds_root) -> dict: # set up - reset config, log in as root, and create a new user alice # reset dj.config manually because its state may be changed by these tests if os.path.exists(dj.settings.LOCALCONFIG): os.remove(dj.settings.LOCALCONFIG) dj.config["database.password"] = os.getenv("DJ_PASS") - root_conn = dj.conn(**CONN_INFO_ROOT, reset=True) + root_conn = dj.conn(**db_creds_root, reset=True) new_credentials = dict( - host=CONN_INFO_ROOT["host"], + host=db_creds_root["host"], user="alice", password="oldpass", ) diff --git a/tests/test_aggr_regressions.py b/tests/test_aggr_regressions.py index b4d4e080..7cc5119e 100644 --- a/tests/test_aggr_regressions.py +++ b/tests/test_aggr_regressions.py @@ -4,18 +4,16 @@ import pytest import datajoint as dj -from . import PREFIX import uuid from .schema_uuid import Topic, Item, top_level_namespace_id from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS @pytest.fixture(scope="function") -def schema_aggr_reg(connection_test): - context = LOCALS_AGGR_REGRESS +def schema_aggr_reg(connection_test, prefix): schema = dj.Schema( - PREFIX + "_aggr_regress", - context=context, + prefix + "_aggr_regress", + context=LOCALS_AGGR_REGRESS, connection=connection_test, ) schema(R) @@ -26,11 +24,10 @@ def schema_aggr_reg(connection_test): @pytest.fixture(scope="function") -def schema_aggr_reg_with_abx(connection_test): - context = LOCALS_AGGR_REGRESS +def schema_aggr_reg_with_abx(connection_test, prefix): schema = dj.Schema( - PREFIX + "_aggr_regress_with_abx", - context=context, + prefix + "_aggr_regress_with_abx", + context=LOCALS_AGGR_REGRESS, connection=connection_test, ) schema(R) diff --git a/tests/test_alter.py b/tests/test_alter.py index a78a07f2..5146d626 100644 --- a/tests/test_alter.py +++ b/tests/test_alter.py @@ -1,60 +1,9 @@ import pytest import re import datajoint as dj -from . import schema as schema_any_module, PREFIX +from . import schema as schema_any_module +from .schema_alter import Experiment, Parent, LOCALS_ALTER - -class Experiment(dj.Imported): - original_definition = """ # information about experiments - -> Subject - experiment_id :smallint # experiment number for this subject - --- - experiment_date :date # date when experiment was started - -> [nullable] User - data_path="" :varchar(255) # file path to recorded data - notes="" :varchar(2048) # e.g. purpose of experiment - entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp - """ - - definition1 = """ # Experiment - -> Subject - experiment_id :smallint # experiment number for this subject - --- - data_path : int # some number - extra=null : longblob # just testing - -> [nullable] User - subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment - entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp - """ - - -class Parent(dj.Manual): - definition = """ - parent_id: int - """ - - class Child(dj.Part): - definition = """ - -> Parent - """ - definition_new = """ - -> master - --- - child_id=null: int - """ - - class Grandchild(dj.Part): - definition = """ - -> master.Child - """ - definition_new = """ - -> master.Child - --- - grandchild_id=null: int - """ - - -LOCALS_ALTER = {"Experiment": Experiment, "Parent": Parent} COMBINED_CONTEXT = { **schema_any_module.LOCALS_ANY, **LOCALS_ALTER, @@ -71,6 +20,19 @@ def schema_alter(connection_test, schema_any): class TestAlter: + def verify_alter(self, schema_alter, table, attribute_sql): + definition_original = schema_alter.connection.query( + f"SHOW CREATE TABLE {table.full_table_name}" + ).fetchone()[1] + table.definition = table.definition_new + table.alter(prompt=False) + definition_new = schema_alter.connection.query( + f"SHOW CREATE TABLE {table.full_table_name}" + ).fetchone()[1] + assert ( + re.sub(f"{attribute_sql},\n ", "", definition_new) == definition_original + ) + def test_alter(self, schema_alter): original = schema_alter.connection.query( "SHOW CREATE TABLE " + Experiment.full_table_name @@ -89,19 +51,6 @@ def test_alter(self, schema_alter): assert altered != restored assert original == restored - def verify_alter(self, schema_alter, table, attribute_sql): - definition_original = schema_alter.connection.query( - f"SHOW CREATE TABLE {table.full_table_name}" - ).fetchone()[1] - table.definition = table.definition_new - table.alter(prompt=False) - definition_new = schema_alter.connection.query( - f"SHOW CREATE TABLE {table.full_table_name}" - ).fetchone()[1] - assert ( - re.sub(f"{attribute_sql},\n ", "", definition_new) == definition_original - ) - def test_alter_part(self, schema_alter): """ https://github.com/datajoint/datajoint-python/issues/936 diff --git a/tests/test_attach.py b/tests/test_attach.py index 654feef5..b3ecea04 100644 --- a/tests/test_attach.py +++ b/tests/test_attach.py @@ -1,15 +1,14 @@ import pytest -import tempfile from pathlib import Path import os from .schema_external import Attach -def test_attach_attributes(schema_ext, minio_client): +def test_attach_attributes(schema_ext, minio_client, tmpdir_factory): """Test saving files in attachments""" # create a mock file table = Attach() - source_folder = tempfile.mkdtemp() + source_folder = tmpdir_factory.mktemp("source") for i in range(2): attach1 = Path(source_folder, "attach1.img") data1 = os.urandom(100) @@ -21,7 +20,7 @@ def test_attach_attributes(schema_ext, minio_client): f.write(data2) table.insert1(dict(attach=i, img=attach1, txt=attach2)) - download_folder = Path(tempfile.mkdtemp()) + download_folder = Path(tmpdir_factory.mktemp("download")) keys, path1, path2 = table.fetch( "KEY", "img", "txt", download_path=download_folder, order_by="KEY" ) @@ -43,11 +42,11 @@ def test_attach_attributes(schema_ext, minio_client): assert p2 == path2[0] -def test_return_string(schema_ext, minio_client): +def test_return_string(schema_ext, minio_client, tmpdir_factory): """Test returning string on fetch""" # create a mock file table = Attach() - source_folder = tempfile.mkdtemp() + source_folder = tmpdir_factory.mktemp("source") attach1 = Path(source_folder, "attach1.img") data1 = os.urandom(100) @@ -59,7 +58,7 @@ def test_return_string(schema_ext, minio_client): f.write(data2) table.insert1(dict(attach=2, img=attach1, txt=attach2)) - download_folder = Path(tempfile.mkdtemp()) + download_folder = Path(tmpdir_factory.mktemp("download")) keys, path1, path2 = table.fetch( "KEY", "img", "txt", download_path=download_folder, order_by="KEY" ) diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index 25f8e16e..d1f0726e 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -1,168 +1,127 @@ import pytest -from . import schema, PREFIX from datajoint import DataJointError import datajoint as dj import pymysql - - -class TestPopulate: - """ - Test base relations: insert, delete - """ - - @classmethod - def setup_class(cls): - cls.user = schema.User() - cls.subject = schema.Subject() - cls.experiment = schema.Experiment() - cls.trial = schema.Trial() - cls.ephys = schema.Ephys() - cls.channel = schema.Ephys.Channel() - - @classmethod - def teardown_class(cls): - """Delete automatic tables just in case""" - for autopop_table in ( - cls.channel, - cls.ephys, - cls.trial.Condition, - cls.trial, - cls.experiment, - ): - try: - autopop_table.delete_quick() - except (pymysql.err.OperationalError, dj.errors.MissingTableError): - # Table doesn't exist - pass - - def test_populate(self, schema_any): - # test simple populate - assert self.subject, "root tables are empty" - assert not self.experiment, "table already filled?" - self.experiment.populate() - assert ( - len(self.experiment) - == len(self.subject) * self.experiment.fake_experiments_per_subject - ) - - # test restricted populate - assert not self.trial, "table already filled?" - restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0] - d = self.trial.connection.dependencies - d.load() - self.trial.populate(restriction) - assert self.trial, "table was not populated" - key_source = self.trial.key_source - assert len(key_source & self.trial) == len(key_source & restriction) - assert len(key_source - self.trial) == len(key_source - restriction) - - # test subtable populate - assert not self.ephys - assert not self.channel - self.ephys.populate() - assert self.ephys - assert self.channel - - def test_populate_with_success_count(self, schema_any): - # test simple populate - assert self.subject, "root tables are empty" - assert not self.experiment, "table already filled?" - ret = self.experiment.populate() - success_count = ret["success_count"] - assert len(self.experiment.key_source & self.experiment) == success_count - - # test restricted populate - assert not self.trial, "table already filled?" - restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0] - d = self.trial.connection.dependencies - d.load() - ret = self.trial.populate(restriction, suppress_errors=True) - success_count = ret["success_count"] - assert len(self.trial.key_source & self.trial) == success_count - - def test_populate_exclude_error_and_ignore_jobs(self, schema_any): - # test simple populate - assert self.subject, "root tables are empty" - assert not self.experiment, "table already filled?" - - keys = self.experiment.key_source.fetch("KEY", limit=2) - for idx, key in enumerate(keys): - if idx == 0: - schema_any.jobs.ignore(self.experiment.table_name, key) - else: - schema_any.jobs.error(self.experiment.table_name, key, "") - - self.experiment.populate(reserve_jobs=True) - assert ( - len(self.experiment.key_source & self.experiment) - == len(self.experiment.key_source) - 2 - ) - - def test_allow_direct_insert(self, schema_any): - assert self.subject, "root tables are empty" - key = self.subject.fetch("KEY", limit=1)[0] - key["experiment_id"] = 1000 - key["experiment_date"] = "2018-10-30" - self.experiment.insert1(key, allow_direct_insert=True) - - def test_multi_processing(self, schema_any): - assert self.subject, "root tables are empty" - assert not self.experiment, "table already filled?" - self.experiment.populate(processes=2) - assert ( - len(self.experiment) - == len(self.subject) * self.experiment.fake_experiments_per_subject - ) - - def test_max_multi_processing(self, schema_any): - assert self.subject, "root tables are empty" - assert not self.experiment, "table already filled?" - self.experiment.populate(processes=None) - assert ( - len(self.experiment) - == len(self.subject) * self.experiment.fake_experiments_per_subject - ) - - def test_allow_insert(self, schema_any): - assert self.subject, "root tables are empty" - key = self.subject.fetch("KEY")[0] - key["experiment_id"] = 1001 - key["experiment_date"] = "2018-10-30" - with pytest.raises(DataJointError): - self.experiment.insert1(key) - - def test_load_dependencies(self): - schema = dj.Schema(f"{PREFIX}_load_dependencies_populate") - - @schema - class ImageSource(dj.Lookup): - definition = """ - image_source_id: int - """ - contents = [(0,)] - - @schema - class Image(dj.Imported): - definition = """ - -> ImageSource - --- - image_data: longblob - """ - - def make(self, key): - self.insert1(dict(key, image_data=dict())) - - Image.populate() - - @schema - class Crop(dj.Computed): - definition = """ - -> Image - --- - crop_image: longblob - """ - - def make(self, key): - self.insert1(dict(key, crop_image=dict())) - - Crop.populate() +from . import schema + + +def test_populate(trial, subject, experiment, ephys, channel): + # test simple populate + assert subject, "root tables are empty" + assert not experiment, "table already filled?" + experiment.populate() + assert len(experiment) == len(subject) * experiment.fake_experiments_per_subject + + # test restricted populate + assert not trial, "table already filled?" + restriction = subject.proj(animal="subject_id").fetch("KEY")[0] + d = trial.connection.dependencies + d.load() + trial.populate(restriction) + assert trial, "table was not populated" + key_source = trial.key_source + assert len(key_source & trial) == len(key_source & restriction) + assert len(key_source - trial) == len(key_source - restriction) + + # test subtable populate + assert not ephys + assert not channel + ephys.populate() + assert ephys + assert channel + + +def test_populate_with_success_count(subject, experiment, trial): + # test simple populate + assert subject, "root tables are empty" + assert not experiment, "table already filled?" + ret = experiment.populate() + success_count = ret["success_count"] + assert len(experiment.key_source & experiment) == success_count + + # test restricted populate + assert not trial, "table already filled?" + restriction = subject.proj(animal="subject_id").fetch("KEY")[0] + d = trial.connection.dependencies + d.load() + ret = trial.populate(restriction, suppress_errors=True) + success_count = ret["success_count"] + assert len(trial.key_source & trial) == success_count + + +def test_populate_exclude_error_and_ignore_jobs(schema_any, subject, experiment): + # test simple populate + assert subject, "root tables are empty" + assert not experiment, "table already filled?" + + keys = experiment.key_source.fetch("KEY", limit=2) + for idx, key in enumerate(keys): + if idx == 0: + schema_any.jobs.ignore(experiment.table_name, key) + else: + schema_any.jobs.error(experiment.table_name, key, "") + + experiment.populate(reserve_jobs=True) + assert len(experiment.key_source & experiment) == len(experiment.key_source) - 2 + + +def test_allow_direct_insert(subject, experiment): + assert subject, "root tables are empty" + key = subject.fetch("KEY", limit=1)[0] + key["experiment_id"] = 1000 + key["experiment_date"] = "2018-10-30" + experiment.insert1(key, allow_direct_insert=True) + + +@pytest.mark.parametrize("processes", [None, 2]) +def test_multi_processing(subject, experiment, processes): + assert subject, "root tables are empty" + assert not experiment, "table already filled?" + experiment.populate(processes=None) + assert len(experiment) == len(subject) * experiment.fake_experiments_per_subject + + +def test_allow_insert(subject, experiment): + assert subject, "root tables are empty" + key = subject.fetch("KEY")[0] + key["experiment_id"] = 1001 + key["experiment_date"] = "2018-10-30" + with pytest.raises(DataJointError): + experiment.insert1(key) + + +def test_load_dependencies(prefix): + schema = dj.Schema(f"{prefix}_load_dependencies_populate") + + @schema + class ImageSource(dj.Lookup): + definition = """ + image_source_id: int + """ + contents = [(0,)] + + @schema + class Image(dj.Imported): + definition = """ + -> ImageSource + --- + image_data: longblob + """ + + def make(self, key): + self.insert1(dict(key, image_data=dict())) + + Image.populate() + + @schema + class Crop(dj.Computed): + definition = """ + -> Image + --- + crop_image: longblob + """ + + def make(self, key): + self.insert1(dict(key, crop_image=dict())) + + Crop.populate() diff --git a/tests/test_blob.py b/tests/test_blob.py index e5548898..12039f7f 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,3 +1,4 @@ +import pytest import datajoint as dj import timeit import numpy as np @@ -10,6 +11,13 @@ from .schema import Longblob +@pytest.fixture +def enable_feature_32bit_dims(): + dj.blob.use_32bit_dims = True + yield + dj.blob.use_32bit_dims = False + + def test_pack(): for x in ( 32, @@ -180,6 +188,8 @@ def test_insert_longblob(schema_any): assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all() (Longblob & "id=1").delete() + +def test_insert_longblob_32bit(schema_any, enable_feature_32bit_dims): query_32_blob = ( "INSERT INTO djtest_test1.longblob (id, data) VALUES (1, " "X'6D596D00530200000001000000010000000400000068697473007369646573007461736B73007374" @@ -190,8 +200,8 @@ def test_insert_longblob(schema_any): "00000041020000000100000008000000040000000000000053007400610067006500200031003000')" ) dj.conn().query(query_32_blob).fetchall() - dj.blob.use_32bit_dims = True - assert (Longblob & "id=1").fetch1() == { + fetched = (Longblob & "id=1").fetch1() + expected = { "id": 1, "data": np.rec.array( [ @@ -207,26 +217,34 @@ def test_insert_longblob(schema_any): dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")], ), } + assert fetched["id"] == expected["id"] + assert np.array_equal(fetched["data"], expected["data"]) (Longblob & "id=1").delete() - dj.blob.use_32bit_dims = False def test_datetime_serialization_speed(): # If this fails that means for some reason deserializing/serializing # np arrays of np.datetime64 types is now slower than regular arrays of datetime + assert not dj.blob.use_32bit_dims, "32 bit dims should be off for this test" + context = dict( + np=np, + datetime=datetime, + pack=pack, + unpack=unpack, + ) optimized_exe_time = timeit.timeit( setup="myarr=pack(np.array([np.datetime64('2022-10-13 03:03:13') for _ in range(0, 10000)]))", stmt="unpack(myarr)", number=10, - globals=globals(), + globals=context, ) print(f"np time {optimized_exe_time}") baseline_exe_time = timeit.timeit( setup="myarr2=pack(np.array([datetime(2022,10,13,3,3,13) for _ in range (0, 10000)]))", stmt="unpack(myarr2)", number=10, - globals=globals(), + globals=context, ) print(f"python time {baseline_exe_time}") diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py index 575e6b0b..8e467cf0 100644 --- a/tests/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -4,8 +4,6 @@ from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal -from . import PREFIX - class Blob(dj.Manual): definition = """ # diverse types of blobs @@ -16,168 +14,158 @@ class Blob(dj.Manual): """ -@pytest.fixture -def schema(connection_test): - schema = dj.Schema(PREFIX + "_test1", dict(Blob=Blob), connection=connection_test) - schema(Blob) - yield schema - schema.drop() - +def insert_blobs(schema): + """ + This function inserts blobs resulting from the following datajoint-matlab code: + + self.insert({ + 1 'simple string' 'character string' + 2 '1D vector' 1:15:180 + 3 'string array' {'string1' 'string2'} + 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + 5 '3D double array' reshape(1:24, [2,3,4]) + 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) + 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) + }) + + and then dumped using the command + mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql + """ -@pytest.fixture -def insert_blobs_func(schema): - def insert_blobs(): - """ - This function inserts blobs resulting from the following datajoint-matlab code: - - self.insert({ - 1 'simple string' 'character string' - 2 '1D vector' 1:15:180 - 3 'string array' {'string1' 'string2'} - 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - 5 '3D double array' reshape(1:24, [2,3,4]) - 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) - 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) - }) - - and then dumped using the command - mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql + schema.connection.query( """ - - schema.connection.query( - """ - INSERT INTO {table_name} VALUES - (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), - (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), - (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), - (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), - (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), - (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), - (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 - ); - """.format( - table_name=Blob.full_table_name - ) + INSERT INTO {table_name} VALUES + (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), + (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), + (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), + (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), + (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), + (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), + (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 + ); + """.format( + table_name=Blob.full_table_name ) + ) - yield insert_blobs + +@pytest.fixture +def schema_blob(connection_test, prefix): + schema = dj.Schema(prefix + "_test1", dict(Blob=Blob), connection=connection_test) + schema(Blob) + yield schema + schema.drop() @pytest.fixture -def setup_class(schema, insert_blobs_func): +def schema_blob_pop(schema_blob): assert not dj.config["safemode"], "safemode must be disabled" Blob().delete() - insert_blobs_func() + insert_blobs(schema_blob) + return schema_blob -class TestFetch: - @staticmethod - def test_complex_matlab_blobs(setup_class): - """ - test correct de-serialization of various blob types - """ - blobs = Blob().fetch("blob", order_by="KEY") - - blob = blobs[0] # 'simple string' 'character string' - assert blob[0] == "character string" - - blob = blobs[1] # '1D vector' 1:15:180 - assert_array_equal(blob, np.r_[1:180:15][None, :]) - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[2] # 'string array' {'string1' 'string2'} - assert isinstance(blob, dj.MatCell) - assert_array_equal(blob, np.array([["string1", "string2"]])) - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[ - 3 - ] # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert isinstance(blob, dj.MatStruct) - assert tuple(blob.dtype.names) == ("a", "b") - assert_array_equal(blob.a[0, 0], np.array([[1.0]])) - assert_array_equal(blob.a[0, 1], np.array([[2.0]])) - assert isinstance(blob.b[0, 1], dj.MatStruct) - assert tuple(blob.b[0, 1].C[0, 0].shape) == (5, 5) - b = unpack(pack(blob)) - assert_array_equal(b[0, 0].b[0, 0].c, blob[0, 0].b[0, 0].c) - assert_array_equal(b[0, 1].b[0, 0].C, blob[0, 1].b[0, 0].C) - - blob = blobs[4] # '3D double array' reshape(1:24, [2,3,4]) - assert_array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert blob.dtype == "float64" - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[5] # reshape(uint8(1:24), [2,3,4]) - assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert blob.dtype == "uint8" - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[6] # fftn(reshape(1:24, [2,3,4])) - assert tuple(blob.shape) == (2, 3, 4) - assert blob.dtype == "complex128" - assert_array_equal(blob, unpack(pack(blob))) - - @staticmethod - def test_complex_matlab_squeeze(setup_class): - """ - test correct de-serialization of various blob types - """ - blob = (Blob & "id=1").fetch1( - "blob", squeeze=True - ) # 'simple string' 'character string' - assert blob == "character string" - - blob = (Blob & "id=2").fetch1( - "blob", squeeze=True - ) # '1D vector' 1:15:180 - assert_array_equal(blob, np.r_[1:180:15]) - - blob = (Blob & "id=3").fetch1( - "blob", squeeze=True - ) # 'string array' {'string1' 'string2'} - assert isinstance(blob, dj.MatCell) - assert_array_equal(blob, np.array(["string1", "string2"])) - - blob = (Blob & "id=4").fetch1( - "blob", squeeze=True - ) # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert isinstance(blob, dj.MatStruct) - assert tuple(blob.dtype.names) == ("a", "b") - assert_array_equal( - blob.a, - np.array( - [ - 1.0, - 2, - ] - ), - ) - assert isinstance(blob[1].b, dj.MatStruct) - assert tuple(blob[1].b.C.item().shape) == (5, 5) - - blob = (Blob & "id=5").fetch1( - "blob", squeeze=True - ) # '3D double array' reshape(1:24, [2,3,4]) - assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert blob.dtype == "float64" - - blob = (Blob & "id=6").fetch1( - "blob", squeeze=True - ) # reshape(uint8(1:24), [2,3,4]) - assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert blob.dtype == "uint8" - - blob = (Blob & "id=7").fetch1( - "blob", squeeze=True - ) # fftn(reshape(1:24, [2,3,4])) - assert tuple(blob.shape) == (2, 3, 4) - assert blob.dtype == "complex128" - - def test_iter(self, setup_class): - """ - test iterator over the entity set - """ - from_iter = {d["id"]: d for d in Blob()} - assert len(from_iter) == len(Blob()) - assert from_iter[1]["blob"] == "character string" +def test_complex_matlab_blobs(schema_blob_pop): + """ + test correct de-serialization of various blob types + """ + blobs = Blob().fetch("blob", order_by="KEY") + + blob = blobs[0] # 'simple string' 'character string' + assert blob[0] == "character string" + + blob = blobs[1] # '1D vector' 1:15:180 + assert_array_equal(blob, np.r_[1:180:15][None, :]) + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[2] # 'string array' {'string1' 'string2'} + assert isinstance(blob, dj.MatCell) + assert_array_equal(blob, np.array([["string1", "string2"]])) + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[ + 3 + ] # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") + assert_array_equal(blob.a[0, 0], np.array([[1.0]])) + assert_array_equal(blob.a[0, 1], np.array([[2.0]])) + assert isinstance(blob.b[0, 1], dj.MatStruct) + assert tuple(blob.b[0, 1].C[0, 0].shape) == (5, 5) + b = unpack(pack(blob)) + assert_array_equal(b[0, 0].b[0, 0].c, blob[0, 0].b[0, 0].c) + assert_array_equal(b[0, 1].b[0, 0].C, blob[0, 1].b[0, 0].C) + + blob = blobs[4] # '3D double array' reshape(1:24, [2,3,4]) + assert_array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "float64" + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[5] # reshape(uint8(1:24), [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[6] # fftn(reshape(1:24, [2,3,4])) + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" + assert_array_equal(blob, unpack(pack(blob))) + + +def test_complex_matlab_squeeze(schema_blob_pop): + """ + test correct de-serialization of various blob types + """ + blob = (Blob & "id=1").fetch1( + "blob", squeeze=True + ) # 'simple string' 'character string' + assert blob == "character string" + + blob = (Blob & "id=2").fetch1("blob", squeeze=True) # '1D vector' 1:15:180 + assert_array_equal(blob, np.r_[1:180:15]) + + blob = (Blob & "id=3").fetch1( + "blob", squeeze=True + ) # 'string array' {'string1' 'string2'} + assert isinstance(blob, dj.MatCell) + assert_array_equal(blob, np.array(["string1", "string2"])) + + blob = (Blob & "id=4").fetch1( + "blob", squeeze=True + ) # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") + assert_array_equal( + blob.a, + np.array( + [ + 1.0, + 2, + ] + ), + ) + assert isinstance(blob[1].b, dj.MatStruct) + assert tuple(blob[1].b.C.item().shape) == (5, 5) + + blob = (Blob & "id=5").fetch1( + "blob", squeeze=True + ) # '3D double array' reshape(1:24, [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "float64" + + blob = (Blob & "id=6").fetch1("blob", squeeze=True) # reshape(uint8(1:24), [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" + + blob = (Blob & "id=7").fetch1("blob", squeeze=True) # fftn(reshape(1:24, [2,3,4])) + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" + + +def test_iter(schema_blob_pop): + """ + test iterator over the entity set + """ + from_iter = {d["id"]: d for d in Blob()} + assert len(from_iter) == len(Blob()) + assert from_iter[1]["blob"] == "character string" diff --git a/tests/test_bypass_serialization.py b/tests/test_bypass_serialization.py index 5f73e1d2..90fc3509 100644 --- a/tests/test_bypass_serialization.py +++ b/tests/test_bypass_serialization.py @@ -1,7 +1,6 @@ import pytest import datajoint as dj import numpy as np -from . import PREFIX from numpy.testing import assert_array_equal test_blob = np.array([1, 2, 3]) @@ -25,9 +24,9 @@ class Output(dj.Manual): @pytest.fixture -def schema_in(connection_test): +def schema_in(connection_test, prefix): schema = dj.Schema( - PREFIX + "_test_bypass_serialization_in", + prefix + "_test_bypass_serialization_in", context=dict(Input=Input), connection=connection_test, ) @@ -37,9 +36,9 @@ def schema_in(connection_test): @pytest.fixture -def schema_out(connection_test): +def schema_out(connection_test, prefix): schema = dj.Schema( - PREFIX + "_test_blob_bypass_serialization_out", + prefix + "_test_blob_bypass_serialization_out", context=dict(Output=Output), connection=connection_test, ) diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py index 8646edec..70fedf68 100644 --- a/tests/test_cascading_delete.py +++ b/tests/test_cascading_delete.py @@ -14,106 +14,110 @@ def schema_simp_pop(schema_simp): yield schema_simp -class TestDelete: - def test_delete_tree(self, schema_simp_pop): - assert not dj.config["safemode"], "safemode must be off for testing" - assert ( - L() and A() and B() and B.C() and D() and E() and E.F(), - "schema is not populated", - ) - A().delete() - assert not A() or B() or B.C() or D() or E() or E.F(), "incomplete delete" - - def test_stepwise_delete(self, schema_simp_pop): - assert not dj.config["safemode"], "safemode must be off for testing" - assert L() and A() and B() and B.C(), "schema population failed" - B.C().delete(force=True) - assert not B.C(), "failed to delete child tables" - B().delete() - assert ( - not B() - ), "failed to delete from the parent table following child table deletion" - - def test_delete_tree_restricted(self, schema_simp_pop): - assert not dj.config["safemode"], "safemode must be off for testing" - assert ( - L() and A() and B() and B.C() and D() and E() and E.F() - ), "schema is not populated" - cond = "cond_in_a" - rel = A() & cond - rest = dict( - A=len(A()) - len(rel), - B=len(B() - rel), - C=len(B.C() - rel), - D=len(D() - rel), - E=len(E() - rel), - F=len(E.F() - rel), - ) - rel.delete() - assert not ( - rel or B() & rel or B.C() & rel or D() & rel or E() & rel or (E.F() & rel) - ), "incomplete delete" - assert len(A()) == rest["A"], "invalid delete restriction" - assert len(B()) == rest["B"], "invalid delete restriction" - assert len(B.C()) == rest["C"], "invalid delete restriction" - assert len(D()) == rest["D"], "invalid delete restriction" - assert len(E()) == rest["E"], "invalid delete restriction" - assert len(E.F()) == rest["F"], "invalid delete restriction" - - def test_delete_lookup(self, schema_simp_pop): - assert not dj.config["safemode"], "safemode must be off for testing" - assert ( - bool(L() and A() and B() and B.C() and D() and E() and E.F()), - "schema is not populated", - ) - L().delete() - assert not bool(L() or D() or E() or E.F()), "incomplete delete" - A().delete() # delete all is necessary because delete L deletes from subtables. - - def test_delete_lookup_restricted(self, schema_simp_pop): - assert not dj.config["safemode"], "safemode must be off for testing" - assert ( - L() and A() and B() and B.C() and D() and E() and E.F(), - "schema is not populated", - ) - rel = L() & "cond_in_l" - original_count = len(L()) - deleted_count = len(rel) - rel.delete() - assert len(L()) == original_count - deleted_count - - def test_delete_complex_keys(self, schema_any): - """ - https://github.com/datajoint/datajoint-python/issues/883 - https://github.com/datajoint/datajoint-python/issues/886 - """ - assert not dj.config["safemode"], "safemode must be off for testing" - parent_key_count = 8 - child_key_count = 1 - restriction = dict( - {"parent_id_{}".format(i + 1): i for i in range(parent_key_count)}, - **{ - "child_id_{}".format(i + 1): (i + parent_key_count) - for i in range(child_key_count) - } - ) - assert len(ComplexParent & restriction) == 1, "Parent record missing" - assert len(ComplexChild & restriction) == 1, "Child record missing" - (ComplexParent & restriction).delete() - assert len(ComplexParent & restriction) == 0, "Parent record was not deleted" - assert len(ComplexChild & restriction) == 0, "Child record was not deleted" - - def test_delete_master(self, schema_simp_pop): +def test_delete_tree(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert ( + L() and A() and B() and B.C() and D() and E() and E.F() + ), "schema is not populated" + A().delete() + assert not A() or B() or B.C() or D() or E() or E.F(), "incomplete delete" + + +def test_stepwise_delete(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert L() and A() and B() and B.C(), "schema population failed" + B.C().delete(force=True) + assert not B.C(), "failed to delete child tables" + B().delete() + assert ( + not B() + ), "failed to delete from the parent table following child table deletion" + + +def test_delete_tree_restricted(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert ( + L() and A() and B() and B.C() and D() and E() and E.F() + ), "schema is not populated" + cond = "cond_in_a" + rel = A() & cond + rest = dict( + A=len(A()) - len(rel), + B=len(B() - rel), + C=len(B.C() - rel), + D=len(D() - rel), + E=len(E() - rel), + F=len(E.F() - rel), + ) + rel.delete() + assert not ( + rel or B() & rel or B.C() & rel or D() & rel or E() & rel or (E.F() & rel) + ), "incomplete delete" + assert len(A()) == rest["A"], "invalid delete restriction" + assert len(B()) == rest["B"], "invalid delete restriction" + assert len(B.C()) == rest["C"], "invalid delete restriction" + assert len(D()) == rest["D"], "invalid delete restriction" + assert len(E()) == rest["E"], "invalid delete restriction" + assert len(E.F()) == rest["F"], "invalid delete restriction" + + +def test_delete_lookup(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert bool( + L() and A() and B() and B.C() and D() and E() and E.F() + ), "schema is not populated" + L().delete() + assert not bool(L() or D() or E() or E.F()), "incomplete delete" + A().delete() # delete all is necessary because delete L deletes from subtables. + + +def test_delete_lookup_restricted(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert ( + L() and A() and B() and B.C() and D() and E() and E.F() + ), "schema is not populated" + rel = L() & "cond_in_l" + original_count = len(L()) + deleted_count = len(rel) + rel.delete() + assert len(L()) == original_count - deleted_count + + +def test_delete_complex_keys(schema_any): + """ + https://github.com/datajoint/datajoint-python/issues/883 + https://github.com/datajoint/datajoint-python/issues/886 + """ + assert not dj.config["safemode"], "safemode must be off for testing" + parent_key_count = 8 + child_key_count = 1 + restriction = dict( + {"parent_id_{}".format(i + 1): i for i in range(parent_key_count)}, + **{ + "child_id_{}".format(i + 1): (i + parent_key_count) + for i in range(child_key_count) + } + ) + assert len(ComplexParent & restriction) == 1, "Parent record missing" + assert len(ComplexChild & restriction) == 1, "Child record missing" + (ComplexParent & restriction).delete() + assert len(ComplexParent & restriction) == 0, "Parent record was not deleted" + assert len(ComplexChild & restriction) == 0, "Child record was not deleted" + + +def test_delete_master(schema_simp_pop): + Profile().populate_random() + Profile().delete() + + +def test_delete_parts(schema_simp_pop): + """test issue #151""" + with pytest.raises(dj.DataJointError): Profile().populate_random() - Profile().delete() - - def test_delete_parts(self, schema_simp_pop): - """test issue #151""" - with pytest.raises(dj.DataJointError): - Profile().populate_random() - Website().delete() - - def test_drop_part(self, schema_simp_pop): - """test issue #374""" - with pytest.raises(dj.DataJointError): - Website().drop() + Website().delete() + + +def test_drop_part(schema_simp_pop): + """test issue #374""" + with pytest.raises(dj.DataJointError): + Website().drop() diff --git a/tests/test_connection.py b/tests/test_connection.py index 8cdbbbff..49725575 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,41 +5,36 @@ import datajoint as dj from datajoint import DataJointError import numpy as np -from . import CONN_INFO_ROOT -from . import PREFIX import pytest +class Subjects(dj.Manual): + definition = """ + #Basic subject + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + + @pytest.fixture -def schema(connection_test): +def schema_tx(connection_test, prefix): schema = dj.Schema( - PREFIX + "_transactions", context=dict(), connection=connection_test + prefix + "_transactions", + context=dict(Subjects=Subjects), + connection=connection_test, ) + schema(Subjects) yield schema schema.drop() -@pytest.fixture -def Subjects(schema): - @schema - class Subjects(dj.Manual): - definition = """ - #Basic subject - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - - yield Subjects - Subjects.drop() - - -def test_dj_conn(): +def test_dj_conn(db_creds_root): """ Should be able to establish a connection as root user """ - c = dj.conn(**CONN_INFO_ROOT) + c = dj.conn(**db_creds_root) assert c.is_connected @@ -50,24 +45,24 @@ def test_dj_connection_class(connection_test): assert connection_test.is_connected -def test_persistent_dj_conn(): +def test_persistent_dj_conn(db_creds_root): """ conn() method should provide persistent connection across calls. Setting reset=True should create a new persistent connection. """ - c1 = dj.conn(**CONN_INFO_ROOT) + c1 = dj.conn(**db_creds_root) c2 = dj.conn() - c3 = dj.conn(**CONN_INFO_ROOT) - c4 = dj.conn(reset=True, **CONN_INFO_ROOT) - c5 = dj.conn(**CONN_INFO_ROOT) + c3 = dj.conn(**db_creds_root) + c4 = dj.conn(reset=True, **db_creds_root) + c5 = dj.conn(**db_creds_root) assert c1 is c2 assert c1 is c3 assert c1 is not c4 assert c4 is c5 -def test_repr(): - c1 = dj.conn(**CONN_INFO_ROOT) +def test_repr(db_creds_root): + c1 = dj.conn(**db_creds_root) assert "disconnected" not in repr(c1) and "connected" in repr(c1) @@ -76,7 +71,7 @@ def test_active(connection_test): assert conn.in_transaction, "Transaction is not active" -def test_transaction_rollback(connection_test, Subjects): +def test_transaction_rollback(schema_tx, connection_test): """Test transaction cancellation using a with statement""" tmp = np.array( [(1, "Peter", "mouse"), (2, "Klara", "monkey")], @@ -101,13 +96,13 @@ def test_transaction_rollback(connection_test, Subjects): ), "Length is not 0. Expected because rollback should have happened." -def test_cancel(connection_test, Subjects): +def test_cancel(schema_tx, connection_test): """Tests cancelling a transaction explicitly""" tmp = np.array( [(1, "Peter", "mouse"), (2, "Klara", "monkey")], - Subjects.heading.as_dtype, + Subjects().heading.as_dtype, ) - Subjects.delete_quick() + Subjects().delete_quick() Subjects.insert1(tmp[0]) connection_test.start_transaction() Subjects.insert1(tmp[1]) diff --git a/tests/test_declare.py b/tests/test_declare.py index a88d396e..dfca54c2 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -5,320 +5,335 @@ from datajoint.declare import declare -class TestDeclare: - def test_schema_decorator(self, schema_any): - assert issubclass(Subject, dj.Lookup) - assert not issubclass(Subject, dj.Part) - - def test_class_help(self, schema_any): - help(TTest) - help(TTest2) - assert TTest.definition in TTest.__doc__ - assert TTest.definition in TTest2.__doc__ - - def test_instance_help(self, schema_any): - help(TTest()) - help(TTest2()) - assert TTest().definition in TTest().__doc__ - assert TTest2().definition in TTest2().__doc__ - - def test_describe(self, schema_any): - """real_definition should match original definition""" - rel = Experiment() - context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 - - def test_describe_indexes(self, schema_any): - """real_definition should match original definition""" - rel = IndexRich() - context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 - - def test_describe_dependencies(self, schema_any): - """real_definition should match original definition""" - rel = ThingC() - context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 - - def test_part(self, schema_any): +def test_schema_decorator(schema_any): + assert issubclass(Subject, dj.Lookup) + assert not issubclass(Subject, dj.Part) + + +def test_class_help(schema_any): + help(TTest) + help(TTest2) + assert TTest.definition in TTest.__doc__ + assert TTest.definition in TTest2.__doc__ + + +def test_instance_help(schema_any): + help(TTest()) + help(TTest2()) + assert TTest().definition in TTest().__doc__ + assert TTest2().definition in TTest2().__doc__ + + +def test_describe(schema_any): + """real_definition should match original definition""" + rel = Experiment() + context = inspect.currentframe().f_globals + s1 = declare(rel.full_table_name, rel.definition, context) + s2 = declare(rel.full_table_name, rel.describe(), context) + assert s1 == s2 + + +def test_describe_indexes(schema_any): + """real_definition should match original definition""" + rel = IndexRich() + context = inspect.currentframe().f_globals + s1 = declare(rel.full_table_name, rel.definition, context) + s2 = declare(rel.full_table_name, rel.describe(), context) + assert s1 == s2 + + +def test_describe_dependencies(schema_any): + """real_definition should match original definition""" + rel = ThingC() + context = inspect.currentframe().f_globals + s1 = declare(rel.full_table_name, rel.definition, context) + s2 = declare(rel.full_table_name, rel.describe(), context) + assert s1 == s2 + + +def test_part(schema_any): + """ + Lookup and part with the same name. See issue #365 + """ + local_schema = dj.Schema(schema_any.database) + + @local_schema + class Type(dj.Lookup): + definition = """ + type : varchar(255) """ - Lookup and part with the same name. See issue #365 + contents = zip(("Type1", "Type2", "Type3")) + + @local_schema + class TypeMaster(dj.Manual): + definition = """ + master_id : int """ - local_schema = dj.Schema(schema_any.database) - @local_schema - class Type(dj.Lookup): + class Type(dj.Part): definition = """ - type : varchar(255) + -> TypeMaster + -> Type """ - contents = zip(("Type1", "Type2", "Type3")) - @local_schema - class TypeMaster(dj.Manual): - definition = """ - master_id : int - """ - class Type(dj.Part): - definition = """ - -> TypeMaster - -> Type - """ +def test_attributes(schema_any): + """ + Test autoincrement declaration + """ + auto = Auto() + auto.fill() + subject = Subject() + experiment = Experiment() + trial = Trial() + ephys = Ephys() + channel = Ephys.Channel() + + assert auto.heading.names == ["id", "name"] + assert auto.heading.attributes["id"].autoincrement + + # test attribute declarations + assert subject.heading.names == [ + "subject_id", + "real_id", + "species", + "date_of_birth", + "subject_notes", + ] + assert subject.primary_key == ["subject_id"] + assert subject.heading.attributes["subject_id"].numeric + assert not subject.heading.attributes["real_id"].numeric + + assert experiment.heading.names == [ + "subject_id", + "experiment_id", + "experiment_date", + "username", + "data_path", + "notes", + "entry_time", + ] + assert experiment.primary_key == ["subject_id", "experiment_id"] + + assert trial.heading.names == [ # tests issue #516 + "animal", + "experiment_id", + "trial_id", + "start_time", + ] + assert trial.primary_key == ["animal", "experiment_id", "trial_id"] + + assert ephys.heading.names == [ + "animal", + "experiment_id", + "trial_id", + "sampling_frequency", + "duration", + ] + assert ephys.primary_key == ["animal", "experiment_id", "trial_id"] + + assert channel.heading.names == [ + "animal", + "experiment_id", + "trial_id", + "channel", + "voltage", + "current", + ] + assert channel.primary_key == ["animal", "experiment_id", "trial_id", "channel"] + assert channel.heading.attributes["voltage"].is_blob + + +def test_dependencies(schema_any): + user = User() + subject = Subject() + experiment = Experiment() + trial = Trial() + ephys = Ephys() + channel = Ephys.Channel() + + assert experiment.full_table_name in user.children(primary=False) + assert set(experiment.parents(primary=False)) == {user.full_table_name} + assert experiment.full_table_name in user.children(primary=False) + assert set(experiment.parents(primary=False)) == {user.full_table_name} + assert set( + s.full_table_name for s in experiment.parents(primary=False, as_objects=True) + ) == {user.full_table_name} + + assert experiment.full_table_name in subject.descendants() + assert experiment.full_table_name in { + s.full_table_name for s in subject.descendants(as_objects=True) + } + assert subject.full_table_name in experiment.ancestors() + assert subject.full_table_name in { + s.full_table_name for s in experiment.ancestors(as_objects=True) + } + + assert trial.full_table_name in experiment.descendants() + assert trial.full_table_name in { + s.full_table_name for s in experiment.descendants(as_objects=True) + } + assert experiment.full_table_name in trial.ancestors() + assert experiment.full_table_name in { + s.full_table_name for s in trial.ancestors(as_objects=True) + } + + assert set(trial.children(primary=True)) == { + ephys.full_table_name, + trial.Condition.full_table_name, + } + assert set(trial.parts()) == {trial.Condition.full_table_name} + assert set(s.full_table_name for s in trial.parts(as_objects=True)) == { + trial.Condition.full_table_name + } + assert set(ephys.parents(primary=True)) == {trial.full_table_name} + assert set( + s.full_table_name for s in ephys.parents(primary=True, as_objects=True) + ) == {trial.full_table_name} + assert set(ephys.children(primary=True)) == {channel.full_table_name} + assert set( + s.full_table_name for s in ephys.children(primary=True, as_objects=True) + ) == {channel.full_table_name} + assert set(channel.parents(primary=True)) == {ephys.full_table_name} + assert set( + s.full_table_name for s in channel.parents(primary=True, as_objects=True) + ) == {ephys.full_table_name} + + +def test_descendants_only_contain_part_table(schema_any): + """issue #927""" + + class A(dj.Manual): + definition = """ + a: int + """ - def test_attributes(self, schema_any): + class B(dj.Manual): + definition = """ + -> A + b: int """ - Test autoincrement declaration + + class Master(dj.Manual): + definition = """ + table_master: int """ - auto = Auto() - auto.fill() - subject = Subject() - experiment = Experiment() - trial = Trial() - ephys = Ephys() - channel = Ephys.Channel() - - assert auto.heading.names == ["id", "name"] - assert auto.heading.attributes["id"].autoincrement - - # test attribute declarations - assert subject.heading.names == [ - "subject_id", - "real_id", - "species", - "date_of_birth", - "subject_notes", - ] - assert subject.primary_key == ["subject_id"] - assert subject.heading.attributes["subject_id"].numeric - assert not subject.heading.attributes["real_id"].numeric - - assert experiment.heading.names == [ - "subject_id", - "experiment_id", - "experiment_date", - "username", - "data_path", - "notes", - "entry_time", - ] - assert experiment.primary_key == ["subject_id", "experiment_id"] - - assert trial.heading.names == [ # tests issue #516 - "animal", - "experiment_id", - "trial_id", - "start_time", - ] - assert trial.primary_key == ["animal", "experiment_id", "trial_id"] - - assert ephys.heading.names == [ - "animal", - "experiment_id", - "trial_id", - "sampling_frequency", - "duration", - ] - assert ephys.primary_key == ["animal", "experiment_id", "trial_id"] - - assert channel.heading.names == [ - "animal", - "experiment_id", - "trial_id", - "channel", - "voltage", - "current", - ] - assert channel.primary_key == ["animal", "experiment_id", "trial_id", "channel"] - assert channel.heading.attributes["voltage"].is_blob - - def test_dependencies(self, schema_any): - user = User() - subject = Subject() - experiment = Experiment() - trial = Trial() - ephys = Ephys() - channel = Ephys.Channel() - - assert experiment.full_table_name in user.children(primary=False) - assert set(experiment.parents(primary=False)) == {user.full_table_name} - assert experiment.full_table_name in user.children(primary=False) - assert set(experiment.parents(primary=False)) == {user.full_table_name} - assert set( - s.full_table_name - for s in experiment.parents(primary=False, as_objects=True) - ) == {user.full_table_name} - - assert experiment.full_table_name in subject.descendants() - assert experiment.full_table_name in { - s.full_table_name for s in subject.descendants(as_objects=True) - } - assert subject.full_table_name in experiment.ancestors() - assert subject.full_table_name in { - s.full_table_name for s in experiment.ancestors(as_objects=True) - } - - assert trial.full_table_name in experiment.descendants() - assert trial.full_table_name in { - s.full_table_name for s in experiment.descendants(as_objects=True) - } - assert experiment.full_table_name in trial.ancestors() - assert experiment.full_table_name in { - s.full_table_name for s in trial.ancestors(as_objects=True) - } - - assert set(trial.children(primary=True)) == { - ephys.full_table_name, - trial.Condition.full_table_name, - } - assert set(trial.parts()) == {trial.Condition.full_table_name} - assert set(s.full_table_name for s in trial.parts(as_objects=True)) == { - trial.Condition.full_table_name - } - assert set(ephys.parents(primary=True)) == {trial.full_table_name} - assert set( - s.full_table_name for s in ephys.parents(primary=True, as_objects=True) - ) == {trial.full_table_name} - assert set(ephys.children(primary=True)) == {channel.full_table_name} - assert set( - s.full_table_name for s in ephys.children(primary=True, as_objects=True) - ) == {channel.full_table_name} - assert set(channel.parents(primary=True)) == {ephys.full_table_name} - assert set( - s.full_table_name for s in channel.parents(primary=True, as_objects=True) - ) == {ephys.full_table_name} - - def test_descendants_only_contain_part_table(self, schema_any): - """issue #927""" - - class A(dj.Manual): - definition = """ - a: int - """ - class B(dj.Manual): + class Part(dj.Part): definition = """ - -> A - b: int + -> master + -> B """ - class Master(dj.Manual): - definition = """ - table_master: int - """ + context = dict(A=A, B=B, Master=Master) + schema_any(A, context=context) + schema_any(B, context=context) + schema_any(Master, context=context) + assert A.descendants() == [ + "`djtest_test1`.`a`", + "`djtest_test1`.`b`", + "`djtest_test1`.`master__part`", + ] + + +def test_bad_attribute_name(schema_any): + class BadName(dj.Manual): + definition = """ + Bad_name : int + """ - class Part(dj.Part): - definition = """ - -> master - -> B - """ - - context = dict(A=A, B=B, Master=Master) - schema_any(A, context=context) - schema_any(B, context=context) - schema_any(Master, context=context) - assert A.descendants() == [ - "`djtest_test1`.`a`", - "`djtest_test1`.`b`", - "`djtest_test1`.`master__part`", - ] - - def test_bad_attribute_name(self, schema_any): - class BadName(dj.Manual): - definition = """ - Bad_name : int - """ + with pytest.raises(dj.DataJointError): + schema_any(BadName) - with pytest.raises(dj.DataJointError): - schema_any(BadName) - def test_bad_fk_rename(self, schema_any): - """issue #381""" +def test_bad_fk_rename(schema_any): + """issue #381""" - class A(dj.Manual): - definition = """ - a : int - """ + class A(dj.Manual): + definition = """ + a : int + """ - class B(dj.Manual): - definition = """ - b -> A # invalid, the new syntax is (b) -> A - """ + class B(dj.Manual): + definition = """ + b -> A # invalid, the new syntax is (b) -> A + """ - schema_any(A) - with pytest.raises(dj.DataJointError): - schema_any(B) + schema_any(A) + with pytest.raises(dj.DataJointError): + schema_any(B) - def test_primary_nullable_foreign_key(self, schema_any): - class Q(dj.Manual): - definition = """ - -> [nullable] Experiment - """ - with pytest.raises(dj.DataJointError): - schema_any(Q) +def test_primary_nullable_foreign_key(schema_any): + class Q(dj.Manual): + definition = """ + -> [nullable] Experiment + """ - def test_invalid_foreign_key_option(self, schema_any): - class R(dj.Manual): - definition = """ - -> Experiment - ---- - -> [optional] User - """ + with pytest.raises(dj.DataJointError): + schema_any(Q) - with pytest.raises(dj.DataJointError): - schema_any(R) - def test_unsupported_datatype(self, schema_any): - class Q(dj.Manual): - definition = """ - experiment : int - --- - description : text - """ +def test_invalid_foreign_key_option(schema_any): + class R(dj.Manual): + definition = """ + -> Experiment + ---- + -> [optional] User + """ - with pytest.raises(dj.DataJointError): - schema_any(Q) + with pytest.raises(dj.DataJointError): + schema_any(R) - def test_int_datatype(self, schema_any): - @schema_any - class Owner(dj.Manual): - definition = """ - ownerid : int - --- - car_count : integer - """ - def test_unsupported_int_datatype(self, schema_any): - class Driver(dj.Manual): - definition = """ - driverid : tinyint - --- - car_count : tinyinteger - """ +def test_unsupported_datatype(schema_any): + class Q(dj.Manual): + definition = """ + experiment : int + --- + description : text + """ + + with pytest.raises(dj.DataJointError): + schema_any(Q) + + +def test_int_datatype(schema_any): + @schema_any + class Owner(dj.Manual): + definition = """ + ownerid : int + --- + car_count : integer + """ - with pytest.raises(dj.DataJointError): - schema_any(Driver) - def test_long_table_name(self, schema_any): +def test_unsupported_int_datatype(schema_any): + class Driver(dj.Manual): + definition = """ + driverid : tinyint + --- + car_count : tinyinteger """ - test issue #205 -- reject table names over 64 characters in length + + with pytest.raises(dj.DataJointError): + schema_any(Driver) + + +def test_long_table_name(schema_any): + """ + test issue #205 -- reject table names over 64 characters in length + """ + + class WhyWouldAnyoneCreateATableNameThisLong(dj.Manual): + definition = """ + master : int """ - class WhyWouldAnyoneCreateATableNameThisLong(dj.Manual): + class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): definition = """ - master : int + -> (master) """ - class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): - definition = """ - -> (master) - """ - - with pytest.raises(dj.DataJointError): - schema_any(WhyWouldAnyoneCreateATableNameThisLong) + with pytest.raises(dj.DataJointError): + schema_any(WhyWouldAnyoneCreateATableNameThisLong) diff --git a/tests/test_fetch.py b/tests/test_fetch.py index b1480fa7..4f45ae9e 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -13,33 +13,281 @@ import io -@pytest.fixture -def lang(): - yield schema.Language() +def test_getattribute(subject): + """Testing Fetch.__call__ with attributes""" + list1 = sorted(subject.proj().fetch(as_dict=True), key=itemgetter("subject_id")) + list2 = sorted(subject.fetch(dj.key), key=itemgetter("subject_id")) + for l1, l2 in zip(list1, list2): + assert l1 == l2, "Primary key is not returned correctly" + + tmp = subject.fetch(order_by="subject_id") + + subject_notes, key, real_id = subject.fetch("subject_notes", dj.key, "real_id") + + np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp["subject_notes"])) + np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"])) + list1 = sorted(key, key=itemgetter("subject_id")) + for l1, l2 in zip(list1, list2): + assert l1 == l2, "Primary key is not returned correctly" + + +def test_getattribute_for_fetch1(subject): + """Testing Fetch1.__call__ with attributes""" + assert (subject & "subject_id=10").fetch1("subject_id") == 10 + assert (subject & "subject_id=10").fetch1("subject_id", "species") == ( + 10, + "monkey", + ) + + +def test_order_by(lang, languages): + """Tests order_by sorting order""" + for ord_name, ord_lang in itertools.product(*2 * [["ASC", "DESC"]]): + cur = lang.fetch(order_by=("name " + ord_name, "language " + ord_lang)) + languages.sort(key=itemgetter(1), reverse=ord_lang == "DESC") + languages.sort(key=itemgetter(0), reverse=ord_name == "DESC") + for c, l in zip(cur, languages): + assert np.all( + cc == ll for cc, ll in zip(c, l) + ), "Sorting order is different" + + +def test_order_by_default(lang, languages): + """Tests order_by sorting order with defaults""" + cur = lang.fetch(order_by=("language", "name DESC")) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + for c, l in zip(cur, languages): + assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" + + +def test_limit(lang): + """Test the limit kwarg""" + limit = 4 + cur = lang.fetch(limit=limit) + assert len(cur) == limit, "Length is not correct" + + +def test_order_by_limit(lang, languages): + """Test the combination of order by and limit kwargs""" + cur = lang.fetch(limit=4, order_by=["language", "name DESC"]) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + assert len(cur) == 4, "Length is not correct" + for c, l in list(zip(cur, languages))[:4]: + assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" + + +def test_head_tail(schema_any): + query = schema.User * schema.Language + n = 5 + frame = query.head(n, format="frame") + assert isinstance(frame, pandas.DataFrame) + array = query.head(n, format="array") + assert array.size == n + assert len(frame) == n + assert query.primary_key == frame.index.names + + n = 4 + frame = query.tail(n, format="frame") + array = query.tail(n, format="array") + assert array.size == n + assert len(frame) == n + assert query.primary_key == frame.index.names + + +def test_limit_offset(lang, languages): + """Test the limit and offset kwargs together""" + cur = lang.fetch(offset=2, limit=4, order_by=["language", "name DESC"]) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + assert len(cur) == 4, "Length is not correct" + for c, l in list(zip(cur, languages[2:6])): + assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" + + +def test_iter(lang, languages): + """Test iterator""" + cur = lang.fetch(order_by=["language", "name DESC"]) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + for (name, lang_val), (tname, tlang) in list(zip(cur, languages)): + assert name == tname and lang_val == tlang, "Values are not the same" + # now as dict + cur = lang.fetch(as_dict=True, order_by=("language", "name DESC")) + for row, (tname, tlang) in list(zip(cur, languages)): + assert ( + row["name"] == tname and row["language"] == tlang + ), "Values are not the same" + + +def test_keys(lang, languages): + """test key fetch""" + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + + lang = schema.Language() + cur = lang.fetch("name", "language", order_by=("language", "name DESC")) + cur2 = list(lang.fetch("KEY", order_by=["language", "name DESC"])) + + for c, c2 in zip(zip(*cur), cur2): + assert c == tuple(c2.values()), "Values are not the same" + + +def test_attributes_as_dict(subject): + """ + Issue #595 + """ + attrs = ("species", "date_of_birth") + result = subject.fetch(*attrs, as_dict=True) + assert bool(result) and len(result) == len(subject) + assert set(result[0]) == set(attrs) + + +def test_fetch1_step1(lang, languages): + assert ( + lang.contents + == languages + == [ + ("Fabian", "English"), + ("Edgar", "English"), + ("Dimitri", "English"), + ("Dimitri", "Ukrainian"), + ("Fabian", "German"), + ("Edgar", "Japanese"), + ] + ), "Unexpected contents in Language table" + key = {"name": "Edgar", "language": "Japanese"} + true = languages[-1] + dat = (lang & key).fetch1() + for k, (ke, c) in zip(true, dat.items()): + assert k == c == (lang & key).fetch1(ke), "Values are not the same" + + +def test_misspelled_attribute(schema_any): + with pytest.raises(dj.DataJointError): + f = (schema.Language & 'lang = "ENGLISH"').fetch() + + +def test_repr(subject): + """Test string representation of fetch, returning table preview""" + repr = subject.fetch.__repr__() + n = len(repr.strip().split("\n")) + limit = dj.config["display.limit"] + # 3 lines are used for headers (2) and summary statement (1) + assert n - 3 <= limit + + +def test_fetch_none(lang): + """Test preparing attributes for getitem""" + with pytest.raises(dj.DataJointError): + lang.fetch(None) + +def test_asdict(lang): + """Test returns as dictionaries""" + d = lang.fetch(as_dict=True) + for dd in d: + assert isinstance(dd, dict) -@pytest.fixture -def languages(lang) -> List: - og_contents = lang.contents - languages = og_contents.copy() - yield languages - lang.contents = og_contents +def test_offset(lang, languages): + """Tests offset""" + cur = lang.fetch(limit=4, offset=1, order_by=["language", "name DESC"]) -@pytest.fixture -def subject(): - yield schema.Subject() + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + assert len(cur) == 4, "Length is not correct" + for c, l in list(zip(cur, languages[1:]))[:4]: + assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" -class TestFetch: - def test_getattribute(self, schema_any, subject): - """Testing Fetch.__call__ with attributes""" +def test_limit_warning(lang): + """Tests whether warning is raised if offset is used without limit.""" + logger = logging.getLogger("datajoint") + log_capture = io.StringIO() + stream_handler = logging.StreamHandler(log_capture) + log_format = logging.Formatter( + "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" + ) + stream_handler.setFormatter(log_format) + stream_handler.set_name("test_limit_warning") + logger.addHandler(stream_handler) + lang.fetch(offset=1) + + log_contents = log_capture.getvalue() + log_capture.close() + + for handler in logger.handlers: # Clean up handler + if handler.name == "test_limit_warning": + logger.removeHandler(handler) + assert "[WARNING]: Offset set, but no limit." in log_contents + + +def test_len(lang): + """Tests __len__""" + assert len(lang.fetch()) == len(lang), "__len__ is not behaving properly" + + +def test_fetch1_step2(lang): + """Tests whether fetch1 raises error""" + with pytest.raises(dj.DataJointError): + lang.fetch1() + + +def test_fetch1_step3(lang): + """Tests whether fetch1 raises error""" + with pytest.raises(dj.DataJointError): + lang.fetch1("name") + + +def test_decimal(schema_any): + """Tests that decimal fields are correctly fetched and used in restrictions, see issue #334""" + rel = schema.DecimalPrimaryKey() + assert len(rel.fetch()), "Table DecimalPrimaryKey contents are empty" + rel.insert1([decimal.Decimal("3.1415926")]) + keys = rel.fetch() + assert len(keys) > 0 + assert len(rel & keys[0]) == 1 + keys = rel.fetch(dj.key) + assert len(keys) >= 2 + assert len(rel & keys[1]) == 1 + + +def test_nullable_numbers(schema_any): + """test mixture of values and nulls in numeric attributes""" + table = schema.NullableNumbers() + table.insert( + ( + ( + k, + np.random.randn(), + np.random.randint(-1000, 1000), + np.random.randn(), + ) + for k in range(10) + ) + ) + table.insert1((100, None, None, None)) + f, d, i = table.fetch("fvalue", "dvalue", "ivalue") + assert None in i + assert any(np.isnan(d)) + assert any(np.isnan(f)) + + +def test_fetch_format(subject): + """test fetch_format='frame'""" + with dj.config(fetch_format="frame"): + # test if lists are both dicts list1 = sorted(subject.proj().fetch(as_dict=True), key=itemgetter("subject_id")) list2 = sorted(subject.fetch(dj.key), key=itemgetter("subject_id")) for l1, l2 in zip(list1, list2): assert l1 == l2, "Primary key is not returned correctly" + # tests if pandas dataframe tmp = subject.fetch(order_by="subject_id") + assert isinstance(tmp, pandas.DataFrame) + tmp = tmp.to_records() subject_notes, key, real_id = subject.fetch("subject_notes", dj.key, "real_id") @@ -51,349 +299,95 @@ def test_getattribute(self, schema_any, subject): for l1, l2 in zip(list1, list2): assert l1 == l2, "Primary key is not returned correctly" - def test_getattribute_for_fetch1(self, schema_any, subject): - """Testing Fetch1.__call__ with attributes""" - assert (subject & "subject_id=10").fetch1("subject_id") == 10 - assert (subject & "subject_id=10").fetch1("subject_id", "species") == ( - 10, - "monkey", - ) - - def test_order_by(self, schema_any, lang, languages): - """Tests order_by sorting order""" - for ord_name, ord_lang in itertools.product(*2 * [["ASC", "DESC"]]): - cur = lang.fetch(order_by=("name " + ord_name, "language " + ord_lang)) - languages.sort(key=itemgetter(1), reverse=ord_lang == "DESC") - languages.sort(key=itemgetter(0), reverse=ord_name == "DESC") - for c, l in zip(cur, languages): - assert np.all( - cc == ll for cc, ll in zip(c, l) - ), "Sorting order is different" - - def test_order_by_default(self, schema_any, lang, languages): - """Tests order_by sorting order with defaults""" - cur = lang.fetch(order_by=("language", "name DESC")) - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - for c, l in zip(cur, languages): - assert np.all( - [cc == ll for cc, ll in zip(c, l)] - ), "Sorting order is different" - - def test_limit(self, schema_any, lang): - """Test the limit kwarg""" - limit = 4 - cur = lang.fetch(limit=limit) - assert len(cur) == limit, "Length is not correct" - - def test_order_by_limit(self, schema_any, lang, languages): - """Test the combination of order by and limit kwargs""" - cur = lang.fetch(limit=4, order_by=["language", "name DESC"]) - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - assert len(cur) == 4, "Length is not correct" - for c, l in list(zip(cur, languages))[:4]: - assert np.all( - [cc == ll for cc, ll in zip(c, l)] - ), "Sorting order is different" - - def test_head_tail(self, schema_any): - query = schema.User * schema.Language - n = 5 - frame = query.head(n, format="frame") - assert isinstance(frame, pandas.DataFrame) - array = query.head(n, format="array") - assert array.size == n - assert len(frame) == n - assert query.primary_key == frame.index.names - - n = 4 - frame = query.tail(n, format="frame") - array = query.tail(n, format="array") - assert array.size == n - assert len(frame) == n - assert query.primary_key == frame.index.names - - def test_limit_offset(self, schema_any, lang, languages): - """Test the limit and offset kwargs together""" - cur = lang.fetch(offset=2, limit=4, order_by=["language", "name DESC"]) - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - assert len(cur) == 4, "Length is not correct" - for c, l in list(zip(cur, languages[2:6])): - assert np.all( - [cc == ll for cc, ll in zip(c, l)] - ), "Sorting order is different" - - def test_iter(self, schema_any, lang, languages): - """Test iterator""" - cur = lang.fetch(order_by=["language", "name DESC"]) - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - for (name, lang_val), (tname, tlang) in list(zip(cur, languages)): - assert name == tname and lang_val == tlang, "Values are not the same" - # now as dict - cur = lang.fetch(as_dict=True, order_by=("language", "name DESC")) - for row, (tname, tlang) in list(zip(cur, languages)): - assert ( - row["name"] == tname and row["language"] == tlang - ), "Values are not the same" - - def test_keys(self, schema_any, lang, languages): - """test key fetch""" - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - - lang = schema.Language() - cur = lang.fetch("name", "language", order_by=("language", "name DESC")) - cur2 = list(lang.fetch("KEY", order_by=["language", "name DESC"])) - - for c, c2 in zip(zip(*cur), cur2): - assert c == tuple(c2.values()), "Values are not the same" - - def test_attributes_as_dict(self, schema_any, subject): - """ - Issue #595 - """ - attrs = ("species", "date_of_birth") - result = subject.fetch(*attrs, as_dict=True) - assert bool(result) and len(result) == len(subject) - assert set(result[0]) == set(attrs) - - def test_fetch1_step1(self, schema_any, lang, languages): - assert ( - lang.contents - == languages - == [ - ("Fabian", "English"), - ("Edgar", "English"), - ("Dimitri", "English"), - ("Dimitri", "Ukrainian"), - ("Fabian", "German"), - ("Edgar", "Japanese"), - ] - ), "Unexpected contents in Language table" - key = {"name": "Edgar", "language": "Japanese"} - true = languages[-1] - dat = (lang & key).fetch1() - for k, (ke, c) in zip(true, dat.items()): - assert k == c == (lang & key).fetch1(ke), "Values are not the same" - - def test_misspelled_attribute(self, schema_any): - with pytest.raises(dj.DataJointError): - f = (schema.Language & 'lang = "ENGLISH"').fetch() - - def test_repr(self, schema_any, subject): - """Test string representation of fetch, returning table preview""" - repr = subject.fetch.__repr__() - n = len(repr.strip().split("\n")) - limit = dj.config["display.limit"] - # 3 lines are used for headers (2) and summary statement (1) - assert n - 3 <= limit - - def test_fetch_none(self, schema_any, lang): - """Test preparing attributes for getitem""" - with pytest.raises(dj.DataJointError): - lang.fetch(None) - - def test_asdict(self, schema_any, lang): - """Test returns as dictionaries""" - d = lang.fetch(as_dict=True) - for dd in d: - assert isinstance(dd, dict) - - def test_offset(self, schema_any, lang, languages): - """Tests offset""" - cur = lang.fetch(limit=4, offset=1, order_by=["language", "name DESC"]) - - languages.sort(key=itemgetter(0), reverse=True) - languages.sort(key=itemgetter(1), reverse=False) - assert len(cur) == 4, "Length is not correct" - for c, l in list(zip(cur, languages[1:]))[:4]: - assert np.all( - [cc == ll for cc, ll in zip(c, l)] - ), "Sorting order is different" - - def test_limit_warning(self, schema_any, lang): - """Tests whether warning is raised if offset is used without limit.""" - logger = logging.getLogger("datajoint") - log_capture = io.StringIO() - stream_handler = logging.StreamHandler(log_capture) - log_format = logging.Formatter( - "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" - ) - stream_handler.setFormatter(log_format) - stream_handler.set_name("test_limit_warning") - logger.addHandler(stream_handler) - lang.fetch(offset=1) - - log_contents = log_capture.getvalue() - log_capture.close() - - for handler in logger.handlers: # Clean up handler - if handler.name == "test_limit_warning": - logger.removeHandler(handler) - assert "[WARNING]: Offset set, but no limit." in log_contents - - def test_len(self, schema_any, lang): - """Tests __len__""" - assert len(lang.fetch()) == len(lang), "__len__ is not behaving properly" - - def test_fetch1_step2(self, schema_any, lang): - """Tests whether fetch1 raises error""" - with pytest.raises(dj.DataJointError): - lang.fetch1() - - def test_fetch1_step3(self, schema_any, lang): - """Tests whether fetch1 raises error""" - with pytest.raises(dj.DataJointError): - lang.fetch1("name") - - def test_decimal(self, schema_any): - """Tests that decimal fields are correctly fetched and used in restrictions, see issue #334""" - rel = schema.DecimalPrimaryKey() - assert len(rel.fetch()), "Table DecimalPrimaryKey contents are empty" - rel.insert1([decimal.Decimal("3.1415926")]) - keys = rel.fetch() - assert len(keys) > 0 - assert len(rel & keys[0]) == 1 - keys = rel.fetch(dj.key) - assert len(keys) >= 2 - assert len(rel & keys[1]) == 1 - - def test_nullable_numbers(self, schema_any): - """test mixture of values and nulls in numeric attributes""" - table = schema.NullableNumbers() - table.insert( - ( - ( - k, - np.random.randn(), - np.random.randint(-1000, 1000), - np.random.randn(), - ) - for k in range(10) - ) - ) - table.insert1((100, None, None, None)) - f, d, i = table.fetch("fvalue", "dvalue", "ivalue") - assert None in i - assert any(np.isnan(d)) - assert any(np.isnan(f)) - - def test_fetch_format(self, schema_any, subject): - """test fetch_format='frame'""" - with dj.config(fetch_format="frame"): - # test if lists are both dicts - list1 = sorted( - subject.proj().fetch(as_dict=True), key=itemgetter("subject_id") - ) - list2 = sorted(subject.fetch(dj.key), key=itemgetter("subject_id")) - for l1, l2 in zip(list1, list2): - assert l1 == l2, "Primary key is not returned correctly" - - # tests if pandas dataframe - tmp = subject.fetch(order_by="subject_id") - assert isinstance(tmp, pandas.DataFrame) - tmp = tmp.to_records() - - subject_notes, key, real_id = subject.fetch( - "subject_notes", dj.key, "real_id" - ) - np.testing.assert_array_equal( - sorted(subject_notes), sorted(tmp["subject_notes"]) - ) - np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"])) - list1 = sorted(key, key=itemgetter("subject_id")) - for l1, l2 in zip(list1, list2): - assert l1 == l2, "Primary key is not returned correctly" - - def test_key_fetch1(self, schema_any, subject): - """test KEY fetch1 - issue #976""" - with dj.config(fetch_format="array"): - k1 = (subject & "subject_id=10").fetch1("KEY") - with dj.config(fetch_format="frame"): - k2 = (subject & "subject_id=10").fetch1("KEY") - assert k1 == k2 - - def test_same_secondary_attribute(self, schema_any): - children = (schema.Child * schema.Parent().proj()).fetch()["name"] - assert len(children) == 1 - assert children[0] == "Dan" - - def test_query_caching(self, schema_any): - # initialize cache directory - os.mkdir(os.path.expanduser("~/dj_query_cache")) - - with dj.config(query_cache=os.path.expanduser("~/dj_query_cache")): - conn = schema.TTest3.connection - # insert sample data and load cache - schema.TTest3.insert([dict(key=100 + i, value=200 + i) for i in range(2)]) - conn.set_query_cache(query_cache="main") - cached_res = schema.TTest3().fetch() - # attempt to insert while caching enabled - try: - schema.TTest3.insert( - [dict(key=200 + i, value=400 + i) for i in range(2)] - ) - assert False, "Insert allowed while query caching enabled" - except dj.DataJointError: - conn.set_query_cache() - # insert new data - schema.TTest3.insert([dict(key=600 + i, value=800 + i) for i in range(2)]) - # re-enable cache to access old results - conn.set_query_cache(query_cache="main") - previous_cache = schema.TTest3().fetch() - # verify properly cached and how to refresh results - assert all([c == p for c, p in zip(cached_res, previous_cache)]) +def test_key_fetch1(subject): + """test KEY fetch1 - issue #976""" + with dj.config(fetch_format="array"): + k1 = (subject & "subject_id=10").fetch1("KEY") + with dj.config(fetch_format="frame"): + k2 = (subject & "subject_id=10").fetch1("KEY") + assert k1 == k2 + + +def test_same_secondary_attribute(schema_any): + children = (schema.Child * schema.Parent().proj()).fetch()["name"] + assert len(children) == 1 + assert children[0] == "Dan" + + +def test_query_caching(schema_any): + # initialize cache directory + os.mkdir(os.path.expanduser("~/dj_query_cache")) + + with dj.config(query_cache=os.path.expanduser("~/dj_query_cache")): + conn = schema.TTest3.connection + # insert sample data and load cache + schema.TTest3.insert([dict(key=100 + i, value=200 + i) for i in range(2)]) + conn.set_query_cache(query_cache="main") + cached_res = schema.TTest3().fetch() + # attempt to insert while caching enabled + try: + schema.TTest3.insert([dict(key=200 + i, value=400 + i) for i in range(2)]) + assert False, "Insert allowed while query caching enabled" + except dj.DataJointError: conn.set_query_cache() - uncached_res = schema.TTest3().fetch() - assert len(uncached_res) > len(cached_res) - # purge query cache - conn.purge_query_cache() - - # reset cache directory state (will fail if purge was unsuccessful) - os.rmdir(os.path.expanduser("~/dj_query_cache")) - - def test_fetch_group_by(self, schema_any): - """ - https://github.com/datajoint/datajoint-python/issues/914 - """ - - assert schema.Parent().fetch("KEY", order_by="name") == [{"parent_id": 1}] - - def test_dj_u_distinct(self, schema_any): - """ - Test developed to see if removing DISTINCT from the select statement - generation breaks the dj.U universal set implementation - """ - - # Contents to be inserted - contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)] - schema.Stimulus.insert(contents) - - # Query the whole table - test_query = schema.Stimulus() - - # Use dj.U to create a list of unique contrast and brightness combinations - result = dj.U("contrast", "brightness") & test_query - expected_result = [ - {"contrast": 2, "brightness": 3}, - {"contrast": 3, "brightness": 2}, - {"contrast": 5, "brightness": 5}, - ] - - fetched_result = result.fetch(as_dict=True, order_by=("contrast", "brightness")) - schema.Stimulus.delete_quick() - assert fetched_result == expected_result - - def test_backslash(self, schema_any): - """ - https://github.com/datajoint/datajoint-python/issues/999 - """ - expected = "She\\Hulk" - schema.Parent.insert([(2, expected)]) - q = schema.Parent & dict(name=expected) - assert q.fetch1("name") == expected - q.delete() + # insert new data + schema.TTest3.insert([dict(key=600 + i, value=800 + i) for i in range(2)]) + # re-enable cache to access old results + conn.set_query_cache(query_cache="main") + previous_cache = schema.TTest3().fetch() + # verify properly cached and how to refresh results + assert all([c == p for c, p in zip(cached_res, previous_cache)]) + conn.set_query_cache() + uncached_res = schema.TTest3().fetch() + assert len(uncached_res) > len(cached_res) + # purge query cache + conn.purge_query_cache() + + # reset cache directory state (will fail if purge was unsuccessful) + os.rmdir(os.path.expanduser("~/dj_query_cache")) + + +def test_fetch_group_by(schema_any): + """ + https://github.com/datajoint/datajoint-python/issues/914 + """ + + assert schema.Parent().fetch("KEY", order_by="name") == [{"parent_id": 1}] + + +def test_dj_u_distinct(schema_any): + """ + Test developed to see if removing DISTINCT from the select statement + generation breaks the dj.U universal set implementation + """ + + # Contents to be inserted + contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)] + schema.Stimulus.insert(contents) + + # Query the whole table + test_query = schema.Stimulus() + + # Use dj.U to create a list of unique contrast and brightness combinations + result = dj.U("contrast", "brightness") & test_query + expected_result = [ + {"contrast": 2, "brightness": 3}, + {"contrast": 3, "brightness": 2}, + {"contrast": 5, "brightness": 5}, + ] + + fetched_result = result.fetch(as_dict=True, order_by=("contrast", "brightness")) + schema.Stimulus.delete_quick() + assert fetched_result == expected_result + + +def test_backslash(schema_any): + """ + https://github.com/datajoint/datajoint-python/issues/999 + """ + expected = "She\\Hulk" + schema.Parent.insert([(2, expected)]) + q = schema.Parent & dict(name=expected) + assert q.fetch1("name") == expected + q.delete() diff --git a/tests/test_fetch_same.py b/tests/test_fetch_same.py index 4935bb03..32d04134 100644 --- a/tests/test_fetch_same.py +++ b/tests/test_fetch_same.py @@ -1,5 +1,4 @@ import pytest -from . import PREFIX, CONN_INFO import numpy as np import datajoint as dj @@ -16,11 +15,11 @@ class ProjData(dj.Manual): @pytest.fixture -def schema_fetch_same(connection_root): +def schema_fetch_same(connection_test, prefix): schema = dj.Schema( - PREFIX + "_fetch_same", + prefix + "_fetch_same", context=dict(ProjData=ProjData), - connection=connection_root, + connection=connection_test, ) schema(ProjData) ProjData().insert( @@ -46,27 +45,24 @@ def schema_fetch_same(connection_root): schema.drop() -@pytest.fixture -def projdata(): - yield ProjData() +def test_object_conversion_one(schema_fetch_same): + new = ProjData().proj(sub="resp").fetch("sub") + assert new.dtype == np.float64 + +def test_object_conversion_two(schema_fetch_same): + [sub, add] = ProjData().proj(sub="resp", add="sim").fetch("sub", "add") + assert sub.dtype == np.float64 + assert add.dtype == np.float64 -class TestFetchSame: - def test_object_conversion_one(self, schema_fetch_same, projdata): - new = projdata.proj(sub="resp").fetch("sub") - assert new.dtype == np.float64 - def test_object_conversion_two(self, schema_fetch_same, projdata): - [sub, add] = projdata.proj(sub="resp", add="sim").fetch("sub", "add") - assert sub.dtype == np.float64 - assert add.dtype == np.float64 +def test_object_conversion_all(schema_fetch_same): + new = ProjData().proj(sub="resp", add="sim").fetch() + assert new["sub"].dtype == np.float64 + assert new["add"].dtype == np.float64 - def test_object_conversion_all(self, schema_fetch_same, projdata): - new = projdata.proj(sub="resp", add="sim").fetch() - assert new["sub"].dtype == np.float64 - assert new["add"].dtype == np.float64 - def test_object_no_convert(self, schema_fetch_same, projdata): - new = projdata.fetch() - assert new["big"].dtype == "object" - assert new["blah"].dtype == "object" +def test_object_no_convert(schema_fetch_same): + new = ProjData().fetch() + assert new["big"].dtype == "object" + assert new["blah"].dtype == "object" diff --git a/tests/test_jobs.py b/tests/test_jobs.py index 37974ac8..9d1d4636 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -6,40 +6,35 @@ import datajoint as dj -@pytest.fixture -def subjects(): - yield schema.Subject() - - -def test_reserve_job(schema_any, subjects): - assert subjects +def test_reserve_job(subject, schema_any): + assert subject table_name = "fake_table" # reserve jobs - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): assert schema_any.jobs.reserve(table_name, key), "failed to reserve a job" # refuse jobs - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): assert not schema_any.jobs.reserve( table_name, key ), "failed to respect reservation" # complete jobs - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): schema_any.jobs.complete(table_name, key) assert not schema_any.jobs, "failed to free jobs" # reserve jobs again - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): assert schema_any.jobs.reserve(table_name, key), "failed to reserve new jobs" # finish with error - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): schema_any.jobs.error(table_name, key, "error message") # refuse jobs with errors - for key in subjects.fetch("KEY"): + for key in subject.fetch("KEY"): assert not schema_any.jobs.reserve( table_name, key ), "failed to ignore error jobs" @@ -95,7 +90,7 @@ def test_suppress_dj_errors(schema_any): assert len(schema.DjExceptionName()) == len(schema_any.jobs) > 0 -def test_long_error_message(schema_any, subjects): +def test_long_error_message(subject, schema_any): # create long error message long_error_message = "".join( random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH + 100) @@ -103,10 +98,10 @@ def test_long_error_message(schema_any, subjects): short_error_message = "".join( random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH // 2) ) - assert subjects + assert subject table_name = "fake_table" - key = subjects.fetch("KEY")[0] + key = subject.fetch("KEY", limit=1)[0] # test long error message schema_any.jobs.reserve(table_name, key) @@ -131,7 +126,7 @@ def test_long_error_message(schema_any, subjects): schema_any.jobs.delete() -def test_long_error_stack(schema_any, subjects): +def test_long_error_stack(subject, schema_any): # create long error stack STACK_SIZE = ( 89942 # Does not fit into small blob (should be 64k, but found to be higher) @@ -139,10 +134,10 @@ def test_long_error_stack(schema_any, subjects): long_error_stack = "".join( random.choice(string.ascii_letters) for _ in range(STACK_SIZE) ) - assert subjects + assert subject table_name = "fake_table" - key = subjects.fetch("KEY")[0] + key = subject.fetch("KEY", limit=1)[0] # test long error stack schema_any.jobs.reserve(table_name, key) diff --git a/tests/test_json.py b/tests/test_json.py index c1caaeed..53016505 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -4,10 +4,9 @@ import datajoint as dj import numpy as np from packaging.version import Version -from . import PREFIX if Version(dj.conn().query("select @@version;").fetchone()[0]) < Version("8.0.0"): - pytest.skip("skipping windows-only tests", allow_module_level=True) + pytest.skip("These tests require MySQL >= v8.0.0", allow_module_level=True) class Team(dj.Lookup): @@ -65,14 +64,16 @@ class Team(dj.Lookup): @pytest.fixture -def schema(connection_test): - schema = dj.Schema(PREFIX + "_json", context=dict(), connection=connection_test) +def schema_json(connection_test, prefix): + schema = dj.Schema( + prefix + "_json", context=dict(Team=Team), connection=connection_test + ) schema(Team) yield schema schema.drop() -def test_insert_update(schema): +def test_insert_update(schema_json): car = { "name": "Discovery", "length": 22.9, @@ -108,7 +109,7 @@ def test_insert_update(schema): assert not q -def test_describe(schema): +def test_describe(schema_json): rel = Team() context = inspect.currentframe().f_globals s1 = declare(rel.full_table_name, rel.definition, context) @@ -116,7 +117,7 @@ def test_describe(schema): assert s1 == s2 -def test_restrict(schema): +def test_restrict(schema_json): # dict assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" @@ -176,7 +177,7 @@ def test_restrict(schema): ).fetch1("name") == "business", "2nd `headlight` object did not match" -def test_proj(schema): +def test_proj(schema_json): # proj necessary since we need to rename indexed value into a proper attribute name assert Team.proj(car_length="car.length").fetch( as_dict=True, order_by="car_length" diff --git a/tests/test_nan.py b/tests/test_nan.py index 299c0d9f..68a28079 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -1,6 +1,5 @@ import numpy as np import datajoint as dj -from . import PREFIX import pytest @@ -12,36 +11,41 @@ class NanTest(dj.Manual): """ -@pytest.fixture(scope="module") -def schema(connection_test): - schema = dj.Schema(PREFIX + "_nantest", connection=connection_test) +@pytest.fixture +def schema_nan(connection_test, prefix): + schema = dj.Schema( + prefix + "_nantest", context=dict(NanTest=NanTest), connection=connection_test + ) schema(NanTest) yield schema schema.drop() -@pytest.fixture(scope="class") -def setup_class(request, schema): +@pytest.fixture +def arr_a(): + return np.array([0, 1 / 3, np.nan, np.pi, np.nan]) + + +@pytest.fixture +def schema_nan_pop(schema_nan, arr_a): rel = NanTest() with dj.config(safemode=False): rel.delete() - a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) - rel.insert(((i, value) for i, value in enumerate(a))) - request.cls.rel = rel - request.cls.a = a - - -class TestNaNInsert: - def test_insert_nan(self, setup_class): - """Test fetching of null values""" - b = self.rel.fetch("value", order_by="id") - assert (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" - assert np.allclose( - self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] - ), "incorrect storage of floats" - - def test_nulls_do_not_affect_primary_keys(self, setup_class): - """Test against a case that previously caused a bug when skipping existing entries.""" - self.rel.insert( - ((i, value) for i, value in enumerate(self.a)), skip_duplicates=True - ) + rel.insert(((i, value) for i, value in enumerate(arr_a))) + return schema_nan + + +def test_insert_nan(schema_nan_pop, arr_a): + """Test fetching of null values""" + b = NanTest().fetch("value", order_by="id") + assert (np.isnan(arr_a) == np.isnan(b)).all(), "incorrect handling of Nans" + assert np.allclose( + arr_a[np.logical_not(np.isnan(arr_a))], b[np.logical_not(np.isnan(b))] + ), "incorrect storage of floats" + + +def test_nulls_do_not_affect_primary_keys(schema_nan_pop, arr_a): + """Test against a case that previously caused a bug when skipping existing entries.""" + NanTest().insert( + ((i, value) for i, value in enumerate(arr_a)), skip_duplicates=True + ) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index ddb8b3bf..95933d2f 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -23,8 +23,7 @@ def test_normal_djerror(): assert e.__cause__ is None -@pytest.mark.parametrize("category", ("connection",)) -def test_verified_djerror(category): +def test_verified_djerror(category="connection"): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( @@ -42,8 +41,7 @@ def test_verified_djerror_type(): test_verified_djerror(category="type") -@pytest.mark.parametrize("category", ("connection",)) -def test_unverified_djerror(category): +def test_unverified_djerror(category="connection"): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( diff --git a/tests/test_privileges.py b/tests/test_privileges.py index 949dbc8a..57880081 100644 --- a/tests/test_privileges.py +++ b/tests/test_privileges.py @@ -1,32 +1,31 @@ import os import pytest import datajoint as dj -from . import schema, CONN_INFO_ROOT, PREFIX -from . import schema_privileges +from . import schema, schema_privileges namespace = locals() @pytest.fixture def schema_priv(connection_test): - schema_priv = dj.Schema( + schema = dj.Schema( context=schema_privileges.LOCALS_PRIV, connection=connection_test, ) - schema_priv(schema_privileges.Parent) - schema_priv(schema_privileges.Child) - schema_priv(schema_privileges.NoAccess) - schema_priv(schema_privileges.NoAccessAgain) - yield schema_priv - if schema_priv.is_activated(): - schema_priv.drop() + schema(schema_privileges.Parent) + schema(schema_privileges.Child) + schema(schema_privileges.NoAccess) + schema(schema_privileges.NoAccessAgain) + yield schema + if schema.is_activated(): + schema.drop() @pytest.fixture -def connection_djsubset(connection_root, db_creds_root, schema_priv): +def connection_djsubset(connection_root, db_creds_root, schema_priv, prefix): user = "djsubset" conn = dj.conn(**db_creds_root, reset=True) - schema_priv.activate(f"{PREFIX}_schema_privileges") + schema_priv.activate(f"{prefix}_schema_privileges") conn.query( f""" CREATE USER IF NOT EXISTS '{user}'@'%%' @@ -36,14 +35,14 @@ def connection_djsubset(connection_root, db_creds_root, schema_priv): conn.query( f""" GRANT SELECT, INSERT, UPDATE, DELETE - ON `{PREFIX}_schema_privileges`.`#parent` + ON `{prefix}_schema_privileges`.`#parent` TO '{user}'@'%%' """ ) conn.query( f""" GRANT SELECT, INSERT, UPDATE, DELETE - ON `{PREFIX}_schema_privileges`.`__child` + ON `{prefix}_schema_privileges`.`__child` TO '{user}'@'%%' """ ) @@ -55,7 +54,7 @@ def connection_djsubset(connection_root, db_creds_root, schema_priv): ) yield conn_djsubset conn.query(f"DROP USER {user}") - conn.query(f"DROP DATABASE {PREFIX}_schema_privileges") + conn.query(f"DROP DATABASE {prefix}_schema_privileges") @pytest.fixture @@ -110,9 +109,9 @@ class Try(dj.Manual): class TestSubset: - def test_populate_activate(self, connection_djsubset, schema_priv): + def test_populate_activate(self, connection_djsubset, schema_priv, prefix): schema_priv.activate( - f"{PREFIX}_schema_privileges", create_schema=True, create_tables=False + f"{prefix}_schema_privileges", create_schema=True, create_tables=False ) schema_privileges.Child.populate() assert schema_privileges.Child.progress(display=False)[0] == 0 diff --git a/tests/test_reconnection.py b/tests/test_reconnection.py index 26253124..5eea4af1 100644 --- a/tests/test_reconnection.py +++ b/tests/test_reconnection.py @@ -5,32 +5,28 @@ import pytest import datajoint as dj from datajoint import DataJointError -from . import CONN_INFO @pytest.fixture -def conn(connection_root): - return dj.conn(reset=True, **CONN_INFO) +def conn(connection_root, db_creds_root): + return dj.conn(reset=True, **db_creds_root) -class TestReconnect: - """ - Test reconnection - """ +def test_close(conn): + assert conn.is_connected, "Connection should be alive" + conn.close() + assert not conn.is_connected, "Connection should now be closed" - def test_close(self, conn): - assert conn.is_connected, "Connection should be alive" - conn.close() - assert not conn.is_connected, "Connection should now be closed" - def test_reconnect(self, conn): - assert conn.is_connected, "Connection should be alive" +def test_reconnect(conn): + assert conn.is_connected, "Connection should be alive" + conn.close() + conn.query("SHOW DATABASES;", reconnect=True).fetchall() + assert conn.is_connected, "Connection should be alive" + + +def test_reconnect_throws_error_in_transaction(conn): + assert conn.is_connected, "Connection should be alive" + with conn.transaction, pytest.raises(DataJointError): conn.close() conn.query("SHOW DATABASES;", reconnect=True).fetchall() - assert conn.is_connected, "Connection should be alive" - - def test_reconnect_throws_error_in_transaction(self, conn): - assert conn.is_connected, "Connection should be alive" - with conn.transaction, pytest.raises(DataJointError): - conn.close() - conn.query("SHOW DATABASES;", reconnect=True).fetchall() diff --git a/tests/test_relation.py b/tests/test_relation.py index 2011a190..169ffc29 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -6,60 +6,9 @@ import datajoint as dj from datajoint.table import Table from unittest.mock import patch - from . import schema -@pytest.fixture -def test(schema_any): - yield schema.TTest() - - -@pytest.fixture -def test2(schema_any): - yield schema.TTest2() - - -@pytest.fixture -def test_extra(schema_any): - yield schema.TTestExtra() - - -@pytest.fixture -def test_no_extra(schema_any): - yield schema.TTestNoExtra() - - -@pytest.fixture -def user(schema_any): - return schema.User() - - -@pytest.fixture -def subject(schema_any): - return schema.Subject() - - -@pytest.fixture -def experiment(schema_any): - return schema.Experiment() - - -@pytest.fixture -def ephys(schema_any): - return schema.Ephys() - - -@pytest.fixture -def img(schema_any): - return schema.Image() - - -@pytest.fixture -def trash(schema_any): - return schema.UberTrash() - - def test_contents(user, subject): """ test the ability of tables to self-populate using the contents property diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py index 50997662..dbb3b673 100644 --- a/tests/test_relation_u.py +++ b/tests/test_relation_u.py @@ -5,83 +5,74 @@ from .schema_simple import * -class TestU: - """ - Test tables: insert, delete - """ - - @classmethod - def setup_class(cls): - cls.user = User() - cls.language = Language() - cls.subject = Subject() - cls.experiment = Experiment() - cls.trial = Trial() - cls.ephys = Ephys() - cls.channel = Ephys.Channel() - cls.img = Image() - cls.trash = UberTrash() - - def test_restriction(self, schema_any): - language_set = {s[1] for s in self.language.contents} - rel = dj.U("language") & self.language - assert list(rel.heading.names) == ["language"] - assert len(rel) == len(language_set) - assert set(rel.fetch("language")) == language_set - # Test for issue #342 - rel = self.trial * dj.U("start_time") - assert list(rel.primary_key) == self.trial.primary_key + ["start_time"] - assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) - assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] - - def test_invalid_restriction(self, schema_any): - with raises(dj.DataJointError): - result = dj.U("color") & dict(color="red") - - def test_ineffective_restriction(self, schema_any): - rel = self.language & dj.U("language") - assert rel.make_sql() == self.language.make_sql() - - def test_join(self, schema_any): - rel = self.experiment * dj.U("experiment_date") - assert self.experiment.primary_key == ["subject_id", "experiment_id"] - assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - - rel = dj.U("experiment_date") * self.experiment - assert self.experiment.primary_key == ["subject_id", "experiment_id"] - assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - - def test_invalid_join(self, schema_any): - with raises(dj.DataJointError): - rel = dj.U("language") * dict(language="English") - - def test_repr_without_attrs(self, schema_any): - """test dj.U() display""" - query = dj.U().aggr(Language, n="count(*)") - repr(query) - - def test_aggregations(self, schema_any): - lang = Language() - # test total aggregation on expression object - n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") - assert n1 == len(lang.fetch()) - # test total aggregation on expression class - n2 = dj.U().aggr(Language, n="count(*)").fetch1("n") - assert n1 == n2 - rel = dj.U("language").aggr(Language, number_of_speakers="count(*)") - assert len(rel) == len(set(l[1] for l in Language.contents)) - assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 - - def test_argmax(self, schema_any): - rel = TTest() - # get the tuples corresponding to the maximum value - mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" - assert mx.fetch("value")[0] == max(rel.fetch("value")) - - def test_aggr(self, schema_any, schema_simp): - rel = ArgmaxTest() - amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") - amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") - assert ( - len(amax1) == len(amax2) == rel.n - ), "Aggregated argmax with join and restriction does not yield the same length." +def test_restriction(lang, languages, trial): + language_set = {s[1] for s in languages} + rel = dj.U("language") & lang + assert list(rel.heading.names) == ["language"] + assert len(rel) == len(language_set) + assert set(rel.fetch("language")) == language_set + # Test for issue #342 + rel = trial * dj.U("start_time") + assert list(rel.primary_key) == trial.primary_key + ["start_time"] + assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) + assert list((dj.U("start_time") & trial).primary_key) == ["start_time"] + + +def test_invalid_restriction(schema_any): + with raises(dj.DataJointError): + result = dj.U("color") & dict(color="red") + + +def test_ineffective_restriction(lang): + rel = lang & dj.U("language") + assert rel.make_sql() == lang.make_sql() + + +def test_join(experiment): + rel = experiment * dj.U("experiment_date") + assert experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == experiment.primary_key + ["experiment_date"] + + rel = dj.U("experiment_date") * experiment + assert experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == experiment.primary_key + ["experiment_date"] + + +def test_invalid_join(schema_any): + with raises(dj.DataJointError): + rel = dj.U("language") * dict(language="English") + + +def test_repr_without_attrs(schema_any): + """test dj.U() display""" + query = dj.U().aggr(Language, n="count(*)") + repr(query) + + +def test_aggregations(schema_any): + lang = Language() + # test total aggregation on expression object + n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") + assert n1 == len(lang.fetch()) + # test total aggregation on expression class + n2 = dj.U().aggr(Language, n="count(*)").fetch1("n") + assert n1 == n2 + rel = dj.U("language").aggr(Language, number_of_speakers="count(*)") + assert len(rel) == len(set(l[1] for l in Language.contents)) + assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 + + +def test_argmax(schema_any): + rel = TTest() + # get the tuples corresponding to the maximum value + mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" + assert mx.fetch("value")[0] == max(rel.fetch("value")) + + +def test_aggr(schema_any, schema_simp): + rel = ArgmaxTest() + amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") + amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") + assert ( + len(amax1) == len(amax2) == rel.n + ), "Aggregated argmax with join and restriction does not yield the same length." diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 06adee5c..65c6a5d7 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -5,34 +5,8 @@ import datetime import numpy as np import datajoint as dj -from .schema_simple import ( - A, - B, - D, - E, - F, - L, - DataA, - DataB, - TTestUpdate, - IJ, - JI, - ReservedWord, - OutfitLaunch, -) -from .schema import ( - Experiment, - TTest3, - Trial, - Ephys, - Child, - Parent, - SubjectA, - SessionA, - SessionStatusA, - SessionDateA, -) -from . import PREFIX, CONN_INFO +from .schema_simple import * +from .schema import * @pytest.fixture @@ -214,8 +188,10 @@ def test_project(schema_simp_pop): ) -def test_rename_non_dj_attribute(connection_test, schema_simp_pop, schema_any_pop): - schema = PREFIX + "_test1" +def test_rename_non_dj_attribute( + connection_test, schema_simp_pop, schema_any_pop, prefix +): + schema = prefix + "_test1" connection_test.query( f"CREATE TABLE {schema}.test_table (oldID int PRIMARY KEY)" ).fetchall() diff --git a/tests/test_s3.py b/tests/test_s3.py index 090d6acf..b5babdd8 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1,50 +1,48 @@ import pytest -import urllib3 -import certifi from .schema_external import SimpleRemote from datajoint.errors import DataJointError from datajoint.hash import uuid_from_buffer from datajoint.blob import pack -from . import S3_CONN_INFO from minio import Minio -class TestS3: - def test_connection(self, http_client, minio_client): - assert minio_client.bucket_exists(S3_CONN_INFO["bucket"]) +def test_connection(http_client, minio_client, s3_creds): + assert minio_client.bucket_exists(s3_creds["bucket"]) - def test_connection_secure(self, minio_client): - assert minio_client.bucket_exists(S3_CONN_INFO["bucket"]) - def test_remove_object_exception(self, schema_ext): - # https://github.com/datajoint/datajoint-python/issues/952 +def test_connection_secure(minio_client, s3_creds): + assert minio_client.bucket_exists(s3_creds["bucket"]) - # Insert some test data and remove it so that the external table is populated - test = [1, [1, 2, 3]] - SimpleRemote.insert1(test) - SimpleRemote.delete() - # Save the old external table minio client - old_client = schema_ext.external["share"].s3.client +def test_remove_object_exception(schema_ext, s3_creds): + # https://github.com/datajoint/datajoint-python/issues/952 - # Apply our new minio client which has a user that does not exist - schema_ext.external["share"].s3.client = Minio( - S3_CONN_INFO["endpoint"], - access_key="jeffjeff", - secret_key="jeffjeff", - secure=False, - ) + # Insert some test data and remove it so that the external table is populated + test = [1, [1, 2, 3]] + SimpleRemote.insert1(test) + SimpleRemote.delete() - # This method returns a list of errors - error_list = schema_ext.external["share"].delete( - delete_external_files=True, errors_as_string=False - ) + # Save the old external table minio client + old_client = schema_ext.external["share"].s3.client - # Teardown - schema_ext.external["share"].s3.client = old_client - schema_ext.external["share"].delete(delete_external_files=True) + # Apply our new minio client which has a user that does not exist + schema_ext.external["share"].s3.client = Minio( + s3_creds["endpoint"], + access_key="jeffjeff", + secret_key="jeffjeff", + secure=False, + ) - with pytest.raises(DataJointError): - # Raise the error we want if the error matches the expected uuid - if str(error_list[0][0]) == str(uuid_from_buffer(pack(test[1]))): - raise error_list[0][2] + # This method returns a list of errors + error_list = schema_ext.external["share"].delete( + delete_external_files=True, errors_as_string=False + ) + + # Teardown + schema_ext.external["share"].s3.client = old_client + schema_ext.external["share"].delete(delete_external_files=True) + + with pytest.raises(DataJointError): + # Raise the error we want if the error matches the expected uuid + if str(error_list[0][0]) == str(uuid_from_buffer(pack(test[1]))): + raise error_list[0][2] diff --git a/tests/test_schema.py b/tests/test_schema.py index 7b262204..d9e22089 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2,10 +2,8 @@ import pytest import inspect import datajoint as dj -from unittest.mock import patch from inspect import getmembers from . import schema -from . import PREFIX class Ephys(dj.Imported): @@ -49,10 +47,10 @@ def schema_empty_module(schema_any, schema_empty): @pytest.fixture -def schema_empty(connection_test, schema_any): +def schema_empty(connection_test, schema_any, prefix): context = {**schema.LOCALS_ANY, "Ephys": Ephys} schema_empty = dj.Schema( - PREFIX + "_test1", context=context, connection=connection_test + prefix + "_test1", context=context, connection=connection_test ) schema_empty(Ephys) # load the rest of the classes @@ -145,9 +143,9 @@ def test_unauthorized_database(db_creds_test): ) -def test_drop_database(db_creds_test): +def test_drop_database(db_creds_test, prefix): schema = dj.Schema( - PREFIX + "_drop_test", connection=dj.conn(reset=True, **db_creds_test) + prefix + "_drop_test", connection=dj.conn(reset=True, **db_creds_test) ) assert schema.exists schema.drop() @@ -155,8 +153,8 @@ def test_drop_database(db_creds_test): schema.drop() # should do nothing -def test_overlapping_name(connection_test): - test_schema = dj.Schema(PREFIX + "_overlapping_schema", connection=connection_test) +def test_overlapping_name(connection_test, prefix): + test_schema = dj.Schema(prefix + "_overlapping_schema", connection=connection_test) @test_schema class Unit(dj.Manual): diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py index 1cad98ef..22ed1c2a 100644 --- a/tests/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -1,6 +1,5 @@ -from . import PREFIX -import datajoint as dj import pytest +import datajoint as dj class A(dj.Manual): @@ -34,15 +33,15 @@ class D(B): @pytest.fixture -def schema(connection_test): - schema = dj.Schema(PREFIX + "_keywords", connection=connection_test) +def schema_kwd(connection_test, prefix): + schema = dj.Schema(prefix + "_keywords", connection=connection_test) schema(A) schema(D) yield schema schema.drop() -def test_inherited_part_table(schema): +def test_inherited_part_table(schema_kwd): assert "a_id" in D().heading.attributes assert "b_id" in D().heading.attributes assert "a_id" in D.C().heading.attributes diff --git a/tests/test_university.py b/tests/test_university.py index 956cc506..800ee7cd 100644 --- a/tests/test_university.py +++ b/tests/test_university.py @@ -4,7 +4,7 @@ from datajoint import DataJointError import datajoint as dj from .schema_university import * -from . import PREFIX, schema_university +from . import schema_university def _hash4(table): @@ -32,10 +32,10 @@ def schema_uni_inactive(): @pytest.fixture -def schema_uni(db_creds_test, schema_uni_inactive, connection_test): +def schema_uni(db_creds_test, schema_uni_inactive, connection_test, prefix): # Deferred activation schema_uni_inactive.activate( - PREFIX + "_university", connection=dj.conn(**db_creds_test) + prefix + "_university", connection=dj.conn(**db_creds_test) ) # --------------- Fill University ------------------- test_data_dir = Path(__file__).parent / "data" diff --git a/tests/test_update1.py b/tests/test_update1.py index 07e0e5b8..f29d2ab0 100644 --- a/tests/test_update1.py +++ b/tests/test_update1.py @@ -4,7 +4,6 @@ from pathlib import Path import tempfile import datajoint as dj -from . import PREFIX from datajoint import DataJointError @@ -42,9 +41,9 @@ def mock_stores_update(tmpdir_factory): @pytest.fixture -def schema_update1(connection_test): +def schema_update1(connection_test, prefix): schema = dj.Schema( - PREFIX + "_update1", context=dict(Thing=Thing), connection=connection_test + prefix + "_update1", context=dict(Thing=Thing), connection=connection_test ) schema(Thing) yield schema