diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b6b4c401e..e68f1d3be1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -233,7 +233,7 @@ jobs: pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl pip install -r __app__/requirements.txt pip install -r requirements-dev.txt - pytest + pytest tests flake8 . bandit -r ./__app__/ black ./__app__/ ./tests --check diff --git a/docs/webhook_events.md b/docs/webhook_events.md index 38946433b9..b24aaf5078 100644 --- a/docs/webhook_events.md +++ b/docs/webhook_events.md @@ -663,6 +663,31 @@ Each event will be submitted via HTTP POST to the user provided URL. ```json { "definitions": { + "ApiAccessRule": { + "properties": { + "allowed_groups": { + "items": { + "format": "uuid", + "type": "string" + }, + "title": "Allowed Groups", + "type": "array" + }, + "methods": { + "items": { + "type": "string" + }, + "title": "Methods", + "type": "array" + } + }, + "required": [ + "methods", + "allowed_groups" + ], + "title": "ApiAccessRule", + "type": "object" + }, "AzureMonitorExtensionConfig": { "properties": { "config_version": { @@ -757,9 +782,27 @@ Each event will be submitted via HTTP POST to the user provided URL. "title": "Allowed Aad Tenants", "type": "array" }, + "api_access_rules": { + "additionalProperties": { + "$ref": "#/definitions/ApiAccessRule" + }, + "title": "Api Access Rules", + "type": "object" + }, "extensions": { "$ref": "#/definitions/AzureVmExtensionConfig" }, + "group_membership": { + "additionalProperties": { + "items": { + "format": "uuid", + "type": "string" + }, + "type": "array" + }, + "title": "Group Membership", + "type": "object" + }, "network_config": { "$ref": "#/definitions/NetworkConfig" }, @@ -4933,6 +4976,31 @@ Each event will be submitted via HTTP POST to the user provided URL. ```json { "definitions": { + "ApiAccessRule": { + "properties": { + "allowed_groups": { + "items": { + "format": "uuid", + "type": "string" + }, + "title": "Allowed Groups", + "type": "array" + }, + "methods": { + "items": { + "type": "string" + }, + "title": "Methods", + "type": "array" + } + }, + "required": [ + "methods", + "allowed_groups" + ], + "title": "ApiAccessRule", + "type": "object" + }, "Architecture": { "description": "An enumeration.", "enum": [ @@ -5856,9 +5924,27 @@ Each event will be submitted via HTTP POST to the user provided URL. "title": "Allowed Aad Tenants", "type": "array" }, + "api_access_rules": { + "additionalProperties": { + "$ref": "#/definitions/ApiAccessRule" + }, + "title": "Api Access Rules", + "type": "object" + }, "extensions": { "$ref": "#/definitions/AzureVmExtensionConfig" }, + "group_membership": { + "additionalProperties": { + "items": { + "format": "uuid", + "type": "string" + }, + "type": "array" + }, + "title": "Group Membership", + "type": "object" + }, "network_config": { "$ref": "#/definitions/NetworkConfig" }, diff --git a/src/api-service/__app__/onefuzzlib/azure/creds.py b/src/api-service/__app__/onefuzzlib/azure/creds.py index d946b6de5c..38eff823a0 100644 --- a/src/api-service/__app__/onefuzzlib/azure/creds.py +++ b/src/api-service/__app__/onefuzzlib/azure/creds.py @@ -168,14 +168,6 @@ def query_microsoft_graph_list( raise GraphQueryError("Expected data containing a list of values", None) -def is_member_of(group_id: str, member_id: str) -> bool: - body = {"groupIds": [group_id]} - response = query_microsoft_graph_list( - method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body - ) - return group_id in response - - @cached def get_scaleset_identity_resource_path() -> str: scaleset_id_name = "%s-scalesetid" % get_instance_name() diff --git a/src/api-service/__app__/onefuzzlib/azure/group_membership.py b/src/api-service/__app__/onefuzzlib/azure/group_membership.py new file mode 100644 index 0000000000..0b24914367 --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/azure/group_membership.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Dict, List, Protocol +from uuid import UUID + +from ..config import InstanceConfig +from .creds import query_microsoft_graph_list + + +class GroupMembershipChecker(Protocol): + def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool: + """Check if member is part of at least one of the groups""" + if member_id in group_ids: + return True + + groups = self.get_groups(member_id) + for g in group_ids: + if g in groups: + return True + + return False + + def get_groups(self, member_id: UUID) -> List[UUID]: + """Gets all the groups of the provided member""" + + +def create_group_membership_checker() -> GroupMembershipChecker: + config = InstanceConfig.fetch() + if config.group_membership: + return StaticGroupMembership(config.group_membership) + else: + return AzureADGroupMembership() + + +class AzureADGroupMembership(GroupMembershipChecker): + def get_groups(self, member_id: UUID) -> List[UUID]: + response = query_microsoft_graph_list( + method="GET", resource=f"users/{member_id}/transitiveMemberOf" + ) + return response + + +class StaticGroupMembership(GroupMembershipChecker): + def __init__(self, memberships: Dict[str, List[UUID]]): + self.memberships = memberships + + def get_groups(self, member_id: UUID) -> List[UUID]: + return self.memberships.get(str(member_id), []) diff --git a/src/api-service/__app__/onefuzzlib/request.py b/src/api-service/__app__/onefuzzlib/request.py index 470ac429f0..45faa75d3c 100644 --- a/src/api-service/__app__/onefuzzlib/request.py +++ b/src/api-service/__app__/onefuzzlib/request.py @@ -5,36 +5,75 @@ import json import logging -import os +import urllib from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union from uuid import UUID from azure.functions import HttpRequest, HttpResponse +from memoization import cached from onefuzztypes.enums import ErrorCode from onefuzztypes.models import Error from onefuzztypes.responses import BaseResponse +from pydantic import BaseModel # noqa: F401 from pydantic import ValidationError +from .azure.group_membership import create_group_membership_checker +from .config import InstanceConfig from .orm import ModelMixin +from .request_access import RequestAccess # We don't actually use these types at runtime at this time. Rather, # these are used in a bound TypeVar. MyPy suggests to only import these # types during type checking. if TYPE_CHECKING: from onefuzztypes.requests import BaseRequest # noqa: F401 - from pydantic import BaseModel # noqa: F401 + + +@cached(ttl=60) +def get_rules() -> Optional[RequestAccess]: + config = InstanceConfig.fetch() + if config.api_access_rules: + return RequestAccess.build(config.api_access_rules) + else: + return None def check_access(req: HttpRequest) -> Optional[Error]: - if "ONEFUZZ_AAD_GROUP_ID" in os.environ: - message = "ONEFUZZ_AAD_GROUP_ID configuration not supported" - logging.error(message) + rules = get_rules() + + # Noting to enforce if there are no rules. + if not rules: + return None + + path = urllib.parse.urlparse(req.url).path + rule = rules.get_matching_rules(req.method, path) + + # No restriction defined on this endpoint. + if not rule: + return None + + member_id = UUID(req.headers["x-ms-client-principal-id"]) + + try: + membership_checker = create_group_membership_checker() + allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id) + if not allowed: + logging.error( + "unauthorized access: %s is not authorized to access in %s", + member_id, + req.url, + ) + return Error( + code=ErrorCode.UNAUTHORIZED, + errors=["not approved to use this endpoint"], + ) + except Exception as e: return Error( - code=ErrorCode.INVALID_CONFIGURATION, - errors=[message], + code=ErrorCode.UNAUTHORIZED, + errors=["unable to interact with graph", str(e)], ) - else: - return None + + return None def ok( diff --git a/src/api-service/__app__/onefuzzlib/request_access.py b/src/api-service/__app__/onefuzzlib/request_access.py index 6dc0c4a635..1699deef43 100644 --- a/src/api-service/__app__/onefuzzlib/request_access.py +++ b/src/api-service/__app__/onefuzzlib/request_access.py @@ -1,8 +1,7 @@ -from typing import Dict, List +from typing import Dict, List, Optional from uuid import UUID from onefuzztypes.models import ApiAccessRule -from pydantic import parse_raw_as class RuleConflictError(Exception): @@ -41,7 +40,7 @@ def __init__(self) -> None: def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None: methods = list(map(lambda m: m.upper(), methods)) - segments = path.split("/") + segments = [s for s in path.split("/") if s != ""] if len(segments) == 0: return @@ -71,15 +70,14 @@ def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None: for method in methods: current_node.rules[method] = rules - def get_matching_rules(self, method: str, path: str) -> Rules: + def get_matching_rules(self, method: str, path: str) -> Optional[Rules]: method = method.upper() - segments = path.split("/") + segments = [s for s in path.split("/") if s != ""] current_node = self.root + current_rule = None if method in current_node.rules: current_rule = current_node.rules[method] - else: - current_rule = RequestAccess.Rules() current_segment_index = 0 @@ -98,17 +96,13 @@ def get_matching_rules(self, method: str, path: str) -> Rules: return current_rule @classmethod - def parse_rules(cls, rules_data: str) -> "RequestAccess": - rules = parse_raw_as(List[ApiAccessRule], rules_data) - return cls.build(rules) - - @classmethod - def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess": + def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess": request_access = RequestAccess() - for rule in rules: + for endpoint in rules: + rule = rules[endpoint] request_access.__add_url__( rule.methods, - rule.endpoint, + endpoint, RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups), ) diff --git a/src/api-service/functional_tests/api_restriction_test.py b/src/api-service/functional_tests/api_restriction_test.py new file mode 100644 index 0000000000..3bb13ba772 --- /dev/null +++ b/src/api-service/functional_tests/api_restriction_test.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import time +import uuid +from typing import Any, List +from urllib.parse import urlparse +from uuid import UUID + +from azure.cli.core import get_default_cli +from onefuzz.api import Onefuzz +from onefuzztypes.models import ApiAccessRule + + +def az_cli(args: List[str]) -> Any: + cli = get_default_cli() + cli.logging_cls + cli.invoke(args, out_file=open(os.devnull, "w")) + if cli.result.result: + return cli.result.result + elif cli.result.error: + raise cli.result.error + + +class APIRestrictionTests: + def __init__( + self, resource_group: str = None, onefuzz_config_path: str = None + ) -> None: + self.onefuzz = Onefuzz(config_path=onefuzz_config_path) + self.intial_config = self.onefuzz.instance_config.get() + + self.instance_name = urlparse(self.onefuzz.config().endpoint).netloc.split(".")[ + 0 + ] + if resource_group: + self.resource_group = resource_group + else: + self.resource_group = self.instance_name + + def restore_config(self) -> None: + self.onefuzz.instance_config.update(self.intial_config) + + def assign(self, group_id: UUID, member_id: UUID) -> None: + instance_config = self.onefuzz.instance_config.get() + if instance_config.group_membership is None: + instance_config.group_membership = {} + + if member_id not in instance_config.group_membership: + instance_config.group_membership[member_id] = [] + + if group_id not in instance_config.group_membership[member_id]: + instance_config.group_membership[member_id].append(group_id) + + self.onefuzz.instance_config.update(instance_config) + + def assign_current_user(self, group_id: UUID) -> None: + onefuzz_service_appId = az_cli( + [ + "ad", + "signed-in-user", + "show", + ] + ) + member_id = UUID(onefuzz_service_appId["objectId"]) + print(f"adding user {member_id}") + self.assign(group_id, member_id) + + def test_restriction_on_current_user(self) -> None: + + print("Checking that the current user can get jobs") + self.onefuzz.jobs.list() + + print("Creating test group") + group_id = uuid.uuid4() + + print("Adding restriction to the jobs endpoint") + instance_config = self.onefuzz.instance_config.get() + if instance_config.api_access_rules is None: + instance_config.api_access_rules = {} + + instance_config.api_access_rules["/api/jobs"] = ApiAccessRule( + allowed_groups=[group_id], + methods=["GET"], + ) + + self.onefuzz.instance_config.update(instance_config) + restart_instance(self.instance_name, self.resource_group) + time.sleep(20) + print("Checking that the current user cannot get jobs") + + try: + self.onefuzz.jobs.list() + failed = False + except Exception: + failed = True + pass + + if not failed: + raise Exception("Current user was able to get jobs") + + print("Assigning current user to test group") + self.assign_current_user(group_id) + restart_instance(self.instance_name, self.resource_group) + time.sleep(20) + + print("Checking that the current user can get jobs") + self.onefuzz.jobs.list() + + +def restart_instance(instance_name: str, resource_group: str) -> None: + print("Restarting instance") + az_cli( + [ + "functionapp", + "restart", + "--name", + f"{instance_name}", + "--resource-group", + f"{resource_group}", + ] + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", default=None) + parser.add_argument("--resource_group", default=None) + args = parser.parse_args() + tester = APIRestrictionTests(args.resource_group, args.config_path) + + try: + print("test current user restriction") + tester.test_restriction_on_current_user() + finally: + tester.restore_config() + pass + + +if __name__ == "__main__": + main() diff --git a/src/api-service/functional_tests/requirements.txt b/src/api-service/functional_tests/requirements.txt new file mode 100644 index 0000000000..02f6e63732 --- /dev/null +++ b/src/api-service/functional_tests/requirements.txt @@ -0,0 +1,4 @@ +../../cli +../../pytypes +azure-cli-core==2.27.2 +azure-cli==2.27.2 \ No newline at end of file diff --git a/src/api-service/tests/test_group_membership.py b/src/api-service/tests/test_group_membership.py new file mode 100644 index 0000000000..bdf8274097 --- /dev/null +++ b/src/api-service/tests/test_group_membership.py @@ -0,0 +1,38 @@ +import unittest +import uuid + +from __app__.onefuzzlib.azure.group_membership import ( + GroupMembershipChecker, + StaticGroupMembership, +) + + +class TestRequestAccess(unittest.TestCase): + def test_empty(self) -> None: + group_id = uuid.uuid4() + user_id = uuid.uuid4() + checker: GroupMembershipChecker = StaticGroupMembership({}) + + self.assertFalse(checker.is_member([group_id], user_id)) + self.assertTrue(checker.is_member([user_id], user_id)) + + def test_matching_user_id(self) -> None: + group_id = uuid.uuid4() + user_id1 = uuid.uuid4() + user_id2 = uuid.uuid4() + + checker: GroupMembershipChecker = StaticGroupMembership( + {str(user_id1): [group_id]} + ) + self.assertTrue(checker.is_member([user_id1], user_id1)) + self.assertFalse(checker.is_member([user_id1], user_id2)) + + def test_user_in_group(self) -> None: + group_id1 = uuid.uuid4() + group_id2 = uuid.uuid4() + user_id = uuid.uuid4() + checker: GroupMembershipChecker = StaticGroupMembership( + {str(user_id): [group_id1]} + ) + self.assertTrue(checker.is_member([group_id1], user_id)) + self.assertFalse(checker.is_member([group_id2], user_id)) diff --git a/src/api-service/tests/test_request_access.py b/src/api-service/tests/test_request_access.py index c75ccc933c..3021a9c4b5 100644 --- a/src/api-service/tests/test_request_access.py +++ b/src/api-service/tests/test_request_access.py @@ -8,61 +8,60 @@ class TestRequestAccess(unittest.TestCase): def test_empty(self) -> None: - request_access1 = RequestAccess.build([]) + request_access1 = RequestAccess.build({}) rules1 = request_access1.get_matching_rules("get", "a/b/c") - self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing") + self.assertEqual(rules1, None, "expected nothing") guid2 = uuid.uuid4() request_access1 = RequestAccess.build( - [ - ApiAccessRule( + { + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid2], ) - ] + } ) rules1 = request_access1.get_matching_rules("get", "") - self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing") + self.assertEqual(rules1, None, "expected nothing") def test_exact_match(self) -> None: guid1 = uuid.uuid4() request_access = RequestAccess.build( - [ - ApiAccessRule( + { + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid1], ) - ] + } ) rules1 = request_access.get_matching_rules("get", "a/b/c") rules2 = request_access.get_matching_rules("get", "b/b/e") + assert rules1 is not None self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups") self.assertEqual(rules1.allowed_groups_ids[0], guid1) - self.assertEqual(len(rules2.allowed_groups_ids), 0, "expected nothing") + self.assertEqual(rules2, None, "expected nothing") def test_wildcard(self) -> None: guid1 = uuid.uuid4() request_access = RequestAccess.build( - [ - ApiAccessRule( + { + "b/*/c": ApiAccessRule( methods=["get"], - endpoint="b/*/c", allowed_groups=[guid1], ) - ] + } ) rules = request_access.get_matching_rules("get", "b/b/c") + assert rules is not None self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups") self.assertEqual(rules.allowed_groups_ids[0], guid1) @@ -71,18 +70,16 @@ def test_adding_rule_on_same_path(self) -> None: try: RequestAccess.build( - [ - ApiAccessRule( + { + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid1], ), - ApiAccessRule( + "a/b/c/": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[], ), - ] + } ) self.fail("this is expected to fail") @@ -95,22 +92,21 @@ def test_priority(self) -> None: guid2 = uuid.uuid4() request_access = RequestAccess.build( - [ - ApiAccessRule( + { + "a/*/c": ApiAccessRule( methods=["get"], - endpoint="a/*/c", allowed_groups=[guid1], ), - ApiAccessRule( + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid2], ), - ] + } ) rules = request_access.get_matching_rules("get", "a/b/c") + assert rules is not None self.assertEqual( rules.allowed_groups_ids[0], guid2, @@ -125,36 +121,36 @@ def test_inherit_rule(self) -> None: guid3 = uuid.uuid4() request_access = RequestAccess.build( - [ - ApiAccessRule( + { + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid1], ), - ApiAccessRule( + "f/*/c": ApiAccessRule( methods=["get"], - endpoint="f/*/c", allowed_groups=[guid2], ), - ApiAccessRule( + "a/b": ApiAccessRule( methods=["post"], - endpoint="a/b", allowed_groups=[guid3], ), - ] + } ) rules1 = request_access.get_matching_rules("get", "a/b/c/d") + assert rules1 is not None self.assertEqual( rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c" ) rules2 = request_access.get_matching_rules("get", "f/b/c/d") + assert rules2 is not None self.assertEqual( rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c" ) rules3 = request_access.get_matching_rules("post", "a/b/c/d") + assert rules3 is not None self.assertEqual( rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b" ) @@ -165,26 +161,26 @@ def test_override_rule(self) -> None: guid2 = uuid.uuid4() request_access = RequestAccess.build( - [ - ApiAccessRule( + { + "a/b/c": ApiAccessRule( methods=["get"], - endpoint="a/b/c", allowed_groups=[guid1], ), - ApiAccessRule( + "a/b/c/d": ApiAccessRule( methods=["get"], - endpoint="a/b/c/d", allowed_groups=[guid2], ), - ] + } ) rules1 = request_access.get_matching_rules("get", "a/b/c") + assert rules1 is not None self.assertEqual( rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c" ) rules2 = request_access.get_matching_rules("get", "a/b/c/d") + assert rules2 is not None self.assertEqual( rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d" ) diff --git a/src/deployment/azuredeploy.json b/src/deployment/azuredeploy.json index f29d3f28dd..9930c7dd22 100644 --- a/src/deployment/azuredeploy.json +++ b/src/deployment/azuredeploy.json @@ -266,7 +266,7 @@ "value": "[parameters('owner')]" } ], - "linuxFxVersion": "Python|3.7", + "linuxFxVersion": "Python|3.8", "alwaysOn": true, "defaultDocuments": [], "httpLoggingEnabled": true, diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index f3a47b32f4..0d891a33e7 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -838,10 +838,15 @@ class AzureVmExtensionConfig(BaseModel): class ApiAccessRule(BaseModel): methods: List[str] - endpoint: str allowed_groups: List[UUID] +Endpoint = str +# json dumps doesn't support UUID as dictionary key +PrincipalID = str +GroupId = UUID + + class InstanceConfig(BaseModel): # initial set of admins can only be set during deployment. # if admins are set, only admins can update instance configs. @@ -857,6 +862,8 @@ class InstanceConfig(BaseModel): ) extensions: Optional[AzureVmExtensionConfig] proxy_vm_sku: str = Field(default="Standard_B2s") + api_access_rules: Optional[Dict[Endpoint, ApiAccessRule]] = None + group_membership: Optional[Dict[PrincipalID, List[GroupId]]] = None def update(self, config: "InstanceConfig") -> None: for field in config.__fields__: