Skip to content

Commit

Permalink
Remove unnecessary * parameter in views, create CommonParameters
Browse files Browse the repository at this point in the history
…and `CurrentIncident` Annotated types (#3200)

* use Annotated type for common_parameters across all views

* remove unnecessary * syntax for k,v parameters since we use Annotated types

* correctly set default values for Annotated types in common_parameters

* Pydantic does not support Annotated types with default values, until 2.0

* fully define CommonParameters type and remove usage of Any

* remove more * syntax and create Annotated type for CurrentIncident

* switch case view include arg back to non-Annotated
  • Loading branch information
wssheldon authored Apr 3, 2023
1 parent e533215 commit 368c776
Show file tree
Hide file tree
Showing 42 changed files with 225 additions and 277 deletions.
7 changes: 3 additions & 4 deletions src/dispatch/auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
InvalidUsernameError,
)
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.enums import UserRoles
from dispatch.models import OrganizationSlug, PrimaryKey
from dispatch.organization.models import OrganizationRead
Expand Down Expand Up @@ -49,7 +49,7 @@
],
response_model=UserPagination,
)
def get_users(*, organization: OrganizationSlug, common: dict = Depends(common_parameters)):
def get_users(organization: OrganizationSlug, common: CommonParameters):
"""Get all users."""
common["filter_spec"] = {
"and": [{"model": "Organization", "op": "==", "field": "slug", "value": organization}]
Expand All @@ -73,7 +73,7 @@ def get_users(*, organization: OrganizationSlug, common: dict = Depends(common_p


@user_router.get("/{user_id}", response_model=UserRead)
def get_user(*, db_session: DbSession, user_id: PrimaryKey):
def get_user(db_session: DbSession, user_id: PrimaryKey):
"""Get a user."""
user = get(db_session=db_session, user_id=user_id)
if not user:
Expand All @@ -90,7 +90,6 @@ def get_user(*, db_session: DbSession, user_id: PrimaryKey):
response_model=UserRead,
)
def update_user(
*,
db_session: DbSession,
user_id: PrimaryKey,
organization: OrganizationSlug,
Expand Down
7 changes: 3 additions & 4 deletions src/dispatch/case/priority/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.models import PrimaryKey

Expand All @@ -18,7 +18,7 @@


@router.get("", response_model=CasePriorityPagination, tags=["case_priorities"])
def get_case_priorities(*, common: dict = Depends(common_parameters)):
def get_case_priorities(common: CommonParameters):
"""Returns all case priorities."""
return search_filter_sort_paginate(model="CasePriority", **common)

Expand All @@ -29,7 +29,6 @@ def get_case_priorities(*, common: dict = Depends(common_parameters)):
dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))],
)
def create_case_priority(
*,
db_session: DbSession,
case_priority_in: CasePriorityCreate,
):
Expand Down Expand Up @@ -66,7 +65,7 @@ def update_case_priority(


@router.get("/{case_priority_id}", response_model=CasePriorityRead)
def get_case_priority(*, db_session: DbSession, case_priority_id: PrimaryKey):
def get_case_priority(db_session: DbSession, case_priority_id: PrimaryKey):
"""Gets a case priority."""
case_priority = get(db_session=db_session, case_priority_id=case_priority_id)
if not case_priority:
Expand Down
6 changes: 3 additions & 3 deletions src/dispatch/case/severity/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.models import PrimaryKey

Expand All @@ -18,7 +18,7 @@


@router.get("", response_model=CaseSeverityPagination, tags=["case_severities"])
def get_case_severities(*, common: dict = Depends(common_parameters)):
def get_case_severities(common: CommonParameters):
"""Returns all case severities."""
return search_filter_sort_paginate(model="CaseSeverity", **common)

Expand Down Expand Up @@ -66,7 +66,7 @@ def update_case_severity(


@router.get("/{case_severity_id}", response_model=CaseSeverityRead)
def get_case_severity(*, db_session: DbSession, case_severity_id: PrimaryKey):
def get_case_severity(db_session: DbSession, case_severity_id: PrimaryKey):
"""Gets a case severity."""
case_severity = get(db_session=db_session, case_severity_id=case_severity_id)
if not case_severity:
Expand Down
6 changes: 3 additions & 3 deletions src/dispatch/case/type/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

from .models import CaseTypeCreate, CaseTypePagination, CaseTypeRead, CaseTypeUpdate
Expand All @@ -13,7 +13,7 @@


@router.get("", response_model=CaseTypePagination, tags=["case_types"])
def get_case_types(*, common: dict = Depends(common_parameters)):
def get_case_types(common: CommonParameters):
"""Returns all case types."""
return search_filter_sort_paginate(model="CaseType", **common)

Expand Down Expand Up @@ -54,7 +54,7 @@ def update_case_type(


@router.get("/{case_type_id}", response_model=CaseTypeRead)
def get_case_type(*, db_session: DbSession, case_type_id: PrimaryKey):
def get_case_type(db_session: DbSession, case_type_id: PrimaryKey):
"""Gets a case type."""
case_type = get(db_session=db_session, case_type_id=case_type_id)
if not case_type:
Expand Down
8 changes: 2 additions & 6 deletions src/dispatch/case/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dispatch.case.enums import CaseStatus
from dispatch.common.utils.views import create_pydantic_include
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import OrganizationSlug, PrimaryKey
from dispatch.incident.models import IncidentCreate, IncidentRead
from dispatch.incident import service as incident_service
Expand Down Expand Up @@ -74,11 +74,7 @@ def get_case(


@router.get("", summary="Retrieves a list of cases.")
def get_cases(
*,
common: dict = Depends(common_parameters),
include: List[str] = Query([], alias="include[]"),
):
def get_cases(common: CommonParameters, include: List[str] = Query([], alias="include[]")):
"""Retrieves all cases."""
pagination = search_filter_sort_paginate(model="Case", **common)

Expand Down
8 changes: 4 additions & 4 deletions src/dispatch/data/alert/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@router.get("/{alert_id}", response_model=AlertRead)
def get_alert(*, db_session: DbSession, alert_id: PrimaryKey):
def get_alert(db_session: DbSession, alert_id: PrimaryKey):
"""Given its unique id, retrieve details about a single alert."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand All @@ -26,13 +26,13 @@ def get_alert(*, db_session: DbSession, alert_id: PrimaryKey):


@router.post("", response_model=AlertRead)
def create_alert(*, db_session: DbSession, alert_in: AlertCreate):
def create_alert(db_session: DbSession, alert_in: AlertCreate):
"""Creates a new alert."""
return create(db_session=db_session, alert_in=alert_in)


@router.put("/{alert_id}", response_model=AlertRead)
def update_alert(*, db_session: DbSession, alert_id: PrimaryKey, alert_in: AlertUpdate):
def update_alert(db_session: DbSession, alert_id: PrimaryKey, alert_in: AlertUpdate):
"""Updates an alert."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand All @@ -44,7 +44,7 @@ def update_alert(*, db_session: DbSession, alert_id: PrimaryKey, alert_in: Alert


@router.delete("/{alert_id}", response_model=None)
def delete_alert(*, db_session: DbSession, alert_id: PrimaryKey):
def delete_alert(db_session: DbSession, alert_id: PrimaryKey):
"""Deletes an alert, returning only an HTTP 200 OK if successful."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand Down
14 changes: 7 additions & 7 deletions src/dispatch/data/query/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, HTTPException, status

from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

from .models import (
Expand All @@ -16,13 +16,13 @@


@router.get("", response_model=QueryPagination)
def get_queries(*, common: dict = Depends(common_parameters)):
def get_queries(common: CommonParameters):
"""Get all queries, or only those matching a given search term."""
return search_filter_sort_paginate(model="Query", **common)


@router.get("/{query_id}", response_model=QueryRead)
def get_query(*, db_session: DbSession, query_id: PrimaryKey):
def get_query(db_session: DbSession, query_id: PrimaryKey):
"""Given its unique ID, retrieve details about a single query."""
query = get(db_session=db_session, query_id=query_id)
if not query:
Expand All @@ -34,13 +34,13 @@ def get_query(*, db_session: DbSession, query_id: PrimaryKey):


@router.post("", response_model=QueryRead)
def create_query(*, db_session: DbSession, query_in: QueryCreate):
def create_query(db_session: DbSession, query_in: QueryCreate):
"""Creates a new data query."""
return create(db_session=db_session, query_in=query_in)


@router.put("/{query_id}", response_model=QueryRead)
def update_query(*, db_session: DbSession, query_id: PrimaryKey, query_in: QueryUpdate):
def update_query(db_session: DbSession, query_id: PrimaryKey, query_in: QueryUpdate):
"""Updates a data query."""
query = get(db_session=db_session, query_id=query_id)
if not query:
Expand All @@ -52,7 +52,7 @@ def update_query(*, db_session: DbSession, query_id: PrimaryKey, query_in: Query


@router.delete("/{query_id}", response_model=None)
def delete_query(*, db_session: DbSession, query_id: PrimaryKey):
def delete_query(db_session: DbSession, query_id: PrimaryKey):
"""Deletes a data query, returning only an HTTP 200 OK if successful."""
query = get(db_session=db_session, query_id=query_id)
if not query:
Expand Down
15 changes: 6 additions & 9 deletions src/dispatch/data/source/data_format/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, HTTPException, status


from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

from .models import (
Expand All @@ -18,13 +18,13 @@


@router.get("", response_model=SourceDataFormatPagination)
def get_source_data_formats(*, common: dict = Depends(common_parameters)):
def get_source_data_formats(common: CommonParameters):
"""Get all source data formats, or only those matching a given search term."""
return search_filter_sort_paginate(model="SourceDataFormat", **common)


@router.get("/{source_data_format_id}", response_model=SourceDataFormatRead)
def get_source_data_format(*, db_session: DbSession, source_data_format_id: PrimaryKey):
def get_source_data_format(db_session: DbSession, source_data_format_id: PrimaryKey):
"""Given its unique id, retrieve details about a source data format."""
source_data_format = get(db_session=db_session, source_data_format_id=source_data_format_id)
if not source_data_format:
Expand All @@ -36,16 +36,13 @@ def get_source_data_format(*, db_session: DbSession, source_data_format_id: Prim


@router.post("", response_model=SourceDataFormatRead)
def create_source_data_format(
*, db_session: DbSession, source_data_format_in: SourceDataFormatCreate
):
def create_source_data_format(db_session: DbSession, source_data_format_in: SourceDataFormatCreate):
"""Creates a new source data format."""
return create(db_session=db_session, source_data_format_in=source_data_format_in)


@router.put("/{source_data_format_id}", response_model=SourceDataFormatRead)
def update_source_data_format(
*,
db_session: DbSession,
source_data_format_id: PrimaryKey,
source_data_format_in: SourceDataFormatUpdate,
Expand All @@ -65,7 +62,7 @@ def update_source_data_format(


@router.delete("/{source_data_format_id}", response_model=None)
def delete_source_data_format(*, db_session: DbSession, source_data_format_id: PrimaryKey):
def delete_source_data_format(db_session: DbSession, source_data_format_id: PrimaryKey):
"""Delete a source data format, returning only an HTTP 200 OK if successful."""
source_data_format = get(db_session=db_session, source_data_format_id=source_data_format_id)
if not source_data_format:
Expand Down
13 changes: 6 additions & 7 deletions src/dispatch/data/source/environment/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, HTTPException, status


from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

from .models import (
Expand All @@ -18,13 +18,13 @@


@router.get("", response_model=SourceEnvironmentPagination)
def get_source_environments(*, common: dict = Depends(common_parameters)):
def get_source_environments(common: CommonParameters):
"""Get all source_environment environments, or only those matching a given search term."""
return search_filter_sort_paginate(model="SourceEnvironment", **common)


@router.get("/{source_environment_id}", response_model=SourceEnvironmentRead)
def get_source_environment(*, db_session: DbSession, source_environment_id: PrimaryKey):
def get_source_environment(db_session: DbSession, source_environment_id: PrimaryKey):
"""Given its unique id, retrieve details about a single source_environment environment."""
source_environment = get(db_session=db_session, source_environment_id=source_environment_id)
if not source_environment:
Expand All @@ -37,15 +37,14 @@ def get_source_environment(*, db_session: DbSession, source_environment_id: Prim

@router.post("", response_model=SourceEnvironmentRead)
def create_source_environment(
*, db_session: DbSession, source_environment_in: SourceEnvironmentCreate
db_session: DbSession, source_environment_in: SourceEnvironmentCreate
):
"""Creates a new source environment."""
return create(db_session=db_session, source_environment_in=source_environment_in)


@router.put("/{source_environment_id}", response_model=SourceEnvironmentRead)
def update_source_environment(
*,
db_session: DbSession,
source_environment_id: PrimaryKey,
source_environment_in: SourceEnvironmentUpdate,
Expand All @@ -65,7 +64,7 @@ def update_source_environment(


@router.delete("/{source_environment_id}", response_model=None)
def delete_source_environment(*, db_session: DbSession, source_environment_id: PrimaryKey):
def delete_source_environment(db_session: DbSession, source_environment_id: PrimaryKey):
"""Delete a source environment, returning only an HTTP 200 OK if successful."""
source_environment = get(db_session=db_session, source_environment_id=source_environment_id)
if not source_environment:
Expand Down
13 changes: 6 additions & 7 deletions src/dispatch/data/source/status/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, HTTPException


from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

from .models import (
Expand All @@ -18,13 +18,13 @@


@router.get("", response_model=SourceStatusPagination)
def get_source_statuses(*, common: dict = Depends(common_parameters)):
def get_source_statuses(common: CommonParameters):
"""Get all source statuses, or only those matching a given search term."""
return search_filter_sort_paginate(model="SourceStatus", **common)


@router.get("/{source_status_id}", response_model=SourceStatusRead)
def get_source_status(*, db_session: DbSession, source_status_id: PrimaryKey):
def get_source_status(db_session: DbSession, source_status_id: PrimaryKey):
"""Given its unique id, retrieve details about a single source status."""
status = get(db_session=db_session, source_status_id=source_status_id)
if not status:
Expand All @@ -36,14 +36,13 @@ def get_source_status(*, db_session: DbSession, source_status_id: PrimaryKey):


@router.post("", response_model=SourceStatusRead)
def create_source_status(*, db_session: DbSession, source_status_in: SourceStatusCreate):
def create_source_status(db_session: DbSession, source_status_in: SourceStatusCreate):
"""Creates a new source status."""
return create(db_session=db_session, source_status_in=source_status_in)


@router.put("/{source_status_id}", response_model=SourceStatusRead)
def update_source_status(
*,
db_session: DbSession,
source_status_id: PrimaryKey,
source_status_in: SourceStatusUpdate,
Expand All @@ -59,7 +58,7 @@ def update_source_status(


@router.delete("/{source_status_id}", response_model=None)
def delete_source_status(*, db_session: DbSession, source_status_id: PrimaryKey):
def delete_source_status(db_session: DbSession, source_status_id: PrimaryKey):
"""Deletes a source status, returning only an HTTP 200 OK if successful."""
status = get(db_session=db_session, source_status_id=source_status_id)
if not status:
Expand Down
Loading

0 comments on commit 368c776

Please sign in to comment.