Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make quetz compatible with SQLAlchemy 2.0 #598

Merged
merged 13 commits into from
Feb 15, 2023
2 changes: 1 addition & 1 deletion plugins/quetz_conda_suggest/quetz_conda_suggest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def post_add_package_version(version, condainfo):
)
db.add(metadata)
else:
metadata = db.query(db_models.CondaSuggestMetadata).get(version.id)
metadata = db.get(db_models.CondaSuggestMetadata, version.id)
metadata.data = json.dumps(suggest_map)
db.commit()
generate_channel_suggest_map(db, version.channel_name, subdir)
Expand Down
6 changes: 3 additions & 3 deletions plugins/quetz_content_trust/quetz_content_trust/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ContentTrustRole(Base):

# delegator created by 'role_delegations.consumers' relationship backref

time_created = Column(Date, nullable=False, server_default=func.now())
time_created = Column(Date, nullable=False, server_default=func.current_date())


class RoleDelegation(Base):
Expand All @@ -65,14 +65,14 @@ class RoleDelegation(Base):
keys = relationship(
"SigningKey", secondary=association_table, backref="delegations"
)
time_created = Column(Date, nullable=False, server_default=func.now())
time_created = Column(Date, nullable=False, server_default=func.current_date())


class SigningKey(Base):
__tablename__ = 'signing_keys'

public_key = Column(String, primary_key=True)
private_key = Column(String)
time_created = Column(Date, nullable=False, server_default=func.now())
time_created = Column(Date, nullable=False, server_default=func.current_date())
user_id = Column(UUID, ForeignKey('users.id'))
channel_name = Column(String, ForeignKey('channels.name'))
2 changes: 1 addition & 1 deletion plugins/quetz_runexports/quetz_runexports/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def post_add_package_version(version, condainfo):
)
db.add(metadata)
else:
metadata = db.query(db_models.PackageVersionMetadata).get(version.id)
metadata = db.get(db_models.PackageVersionMetadata, version.id)
metadata.data = run_exports
db.commit()
2 changes: 1 addition & 1 deletion quetz/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def create_api_key(self, user_id, api_key: rest_models.BaseApiKey, key):
return db_api_key

def get_api_key(self, key):
return self.db.query(ApiKey).get(key)
return self.db.get(ApiKey, key)

def create_version(
self,
Expand Down
18 changes: 13 additions & 5 deletions quetz/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@
func,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, column_property, relationship
from sqlalchemy.schema import ForeignKeyConstraint

Base = declarative_base()
try:
from sqlalchemy.orm import DeclarativeBase # type: ignore

class Base(DeclarativeBase):
pass

except ImportError:
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

UUID = LargeBinary(length=16)

Expand Down Expand Up @@ -209,7 +217,7 @@ class Channel(Base):
mirrors = relationship("ChannelMirror", cascade="all, delete", uselist=True)

members_count = column_property(
select([func.count(ChannelMember.user_id)])
select(func.count(ChannelMember.user_id))
.where(ChannelMember.channel_name == name)
.scalar_subquery(), # type: ignore
deferred=True,
Expand All @@ -223,7 +231,7 @@ def load_channel_metadata(self):
return {}

packages_count = column_property(
select([func.count(Package.name)])
select(func.count(Package.name))
.where(Package.channel_name == name)
.scalar_subquery(), # type: ignore
deferred=True,
Expand Down Expand Up @@ -281,7 +289,7 @@ class ApiKey(Base):

key = Column(String, primary_key=True, index=True)
description = Column(String)
time_created = Column(Date, nullable=False, server_default=func.now())
time_created = Column(Date, nullable=False, server_default=func.current_date())
expire_at = Column(Date)
deleted = Column(Boolean, default=False)
user_id = Column(UUID, ForeignKey('users.id'))
Expand Down
24 changes: 12 additions & 12 deletions quetz/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,25 @@ def run_migrations_online():
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = config.attributes.get('connection', None)
connection = config.attributes.get('connection', None)

if connectable is None:
connectable = engine_from_config(
if connection is None:
engine = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connection = engine.connect()

with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
render_as_batch=True,
)
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
render_as_batch=True,
)

with context.begin_transaction():
context.run_migrations()
with context.begin_transaction():
context.run_migrations()


if context.is_offline_mode():
Expand Down
22 changes: 12 additions & 10 deletions quetz/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def get_db(_):
def refresh_db(engine, database_url):
Base.metadata.drop_all(engine)
try:
engine.execute("DROP TABLE alembic_version")
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE alembic_version"))
except sa.exc.DatabaseError:
pass

Expand All @@ -151,11 +152,11 @@ def test_run_migrations(
):
db = sql_connection
with pytest.raises(sa.exc.DatabaseError):
db.execute("SELECT * FROM users")
db.execute(sa.text("SELECT * FROM users"))

cli._run_migrations(alembic_config=alembic_config)

db.execute("SELECT * FROM users")
db.execute(sa.text("SELECT * FROM users"))


def test_make_migrations_quetz(mocker, config, config_dir):
Expand Down Expand Up @@ -296,7 +297,8 @@ class TestPluginModel(Base):

Base.metadata.drop_all(engine)
try:
engine.execute("DROP TABLE alembic_version")
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE alembic_version"))
except sa.exc.DatabaseError:
pass

Expand All @@ -309,12 +311,11 @@ class TestPluginModel(Base):
from sqlalchemy import MetaData
target_metadata = MetaData()

connectable = config.attributes.get('connection')
connection = config.attributes.get('connection')

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
"""

script_mako = """
Expand Down Expand Up @@ -425,7 +426,8 @@ def test_multi_head(
os.remove(p)

try:
engine.execute("DROP TABLE alembic_version")
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE alembic_version"))
except sa.exc.DatabaseError:
pass

Expand Down
4 changes: 2 additions & 2 deletions quetz/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_post_new_job_from_plugin(
)
assert response.status_code == 201
job_id = response.json()['id']
job = db.query(Job).get(job_id)
job = db.get(Job, job_id)
assert job.manifest.decode('ascii') == manifest

sync_supervisor.run_once()
Expand All @@ -654,7 +654,7 @@ def test_post_new_job_with_handler(
)
assert response.status_code == 201
job_id = response.json()['id']
job = db.query(Job).get(job_id)
job = db.get(Job, job_id)
assert job.status == JobStatus.pending
assert job.manifest.decode('ascii') == "test_action"

Expand Down
2 changes: 1 addition & 1 deletion quetz/tests/test_mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def package_version(user, mirror_channel, db, dao):

@pytest.fixture
def owner(user, db):
user = db.query(User).get(user.id)
user = db.get(User, user.id)
user.role = "owner"
db.commit()
yield user
Expand Down