Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy: analyze types in all possible libraries #658

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
- name: Add micromamba to GITHUB_PATH
run: echo "${HOME}/micromamba-bin" >> "$GITHUB_PATH"
- run: ln -s "${CONDA_PREFIX}" .venv # Necessary for pyright.
- run: pip install -e .[mypy]
- name: Add mypy to GITHUB_PATH
run: echo "${GITHUB_WORKSPACE}/.venv/bin" >> "$GITHUB_PATH"
- uses: pre-commit/action@v3.0.0
with:
extra_args: --all-files --show-diff-on-failure
Expand Down
14 changes: 1 addition & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,7 @@ repos:
hooks:
- id: mypy
files: ^quetz/
additional_dependencies:
- sqlalchemy-stubs
- types-click
- types-Jinja2
- types-mock
- types-orjson
- types-pkg-resources
- types-redis
- types-requests
- types-six
- types-toml
- types-ujson
- types-aiofiles
language: system
args: [--show-error-codes]
- repo: https://github.com/Quantco/pre-commit-mirrors-prettier
rev: 2.7.1
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ dependencies:
- pre-commit
- pytest
- pytest-mock
- rq
- libcflib
- mamba
- conda-content-trust
Expand Down
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,33 @@ venv = ".venv"
venvPath= "."

[tool.mypy]
ignore_missing_imports = true
packages = [
"quetz"
]
plugins = [
"pydantic.mypy",
"sqlmypy"
]
disable_error_code = [
"annotation-unchecked",
"misc"
]

[[tool.mypy.overrides]]
module = [
"adlfs",
"authlib",
"authlib.*",
"fsspec",
"gcsfs",
"pamela",
"sqlalchemy_utils",
"sqlalchemy_utils.*",
"s3fs",
"xattr"
]
ignore_missing_imports = true

[tool.coverage.run]
omit = [
"quetz/tests/*",
Expand Down
24 changes: 11 additions & 13 deletions quetz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,17 @@ def _alembic_config(db_url: str) -> AlembicConfig:


def _run_migrations(
db_url: Optional[str] = None,
alembic_config: Optional[AlembicConfig] = None,
db_url: str,
branch_name: str = "heads",
) -> None:
if db_url:
if db_url.startswith("postgre"):
db_engine = "PostgreSQL"
elif db_url.startswith("sqlite"):
db_engine = "SQLite"
else:
db_engine = db_url.split("/")[0]
logger.info('Running DB migrations on %s', db_engine)
if not alembic_config:
alembic_config = _alembic_config(db_url)
if db_url.startswith("postgre"):
db_engine = "PostgreSQL"
elif db_url.startswith("sqlite"):
db_engine = "SQLite"
else:
db_engine = db_url.split("/")[0]
logger.info('Running DB migrations on %s', db_engine)
alembic_config = _alembic_config(db_url)
command.upgrade(alembic_config, branch_name)


Expand Down Expand Up @@ -135,6 +132,7 @@ def _make_migrations(
logger.info('Making DB migrations on %r for %r', db_url, plugin_name)
if not alembic_config and db_url:
alembic_config = _alembic_config(db_url)
assert alembic_config is not None

# find path
if plugin_name == "quetz":
Expand Down Expand Up @@ -594,7 +592,7 @@ def start(
uvicorn.run(
"quetz.main:app",
reload=reload,
reload_dirs=(quetz_src,),
reload_dirs=[quetz_src],
port=port,
proxy_headers=proxy_headers,
host=host,
Expand Down
4 changes: 2 additions & 2 deletions quetz/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class Config:

_instances: Dict[Optional[str], "Config"] = {}

def __new__(cls, deployment_config: str = None):
def __new__(cls, deployment_config: Optional[str] = None):
if not deployment_config and None in cls._instances:
return cls._instances[None]

Expand All @@ -254,7 +254,7 @@ def __getattr__(self, name: str) -> Any:
super().__getattr__(self, name)

@classmethod
def find_file(cls, deployment_config: str = None):
def find_file(cls, deployment_config: Optional[str] = None):
config_file_env = os.getenv(f"{_env_prefix}{_env_config_file}")
deployment_config_files = []
for f in (deployment_config, config_file_env):
Expand Down
4 changes: 2 additions & 2 deletions quetz/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,8 @@ def create_version(
def get_package_versions(
self,
package,
time_created_ge: datetime = None,
version_match_str: str = None,
time_created_ge: Optional[datetime] = None,
version_match_str: Optional[str] = None,
skip: int = 0,
limit: int = -1,
):
Expand Down
2 changes: 1 addition & 1 deletion quetz/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@hookspec
def register_router() -> 'fastapi.APIRouter':
def register_router() -> 'fastapi.APIRouter': # type: ignore[empty-body]
"""add extra endpoints to the url tree.

It should return an :py:class:`fastapi.APIRouter` with new endpoints definitions.
Expand Down
26 changes: 13 additions & 13 deletions quetz/jobs/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def parse_job_name(v):
class JobBase(BaseModel):
"""New job spec"""

manifest: str = Field(None, title='Name of the function')
manifest: str = Field(title='Name of the function')

start_at: Optional[datetime] = Field(
None, title="date and time the job should start, if None it starts immediately"
Expand All @@ -110,35 +110,35 @@ def validate_job_name(cls, function_name):
class JobCreate(JobBase):
"""Create job spec"""

items_spec: str = Field(..., title='Item selector spec')
items_spec: str = Field(title='Item selector spec')


class JobUpdateModel(BaseModel):
"""Modify job spec items (status and items_spec)"""

items_spec: str = Field(None, title='Item selector spec')
status: JobStatus = Field(None, title='Change status')
items_spec: Optional[str] = Field(None, title='Item selector spec')
status: JobStatus = Field(title='Change status')
force: bool = Field(False, title="force re-running job on all matching packages")


class Job(JobBase):
id: int = Field(None, title='Unique id for job')
owner_id: uuid.UUID = Field(None, title='User id of the owner')
id: int = Field(title='Unique id for job')
owner_id: uuid.UUID = Field(title='User id of the owner')

created: datetime = Field(None, title='Created at')
created: datetime = Field(title='Created at')

status: JobStatus = Field(None, title='Status of the job (running, paused, ...)')
status: JobStatus = Field(title='Status of the job (running, paused, ...)')

items_spec: Optional[str] = Field(None, title='Item selector spec')
model_config = ConfigDict(from_attributes=True)


class Task(BaseModel):
id: int = Field(None, title='Unique id for task')
job_id: int = Field(None, title='ID of the parent job')
package_version: dict = Field(None, title='Package version')
created: datetime = Field(None, title='Created at')
status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)')
id: int = Field(title='Unique id for task')
job_id: int = Field(title='ID of the parent job')
package_version: dict = Field(title='Package version')
created: datetime = Field(title='Created at')
status: TaskStatus = Field(title='Status of the task (running, paused, ...)')

@field_validator("package_version", mode="before")
@classmethod
Expand Down
18 changes: 9 additions & 9 deletions quetz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def get_users_handler(dao, q, auth, skip, limit):
@api_router.get("/users", response_model=List[rest_models.User], tags=["users"])
def get_users(
dao: Dao = Depends(get_dao),
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
return get_users_handler(dao, q, auth, 0, -1)
Expand All @@ -341,7 +341,7 @@ def get_paginated_users(
dao: Dao = Depends(get_dao),
skip: int = 0,
limit: int = PAGINATION_LIMIT,
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
return get_users_handler(dao, q, auth, skip, limit)
Expand Down Expand Up @@ -521,7 +521,7 @@ def set_user_role(
def get_channels(
public: bool = True,
dao: Dao = Depends(get_dao),
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
"""List all channels"""
Expand All @@ -540,7 +540,7 @@ def get_paginated_channels(
skip: int = 0,
limit: int = PAGINATION_LIMIT,
public: bool = True,
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
"""List all channels, as a paginated response"""
Expand Down Expand Up @@ -780,7 +780,7 @@ def post_channel(
response_model=rest_models.ChannelBase,
)
def patch_channel(
channel_data: rest_models.Channel,
channel_data: rest_models.ChannelWithOptionalName,
dao: Dao = Depends(get_dao),
auth: authorization.Rules = Depends(get_rules),
channel: db_models.Channel = Depends(get_channel_or_fail),
Expand Down Expand Up @@ -1054,8 +1054,8 @@ def post_package_member(
def get_package_versions(
package: db_models.Package = Depends(get_package_or_fail),
dao: Dao = Depends(get_dao),
time_created__ge: datetime.datetime = None,
version_match_str: str = None,
time_created__ge: Optional[datetime.datetime] = None,
version_match_str: Optional[str] = None,
):
version_profile_list = dao.get_package_versions(
package, time_created__ge, version_match_str
Expand All @@ -1079,8 +1079,8 @@ def get_paginated_package_versions(
dao: Dao = Depends(get_dao),
skip: int = 0,
limit: int = PAGINATION_LIMIT,
time_created__ge: datetime.datetime = None,
version_match_str: str = None,
time_created__ge: Optional[datetime.datetime] = None,
version_match_str: Optional[str] = None,
):
version_profile_list = dao.get_package_versions(
package, time_created__ge, version_match_str, skip, limit
Expand Down
4 changes: 2 additions & 2 deletions quetz/metrics/view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from fastapi import FastAPI
from prometheus_client import (
CONTENT_TYPE_LATEST,
REGISTRY,
Expand All @@ -9,7 +10,6 @@
from prometheus_client.multiprocess import MultiProcessCollector
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from .middleware import PrometheusMiddleware

Expand All @@ -24,6 +24,6 @@ def metrics(request: Request) -> Response:
return Response(generate_latest(registry), media_type=CONTENT_TYPE_LATEST)


def init(app: ASGIApp):
def init(app: FastAPI):
app.add_middleware(PrometheusMiddleware)
app.add_route("/metricsp", metrics)
2 changes: 1 addition & 1 deletion quetz/pkgstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def file_exists(self, channel: str, destination: str):
def get_filemetadata(self, channel: str, src: str) -> Tuple[int, int, str]:
"""get file metadata: returns (file size, last modified time, etag)"""

@abc.abstractclassmethod
@abc.abstractmethod
def cleanup_temp_files(self, channel: str, dry_run: bool = False):
"""clean up temporary `*.json{HASH}.[bz2|gz]` files from pkgstore"""

Expand Down
24 changes: 15 additions & 9 deletions quetz/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class User(BaseUser):
Profile.model_rebuild()


Role = Field(None, pattern='owner|maintainer|member')
Role = Field(pattern='owner|maintainer|member')


class Member(BaseModel):
Expand All @@ -58,7 +58,7 @@ class MirrorMode(str, Enum):


class ChannelBase(BaseModel):
name: str = Field(None, title='The name of the channel', max_length=50)
name: str = Field(title='The name of the channel', max_length=50)
description: Optional[str] = Field(
None, title='The description of the channel', max_length=300
)
Expand Down Expand Up @@ -134,7 +134,7 @@ class ChannelMetadata(BaseModel):

class Channel(ChannelBase):
metadata: ChannelMetadata = Field(
default_factory=ChannelMetadata, title="channel metadata", examples={}
default_factory=ChannelMetadata, title="channel metadata", examples=[]
)

actions: Optional[List[ChannelActionEnum]] = Field(
Expand All @@ -160,8 +160,14 @@ def check_mirror_params(self) -> "Channel":
return self


class ChannelWithOptionalName(Channel):
name: Optional[str] = Field( # type: ignore[assignment]
None, title='The name of the channel', max_length=50
)


class ChannelMirrorBase(BaseModel):
url: str = Field(None, pattern="^(http|https)://.+")
url: str = Field(pattern="^(http|https)://.+")
api_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
metrics_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
model_config = ConfigDict(from_attributes=True)
Expand All @@ -173,7 +179,7 @@ class ChannelMirror(ChannelMirrorBase):

class Package(BaseModel):
name: str = Field(
None, title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
)
summary: Optional[str] = Field(None, title='The summary of the package')
description: Optional[str] = Field(None, title='The description of the package')
Expand Down Expand Up @@ -201,18 +207,18 @@ class PackageRole(BaseModel):


class PackageSearch(Package):
channel_name: str = Field(None, title='The channel this package belongs to')
channel_name: str = Field(title='The channel this package belongs to')


class ChannelSearch(BaseModel):
name: str = Field(None, title='The name of the channel', max_length=1500)
name: str = Field(title='The name of the channel', max_length=1500)
description: Optional[str] = Field(None, title='The description of the channel')
private: bool = Field(None, title='The visibility of the channel')
private: bool = Field(title='The visibility of the channel')
model_config = ConfigDict(from_attributes=True)


class PaginatedResponse(BaseModel, Generic[T]):
pagination: Pagination = Field(None, title="Pagination object")
pagination: Pagination = Field(title="Pagination object")
result: List[T] = Field([], title="Result objects")


Expand Down
Loading