Skip to content

Commit

Permalink
improved some mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMrSheldon committed Jul 30, 2024
1 parent 9f8bb80 commit ae86298
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 208 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ jobs:
working-directory: ${{github.workspace}}/application
run: |
mkdir .mypy_cache
mypy . --disallow-untyped-calls --explicit-package-bases --ignore-missing-imports --install-types --non-interactive --cache-dir=.mypy_cache/
mypy . --non-interactive --cache-dir=.mypy_cache/
- name: Run mypy on python-client
working-directory: ${{github.workspace}}/python-client
run: |
mkdir .mypy_cache
mypy . --disallow-untyped-calls --explicit-package-bases --ignore-missing-imports --install-types --non-interactive --cache-dir=.mypy_cache/
mypy .--non-interactive --cache-dir=.mypy_cache/
flake8:
runs-on: ubuntu-latest
Expand Down
20 changes: 18 additions & 2 deletions application/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
[tool.black]
line-length = 120
extend-exclude = "src/tira/proto"
exclude = '''/(
src/tira/migrations
| src/tira/proto
)'''

[tool.isort]
profile = "black"
multi_line_output = 3
line_length = 120
include_trailing_comma = true
skip = "src/tira/proto"
skip = [
"src/tira/migrations",
"src/tira/proto",
]

[tool.mypy]
disallow_untyped_calls = true
explicit_package_bases = true
ignore_missing_imports = true
install_types = true
exclude = [
"^src/tira/proto/.*\\.py$",
"^src/tira/migrations/.*\\.py$",
]
1 change: 1 addition & 0 deletions application/src/django_admin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def logger_config(log_dir: Path):
}
}

# FIXME: I don't close my file handle :((((((((
TIREX_COMPONENTS = yaml.load(open(BASE_DIR / "tirex-components.yml").read(), Loader=yaml.FullLoader)

# Logging
Expand Down
2 changes: 1 addition & 1 deletion application/src/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys


def main():
def main() -> None:
"""Run administrative tasks."""
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_admin.settings")
try:
Expand Down
57 changes: 35 additions & 22 deletions application/src/tira/authentication.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
""" """

import json
import logging
import os
from functools import wraps
from typing import Optional

from django.conf import settings
from django.http import HttpResponseNotAllowed
from django.http import HttpRequest, HttpResponseNotAllowed
from slugify import slugify

import tira.tira_model as model
Expand All @@ -19,7 +18,7 @@
class Authentication(object):
"""Base class for Authentication and Role Management"""

subclasses = {}
subclasses: dict[str, type] = {}
_AUTH_SOURCE = "superclass"
ROLE_TIRA = "tira" # super admin if we ever need it
ROLE_ADMIN = "admin" # is admin for the requested resource, so all permissions
Expand All @@ -46,10 +45,16 @@ def __init__(self, **kwargs):
pass

@staticmethod
def get_default_vm_id(user_id):
def get_default_vm_id(user_id: str) -> str:
return f"{user_id}-default"

def get_role(self, request, user_id: str = None, vm_id: str = None, task_id: str = None):
def get_role(
self,
request: HttpRequest,
user_id: Optional[str] = None,
vm_id: Optional[str] = None,
task_id: Optional[str] = None,
):
"""Determine the role of the user on the requested page (determined by the given directives).
@param request: djangos request object associated to the http request
Expand All @@ -66,25 +71,25 @@ def get_role(self, request, user_id: str = None, vm_id: str = None, task_id: str
def get_auth_source(self):
return self._AUTH_SOURCE

def get_user_id(self, request):
def get_user_id(self, request: HttpRequest):
return None

def get_vm_id(self, request, user_id):
def get_vm_id(self, request: HttpRequest, user_id):
return "None"

def login(self, request, **kwargs):
def login(self, request: HttpRequest, **kwargs):
pass

def logout(self, request, **kwargs):
def logout(self, request: HttpRequest, **kwargs):
pass

def create_group(self, vm_id):
return {"status": 0, "message": f"create_group is not implemented for {self._AUTH_SOURCE}"}

def get_organizer_ids(self, request, user_id=None):
def get_organizer_ids(self, request: HttpRequest, user_id=None):
pass

def get_vm_ids(self, request, user_id=None):
def get_vm_ids(self, request: HttpRequest, user_id=None):
pass

def user_is_organizer_for_endpoint(
Expand Down Expand Up @@ -154,19 +159,19 @@ def __init__(self, **kwargs):
super(DisraptorAuthentication, self).__init__(**kwargs)
self.discourse_client = model.discourse_api_client()

def _get_user_id(self, request):
def _get_user_id(self, request: HttpRequest) -> Optional[str]:
"""Return the content of the X-Disraptor-User header set in the http request"""
user_id = request.headers.get("X-Disraptor-User", None)
if user_id:
if user_id is not None:
vm_id = Authentication.get_default_vm_id(user_id)
_ = model.get_vm(vm_id, create_if_none=True)
return user_id

def _is_in_group(self, request, group_name="tira_reviewer") -> bool:
def _is_in_group(self, request: HttpRequest, group_name="tira_reviewer") -> bool:
"""return True if the user is in the given disraptor group"""
return group_name in request.headers.get("X-Disraptor-Groups", "").split(",")

def _parse_tira_groups(self, groups: list) -> list:
def _parse_tira_groups(self, groups: list[str]) -> dict[str, str]:
"""find all groups with 'tira_' prefix and return key and value of the group.
Note: Groupnames should be in the format '[tira_]key[_value]'
"""
Expand All @@ -183,7 +188,7 @@ def _parse_tira_groups(self, groups: list) -> list:
value = None
yield {"key": key, "value": value}

def _get_user_groups(self, request, group_type: str = "vm") -> list:
def _get_user_groups(self, request: HttpRequest, group_type: str = "vm") -> list:
"""read groups from the disraptor groups header.
@param group_type: {"vm", "org"}, indicate the class of groups.
"""
Expand All @@ -201,8 +206,16 @@ def _get_user_groups(self, request, group_type: str = "vm") -> list:
if group_type == "org": # if we check for organizer groups of a user
return [group["value"] for group in self._parse_tira_groups(all_groups) if group["key"] == "org"]

raise ValueError(f"Can't handle group type {group_type}")

@check_disraptor_token
def get_role(self, request, user_id: str = None, vm_id: str = None, task_id: str = None):
def get_role(
self,
request: HttpRequest,
user_id: Optional[str] = None,
vm_id: Optional[str] = None,
task_id: Optional[str] = None,
):
"""Determine the role of the user on the requested page (determined by the given directives).
This is a minimalistic implementation that suffices for the current features of TIRA.
Expand Down Expand Up @@ -230,28 +243,28 @@ def get_role(self, request, user_id: str = None, vm_id: str = None, task_id: str
return self.ROLE_GUEST

@check_disraptor_token
def get_user_id(self, request):
def get_user_id(self, request: HttpRequest):
"""public wrapper of _get_user_id that checks conditions"""
return self._get_user_id(request)

@check_disraptor_token
def get_vm_id(self, request, user_id=None):
def get_vm_id(self, request: HttpRequest, user_id=None):
"""return the vm_id of the first vm_group ("tira-vm-<vm_id>") found.
If there is no vm-group, return "no-vm-assigned"
"""

return self.get_vm_ids(request, user_id)[0]

@check_disraptor_token
def get_organizer_ids(self, request, user_id=None):
def get_organizer_ids(self, request: HttpRequest, user_id=None):
"""return the organizer ids of all organizer teams that the user is found in ("tira-org-<vm_id>").
If there is no vm-group, return the empty list
"""

return self._get_user_groups(request, group_type="org")

@check_disraptor_token
def get_vm_ids(self, request, user_id=None):
def get_vm_ids(self, request: HttpRequest, user_id=None):
"""returns a list of all vm_ids of the all vm_groups ("tira-vm-<vm_id>") found.
If there is no vm-group, a list with "no-vm-assigned" is returned
"""
Expand Down
8 changes: 4 additions & 4 deletions application/src/tira/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import wraps

from django.conf import settings
from django.http import Http404, HttpResponseNotAllowed, HttpResponseRedirect, JsonResponse
from django.http import Http404, HttpRequest, HttpResponseNotAllowed, HttpResponseRedirect, JsonResponse
from django.shortcuts import redirect
from django.urls import resolve

Expand Down Expand Up @@ -34,7 +34,7 @@ def check_permissions(func):
"""

@wraps(func)
def func_wrapper(request, *args, **kwargs):
def func_wrapper(request: HttpRequest, *args, **kwargs):
vm_id = kwargs.get("vm_id", None)
user_id = kwargs.get("user_id", None)
if vm_id is None and user_id is not None: # some endpoints say user_id instead of vm_id
Expand Down Expand Up @@ -275,12 +275,12 @@ def run_is_public(run_id, vm_id, dataset_id):
return dataset_is_public(dataset_id)


def dataset_is_public(dataset_id):
def dataset_is_public(dataset_id: str) -> bool:
if not dataset_id or (dataset_id not in settings.PUBLIC_TRAINING_DATA and not dataset_id.endswith("-training")):
return False

i = model.get_dataset(dataset_id)
return i and "is_confidential" in i and not i["is_confidential"] and "is_deprecated" in i and not i["is_deprecated"]
return ("is_confidential" in i) and not i["is_confidential"] and ("is_deprecated" in i) and not i["is_deprecated"]


def check_resources_exist(reply_as="json"):
Expand Down
Loading

0 comments on commit ae86298

Please sign in to comment.