Skip to content

Commit

Permalink
Merge pull request #1141 from ethho/dev-tests-cleanup
Browse files Browse the repository at this point in the history
PLAT-107: Clean up
  • Loading branch information
A-Baji authored Jan 2, 2024
2 parents 03db252 + a0a4a96 commit 88c96e6
Show file tree
Hide file tree
Showing 34 changed files with 1,526 additions and 1,599 deletions.
2 changes: 1 addition & 1 deletion LNX-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ services:
interval: 15s
fakeservices.datajoint.io:
<<: *net
image: datajoint/nginx:v0.2.7
image: datajoint/nginx:v0.2.8
environment:
- ADD_db_TYPE=DATABASE
- ADD_db_ENDPOINT=db:3306
Expand Down
26 changes: 0 additions & 26 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +0,0 @@
import datajoint as dj
from packaging import version
import pytest
import os

PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest")

# Connection for testing
CONN_INFO = dict(
host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"),
user=os.environ.get("DJ_TEST_USER", "datajoint"),
password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"),
)

CONN_INFO_ROOT = dict(
host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"),
user=os.environ.get("DJ_USER", "root"),
password=os.environ.get("DJ_PASS", "simple"),
)

S3_CONN_INFO = dict(
endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"),
access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"),
secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"),
bucket=os.environ.get("S3_BUCKET", "datajoint.test"),
)
147 changes: 117 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datajoint as dj
from packaging import version
from typing import Dict
from typing import Dict, List
import os
from os import environ, remove
import minio
Expand All @@ -18,9 +18,6 @@
DataJointError,
)
from . import (
PREFIX,
CONN_INFO,
S3_CONN_INFO,
schema,
schema_simple,
schema_advanced,
Expand All @@ -30,6 +27,11 @@
)


@pytest.fixture(scope="session")
def prefix():
return os.environ.get("DJ_TEST_DB_PREFIX", "djtest")


@pytest.fixture(scope="session")
def monkeysession():
with pytest.MonkeyPatch.context() as mp:
Expand Down Expand Up @@ -81,7 +83,7 @@ def connection_root_bare(db_creds_root):


@pytest.fixture(scope="session")
def connection_root(connection_root_bare):
def connection_root(connection_root_bare, prefix):
"""Root user database connection."""
dj.config["safemode"] = False
conn_root = connection_root_bare
Expand Down Expand Up @@ -136,7 +138,7 @@ def connection_root(connection_root_bare):

# Teardown
conn_root.query("SET FOREIGN_KEY_CHECKS=0")
cur = conn_root.query('SHOW DATABASES LIKE "{}\\_%%"'.format(PREFIX))
cur = conn_root.query('SHOW DATABASES LIKE "{}\\_%%"'.format(prefix))
for db in cur.fetchall():
conn_root.query("DROP DATABASE `{}`".format(db[0]))
conn_root.query("SET FOREIGN_KEY_CHECKS=1")
Expand All @@ -151,9 +153,9 @@ def connection_root(connection_root_bare):


@pytest.fixture(scope="session")
def connection_test(connection_root, db_creds_test):
def connection_test(connection_root, prefix, db_creds_test):
"""Test user database connection."""
database = f"{PREFIX}%%"
database = f"{prefix}%%"
permission = "ALL PRIVILEGES"

# Create MySQL users
Expand Down Expand Up @@ -191,7 +193,17 @@ def connection_test(connection_root, db_creds_test):


@pytest.fixture(scope="session")
def stores_config(tmpdir_factory):
def s3_creds() -> Dict:
return dict(
endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"),
access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"),
secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"),
bucket=os.environ.get("S3_BUCKET", "datajoint.test"),
)


@pytest.fixture(scope="session")
def stores_config(s3_creds, tmpdir_factory):
stores_config = {
"raw": dict(protocol="file", location=tmpdir_factory.mktemp("raw")),
"repo": dict(
Expand All @@ -200,7 +212,7 @@ def stores_config(tmpdir_factory):
location=tmpdir_factory.mktemp("repo"),
),
"repo-s3": dict(
S3_CONN_INFO,
s3_creds,
protocol="s3",
location="dj/repo",
stage=tmpdir_factory.mktemp("repo-s3"),
Expand All @@ -209,7 +221,7 @@ def stores_config(tmpdir_factory):
protocol="file", location=tmpdir_factory.mktemp("local"), subfolding=(1, 1)
),
"share": dict(
S3_CONN_INFO, protocol="s3", location="dj/store/repo", subfolding=(2, 4)
s3_creds, protocol="s3", location="dj/store/repo", subfolding=(2, 4)
),
}
return stores_config
Expand Down Expand Up @@ -238,9 +250,9 @@ def mock_cache(tmpdir_factory):


@pytest.fixture
def schema_any(connection_test):
def schema_any(connection_test, prefix):
schema_any = dj.Schema(
PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test
prefix + "_test1", schema.LOCALS_ANY, connection=connection_test
)
assert schema.LOCALS_ANY, "LOCALS_ANY is empty"
try:
Expand Down Expand Up @@ -292,9 +304,9 @@ def schema_any(connection_test):


@pytest.fixture
def schema_simp(connection_test):
def schema_simp(connection_test, prefix):
schema = dj.Schema(
PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test
prefix + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test
)
schema(schema_simple.IJ)
schema(schema_simple.JI)
Expand All @@ -319,9 +331,9 @@ def schema_simp(connection_test):


@pytest.fixture
def schema_adv(connection_test):
def schema_adv(connection_test, prefix):
schema = dj.Schema(
PREFIX + "_advanced",
prefix + "_advanced",
schema_advanced.LOCALS_ADVANCED,
connection=connection_test,
)
Expand All @@ -339,9 +351,11 @@ def schema_adv(connection_test):


@pytest.fixture
def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache):
def schema_ext(
connection_test, enable_filepath_feature, mock_stores, mock_cache, prefix
):
schema = dj.Schema(
PREFIX + "_extern",
prefix + "_extern",
context=schema_external.LOCALS_EXTERNAL,
connection=connection_test,
)
Expand All @@ -358,9 +372,9 @@ def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache


@pytest.fixture
def schema_uuid(connection_test):
def schema_uuid(connection_test, prefix):
schema = dj.Schema(
PREFIX + "_test1",
prefix + "_test1",
context=schema_uuid_module.LOCALS_UUID,
connection=connection_test,
)
Expand All @@ -386,37 +400,110 @@ def http_client():


@pytest.fixture(scope="session")
def minio_client_bare(http_client):
def minio_client_bare(s3_creds, http_client):
"""Initialize MinIO with an endpoint and access/secret keys."""
client = minio.Minio(
S3_CONN_INFO["endpoint"],
access_key=S3_CONN_INFO["access_key"],
secret_key=S3_CONN_INFO["secret_key"],
s3_creds["endpoint"],
access_key=s3_creds["access_key"],
secret_key=s3_creds["secret_key"],
secure=True,
http_client=http_client,
)
return client


@pytest.fixture(scope="session")
def minio_client(minio_client_bare):
def minio_client(s3_creds, minio_client_bare):
"""Initialize a MinIO client and create buckets for testing session."""
# Setup MinIO bucket
aws_region = "us-east-1"
try:
minio_client_bare.make_bucket(S3_CONN_INFO["bucket"], location=aws_region)
minio_client_bare.make_bucket(s3_creds["bucket"], location=aws_region)
except minio.error.S3Error as e:
if e.code != "BucketAlreadyOwnedByYou":
raise e

yield minio_client_bare

# Teardown S3
objs = list(minio_client_bare.list_objects(S3_CONN_INFO["bucket"], recursive=True))
objs = list(minio_client_bare.list_objects(s3_creds["bucket"], recursive=True))
objs = [
minio_client_bare.remove_object(
S3_CONN_INFO["bucket"], o.object_name.encode("utf-8")
s3_creds["bucket"], o.object_name.encode("utf-8")
)
for o in objs
]
minio_client_bare.remove_bucket(S3_CONN_INFO["bucket"])
minio_client_bare.remove_bucket(s3_creds["bucket"])


@pytest.fixture
def test(schema_any):
yield schema.TTest()


@pytest.fixture
def test2(schema_any):
yield schema.TTest2()


@pytest.fixture
def test_extra(schema_any):
yield schema.TTestExtra()


@pytest.fixture
def test_no_extra(schema_any):
yield schema.TTestNoExtra()


@pytest.fixture
def user(schema_any):
return schema.User()


@pytest.fixture
def lang(schema_any):
yield schema.Language()


@pytest.fixture
def languages(lang) -> List:
og_contents = lang.contents
languages = og_contents.copy()
yield languages
lang.contents = og_contents


@pytest.fixture
def subject(schema_any):
yield schema.Subject()


@pytest.fixture
def experiment(schema_any):
return schema.Experiment()


@pytest.fixture
def ephys(schema_any):
return schema.Ephys()


@pytest.fixture
def img(schema_any):
return schema.Image()


@pytest.fixture
def trial(schema_any):
return schema.Trial()


@pytest.fixture
def channel(schema_any):
return schema.Ephys.Channel()


@pytest.fixture
def trash(schema_any):
return schema.UberTrash()
56 changes: 56 additions & 0 deletions tests/schema_alter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import datajoint as dj
import inspect


class Experiment(dj.Imported):
original_definition = """ # information about experiments
-> Subject
experiment_id :smallint # experiment number for this subject
---
experiment_date :date # date when experiment was started
-> [nullable] User
data_path="" :varchar(255) # file path to recorded data
notes="" :varchar(2048) # e.g. purpose of experiment
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
"""

definition1 = """ # Experiment
-> Subject
experiment_id :smallint # experiment number for this subject
---
data_path : int # some number
extra=null : longblob # just testing
-> [nullable] User
subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
"""


class Parent(dj.Manual):
definition = """
parent_id: int
"""

class Child(dj.Part):
definition = """
-> Parent
"""
definition_new = """
-> master
---
child_id=null: int
"""

class Grandchild(dj.Part):
definition = """
-> master.Child
"""
definition_new = """
-> master.Child
---
grandchild_id=null: int
"""


LOCALS_ALTER = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_ALTER)
1 change: 0 additions & 1 deletion tests/schema_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tempfile
import inspect
import datajoint as dj
from . import PREFIX, CONN_INFO, S3_CONN_INFO
import numpy as np


Expand Down
1 change: 0 additions & 1 deletion tests/schema_uuid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid
import inspect
import datajoint as dj
from . import PREFIX, CONN_INFO

top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000")

Expand Down
Loading

0 comments on commit 88c96e6

Please sign in to comment.