Skip to content

Commit

Permalink
(add) support for managing multiple API keys with scope permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
danh91 committed Dec 11, 2023
1 parent 715c40a commit e47bb61
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 98 deletions.
36 changes: 6 additions & 30 deletions modules/core/karrio/server/core/authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import yaml # type: ignore
import yaml # type: ignore
import pydoc
import logging
import functools
Expand Down Expand Up @@ -201,39 +201,15 @@ def process_request(self, request):
request = authenticate_user(request)

if hasattr(request, "user") and getattr(request, "org", None) is None:
request.org = self._get_organization(request)
request.org = get_request_org(
request,
request.user,
org_id=request.META.get("HTTP_X_ORG_ID"),
)

if not hasattr(request, "test_mode"):
request.test_mode = get_request_test_mode(request)

def _get_organization(self, request):
"""
Attempts to find and return an organization using the given validated token.
"""
if settings.MULTI_ORGANIZATIONS:
try:
from karrio.server.orgs.models import Organization

org_id = request.META.get("HTTP_X_ORG_ID")
orgs = Organization.objects.filter(users__id=request.user.id)
org = (
orgs.filter(id=org_id).first()
if org_id is not None
else orgs.filter(is_active=True).first()
)

# org was found but is not active
if (org is not None) and (not org.is_active):
raise exceptions.AuthenticationFailed(
_("Organization is inactive"), code="organization_inactive"
)

return org
except ProgrammingError:
pass

return None


def authenticate_user(request):
def authenticate(request, authenticator):
Expand Down
8 changes: 4 additions & 4 deletions modules/core/karrio/server/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def error_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
return func(*args, **kwargs)
except Exception as e:
logger.exception(e)
raise e
Expand All @@ -66,7 +66,7 @@ def async_wrapper(func):
@functools.wraps(func)
def wrapper(*args, run_synchronous: bool = False, **kwargs):
def _run():
func(*args, **kwargs)
return func(*args, **kwargs)

if run_synchronous:
return _run()
Expand Down Expand Up @@ -211,8 +211,8 @@ def compute_tracking_status(
return serializers.TrackerStatus.pending

if (
any(details.status or "") and
serializers.TrackerStatus.map(details.status).value is not None
any(details.status or "")
and serializers.TrackerStatus.map(details.status).value is not None
):
return serializers.TrackerStatus.map(details.status)

Expand Down
20 changes: 10 additions & 10 deletions modules/core/karrio/server/user/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,31 @@ def permissions(self):
import karrio.server.iam.models as iam

_permissions = []
if iam.ContextPermission.objects.filter(object_pk=self.pk).exists():

if conf.settings.MULTI_ORGANIZATIONS and self.org.exists():
org_user = self.org.first().organization_users.filter(user_id=self.user_id)
_permissions = (
iam.ContextPermission.objects.get(
object_pk=self.pk,
content_type=ContentType.objects.get_for_model(Token),
object_pk=org_user.first().pk,
content_type=ContentType.objects.get_for_model(org_user.first()),
)
.groups.all()
.values_list("name", flat=True)
if org_user.exists()
else []
)

if (
not any(_permissions)
and conf.settings.MULTI_ORGANIZATIONS
and self.org.exists()
and iam.ContextPermission.objects.filter(object_pk=self.pk).exists()
):
org_user = self.org.first().organization_users.filter(user_id=self.user_id)
_permissions = (
iam.ContextPermission.objects.get(
object_pk=org_user.first().pk,
content_type=ContentType.objects.get_for_model(org_user.first()),
object_pk=self.pk,
content_type=ContentType.objects.get_for_model(Token),
)
.groups.all()
.values_list("name", flat=True)
if org_user.exists()
else []
)

return _permissions if any(_permissions) else self.user.permissions
Expand Down
72 changes: 35 additions & 37 deletions modules/core/karrio/server/user/serializers.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
from karrio.server.serializers import (
owned_model_serializer,
Serializer,
Context,
)
from karrio.server.user.models import Token


@owned_model_serializer
class TokenSerializer(Serializer):
def create(self, validated_data: dict, context: Context) -> Token:
extra = (
dict(org__id=getattr(context.org, "id", None))
if hasattr(Token, "org")
else {}
)
token = Token.objects.filter(
user=context.user,
test_mode=context.test_mode,
**extra,
).first()
import karrio.server.conf as conf
import karrio.server.user.models as models
import karrio.server.serializers as serializers

if token:
return token

return Token.objects.create(user=context.user, test_mode=context.test_mode)
@serializers.owned_model_serializer
class TokenSerializer(serializers.Serializer):
label = serializers.CharField(required=False)

def create(
self, validated_data: dict, context: serializers.Context
) -> models.Token:
return models.Token.objects.create(
user=context.user,
test_mode=context.test_mode,
label=validated_data.get("label") or "Default API Key",
)

@staticmethod
def retrieve_token(context, org_id: str = None):
user = getattr(context, "user", None)
test_mode = getattr(context, "test_mode", None)
org = getattr(context, "org", None)
org_id = org_id or getattr(org, "id", None)

queyset = models.Token.objects.filter(
**{
"test_mode": getattr(context, "test_mode", None),
"user__id": getattr(getattr(context, "user", None), "id", None),
**({"org__id": org_id} if org_id is not None else {}),
}
)

if org_id is not None and hasattr(Token, "org"):
import karrio.server.orgs.models as orgs
if queyset.exists():
return queyset.first()

org = orgs.Organization.objects.get(
id=org_id, users__id=getattr(user, "id", None)
)
else:
org = getattr(context, "org", None)
if org_id is not None and conf.settings.MULTI_ORGANIZATIONS:
import karrio.server.orgs.models as orgs

ctx = Context(user, org, test_mode)
tokens = Token.access_by(ctx).filter(user__id=getattr(user, "id", None))
org = orgs.Organization.objects.get(id=org_id, users__id=context.user.id)

if tokens.exists():
return tokens.first()
ctx = serializers.Context(
org=org,
user=getattr(context, "user", None),
test_mode=getattr(context, "test_mode", None),
)

return TokenSerializer.map(data={}, context=ctx).save().instance
15 changes: 15 additions & 0 deletions modules/graph/karrio/server/graph/schemas/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
class Query:
user: types.UserType = strawberry.field(resolver=types.UserType.resolve)
token: types.TokenType = strawberry.field(resolver=types.TokenType.resolve)
api_keys: typing.List[types.APIKeyType] = strawberry.field(
resolver=types.APIKeyType.resolve_list
)

user_connections: typing.List[types.CarrierConnectionType] = strawberry.field(
resolver=types.ConnectionType.resolve_list
Expand Down Expand Up @@ -85,6 +88,18 @@ def mutate_token(
) -> mutations.TokenMutation:
return mutations.TokenMutation.mutate(info, **input.to_dict())

@strawberry.mutation
def create_api_key(
self, info: Info, input: inputs.CreateAPIKeyMutationInput
) -> mutations.CreateAPIKeyMutation:
return mutations.CreateAPIKeyMutation.mutate(info, **input.to_dict())

@strawberry.mutation
def delete_api_key(
self, info: Info, input: inputs.DeleteAPIKeyMutationInput
) -> mutations.DeleteAPIKeyMutation:
return mutations.DeleteAPIKeyMutation.mutate(info, **input.to_dict())

@strawberry.mutation
def request_email_change(
self, info: Info, input: inputs.RequestEmailChangeMutationInput
Expand Down
13 changes: 13 additions & 0 deletions modules/graph/karrio/server/graph/schemas/base/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ class TokenMutationInput(utils.BaseInput):
refresh: typing.Optional[bool] = strawberry.UNSET


@strawberry.input
class CreateAPIKeyMutationInput(utils.BaseInput):
password: str
label: str
permissions: typing.Optional[typing.List[str]] = strawberry.UNSET


@strawberry.input
class DeleteAPIKeyMutationInput(utils.BaseInput):
password: str
key: str


@strawberry.input
class RequestEmailChangeMutationInput(utils.BaseInput):
email: str
Expand Down
70 changes: 65 additions & 5 deletions modules/graph/karrio/server/graph/schemas/base/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,31 @@
from strawberry.types import Info
from rest_framework import exceptions
from django.utils.http import urlsafe_base64_decode
from django.contrib.contenttypes.models import ContentType
from django_email_verification import confirm as email_verification
from django_otp.plugins.otp_email import models as otp
from django.utils.translation import gettext_lazy as _
from django.db import transaction

import karrio.lib as lib
from karrio.server.core.utils import ConfirmationToken, send_email
from karrio.server.user.serializers import TokenSerializer
from karrio.server.conf import settings
from karrio.server.serializers import (
save_many_to_many_data,
process_dictionaries_mutations,
)
from karrio.server.core.utils import ConfirmationToken, send_email
from karrio.server.user.serializers import TokenSerializer, Token
import karrio.server.providers.models as providers
import karrio.server.manager.serializers as manager_serializers
import karrio.server.graph.schemas.base.inputs as inputs
import karrio.server.graph.schemas.base.types as types
import karrio.server.graph.serializers as serializers
import karrio.server.providers.models as providers
import karrio.server.user.models as user_models
import karrio.server.manager.models as manager
import karrio.server.graph.models as graph
import karrio.server.graph.forms as forms
import karrio.server.graph.utils as utils
import karrio.server.iam.models as iam

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +66,7 @@ class TokenMutation(utils.BaseMutation):
def mutate(
info: Info, refresh: bool = None, password: str = None
) -> "UserUpdateMutation":
tokens = Token.access_by(info.context.request)
tokens = user_models.Token.access_by(info.context.request)

if refresh:
if len(password or "") == 0:
Expand All @@ -84,6 +87,61 @@ def mutate(
return TokenMutation(token=token) # type:ignore


@strawberry.type
class CreateAPIKeyMutation(utils.BaseMutation):
api_key: typing.Optional[types.APIKeyType] = None

@staticmethod
@transaction.atomic
@utils.authentication_required
@utils.authorization_required()
@utils.password_required
def mutate(
info: Info, password: str, **input: inputs.CreateAPIKeyMutationInput
) -> "CreateAPIKeyMutation":
context = info.context.request
data = input.copy()
permissions = data.pop("permissions", [])
api_key = TokenSerializer.map(data=data, context=context).save().instance

if any(permissions):
_auth_ctx = getattr(context, "token", context.user)
_ctx_permissions = getattr(_auth_ctx, "permissions", [])
_invalid_permissions = [_ for _ in permissions if _ not in _ctx_permissions]

if any(_invalid_permissions):
raise exceptions.ValidationError({"permissions": "Invalid permissions"})

_ctx = iam.ContextPermission.objects.create(
object_pk=api_key.pk,
content_object=api_key,
content_type=ContentType.objects.get_for_model(api_key),
)
_ctx.groups.set(user_models.Group.objects.filter(name__in=permissions))

return CreateAPIKeyMutation(
api_key=user_models.Token.access_by(context).get(key=api_key.key)
) # type:ignore


@strawberry.type
class DeleteAPIKeyMutation(utils.BaseMutation):
label: typing.Optional[str] = None

@staticmethod
@utils.authentication_required
@utils.authorization_required()
@utils.password_required
def mutate(
info: Info, password: str, **input: inputs.DeleteAPIKeyMutationInput
) -> "DeleteAPIKeyMutation":
api_key = user_models.Token.access_by(info.context.request).get(**input)
label = api_key.label
api_key.delete()

return DeleteAPIKeyMutation(label=label) # type:ignore


@strawberry.type
class RequestEmailChangeMutation(utils.BaseMutation):
user: typing.Optional[types.UserType] = None
Expand Down Expand Up @@ -131,7 +189,9 @@ def mutate(info: Info, token: str) -> "ConfirmEmailChangeMutation":
user = info.context.request.user

if user.id != validated_token["user_id"]:
raise exceptions.ValidationError({"token": "Token is invalid or expired"})
raise exceptions.ValidationError(
{"token": "user_models.Token is invalid or expired"}
)

if user.email == validated_token["new_email"]:
raise exceptions.APIException("Email is already confirmed")
Expand Down
Loading

0 comments on commit e47bb61

Please sign in to comment.