Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port the performance improvements of fastbin back to pycasbin #285

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v2
with:
node-version: '16'
node-version: '18.14.0'

- name: Setup
run: npm install -g semantic-release @semantic-release/github @semantic-release/changelog @semantic-release/commit-analyzer @semantic-release/git @semantic-release/release-notes-generator semantic-release-pypi
Expand Down
23 changes: 19 additions & 4 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import logging
import copy
from typing import Sequence

from casbin.effect import Effector, get_effector, effect_to_bool
from casbin.model import Model, FunctionMap
from casbin.model import Model, FastModel, FunctionMap, fast_policy_filter
from casbin.persist import Adapter
from casbin.persist.adapters import FileAdapter
from casbin.rbac import default_role_manager
Expand Down Expand Up @@ -51,10 +52,13 @@ class CoreEnforcer:
auto_save = False
auto_build_role_links = False

def __init__(self, model=None, adapter=None):
_cache_key_order: Sequence[int] = None

def __init__(self, model=None, adapter=None, cache_key_order: Sequence[int] = None):
self.logger = logging.getLogger(__name__)
# if want to see more detail logs, change log level to info or debug
self.logger.setLevel(logging.WARNING)
CoreEnforcer._cache_key_order = cache_key_order
if isinstance(model, str):
if isinstance(adapter, str):
self.init_with_file(model, adapter)
Expand Down Expand Up @@ -113,7 +117,11 @@ def _initialize(self):
def new_model(path="", text=""):
"""creates a model."""

m = Model()
if CoreEnforcer._cache_key_order == None:
m = Model()
else:
m = FastModel(CoreEnforcer._cache_key_order)

if len(path) > 0:
m.load_model(path)
else:
Expand Down Expand Up @@ -320,7 +328,14 @@ def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
result, _ = self.enforce_ex(*rvals)

if CoreEnforcer._cache_key_order == None:
result, _ = self.enforce_ex(*rvals)
else:
keys = [rvals[x] for x in self._cache_key_order]
with fast_policy_filter(self.model.model["p"]["p"].policy, *keys):
result, _ = self.enforce_ex(*rvals)

return result

def enforce_ex(self, *rvals):
Expand Down
2 changes: 2 additions & 0 deletions casbin/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@

from .assertion import Assertion
from .model import Model
from .model_fast import FastModel
from .policy import Policy
from .policy_fast import FastPolicy, fast_policy_filter
from .function import FunctionMap
35 changes: 35 additions & 0 deletions casbin/model/model_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2021 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .policy_fast import FastPolicy
from typing import Any, Sequence
from .model import Model


class FastModel(Model):
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
super().__init__()
self._cache_key_order = cache_key_order

def add_def(self, sec: str, key: str, value: Any) -> None:
super().add_def(sec, key, value)
if sec == "p" and key == "p":
self.model[sec][key].policy = FastPolicy(self._cache_key_order)

def clear_policy(self) -> None:
"""clears all current policy."""
super().clear_policy()
self.model["p"]["p"].policy = FastPolicy(self._cache_key_order)
101 changes: 101 additions & 0 deletions casbin/model/policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2021 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast


def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]:
if keys[0] in cache:
if len(keys) > 1:
return in_cache(cache[keys[-0]], keys[1:])
return cast(Set[Sequence[str]], cache[keys[0]])
else:
return None


class FastPolicy(Container[Sequence[str]]):
_cache: Dict[str, Any]
_current_filter: Optional[Set[Sequence[str]]]
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
self._cache = {}
self._current_filter = None
self._cache_key_order = cache_key_order

def __iter__(self) -> Iterator[Sequence[str]]:
yield from self.__get_policy()

def __len__(self) -> int:
return len(list(self.__get_policy()))

def __contains__(self, item: object) -> bool:
if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item):
return False
keys = [item[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return False
return tuple(item) in exists

def __getitem__(self, item: int) -> Sequence[str]:
for i, entry in enumerate(self):
if i == item:
return entry
raise KeyError("No such value exists")

def append(self, item: Sequence[str]) -> None:
cache = self._cache
keys = [item[x] for x in self._cache_key_order]

for key in keys[:-1]:
if key not in cache:
cache[key] = dict()
cache = cache[key]
if keys[-1] not in cache:
cache[keys[-1]] = set()

cache[keys[-1]].add(tuple(item))

def remove(self, policy: Sequence[str]) -> bool:
keys = [policy[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return True

exists.remove(tuple(policy))
return True

def __get_policy(self) -> Iterable[Sequence[str]]:
if self._current_filter is not None:
return (list(x) for x in self._current_filter)
else:
return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1)

def apply_filter(self, *keys: str) -> None:
value = in_cache(self._cache, keys)
self._current_filter = value or set()

def clear_filter(self) -> None:
self._current_filter = None


@contextmanager
def fast_policy_filter(policy: FastPolicy, *keys: str) -> Iterator[None]:
try:
policy.apply_filter(*keys)
yield
finally:
policy.clear_filter()
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .test_frontend import TestFrontend
from .test_management_api import TestManagementApi, TestManagementApiSynced
from .test_rbac_api import TestRbacApi, TestRbacApiSynced
from .test_enforcer_fast import TestFastEnforcer
from . import benchmarks
from . import config
from . import model
Expand Down
1 change: 1 addition & 0 deletions tests/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .test_policy import TestPolicy
from .test_policy_fast import TestContextManager, TestFastPolicy
101 changes: 101 additions & 0 deletions tests/model/test_policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2021 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import TestCase

from casbin.model import FastPolicy, fast_policy_filter


class TestFastPolicy(TestCase):
def test_able_to_add_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_does_not_add_duplicates(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_can_remove_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.remove(["sub", "obj", "read"])

assert list(policy) == []

def test_returns_lengtt(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert len(policy) == 1

def test_supports_in_keyword(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert ["sub", "obj", "read"] in policy

def test_supports_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")

assert list(policy) == [["sub", "obj2", "read2"]]

def test_clears_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")
policy.clear_filter()

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]


class TestContextManager:
def test_fast_policy_filter(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

with fast_policy_filter(policy, "read2", "obj2"):
assert list(policy) == [["sub", "obj2", "read2"]]

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]
4 changes: 3 additions & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import time
from unittest import TestCase
from typing import Sequence

import casbin

Expand All @@ -31,10 +32,11 @@ def __init__(self, name, age):


class TestCaseBase(TestCase):
def get_enforcer(self, model=None, adapter=None):
def get_enforcer(self, model=None, adapter=None, cache_key_order: Sequence[int] = None):
return casbin.Enforcer(
model,
adapter,
cache_key_order,
)


Expand Down
Loading