diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e088b9b2..876f0471 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -63,7 +63,7 @@ jobs: VALIDATE_PYTHON_BLACK: true DEFAULT_BRANCH: master GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - LINTER_RULES_PATH: /casbin + LINTER_RULES_PATH: / PYTHON_BLACK_CONFIG_FILE: pyproject.toml coveralls: diff --git a/casbin/config/config.py b/casbin/config/config.py index ced352c0..3250943c 100644 --- a/casbin/config/config.py +++ b/casbin/config/config.py @@ -108,9 +108,7 @@ def _write(self, section, line_num, b): if len(option_val) != 2: raise RuntimeError( - "parse the content error : line {} , {} = ?".format( - line_num, option_val[0] - ) + "parse the content error : line {} , {} = ?".format(line_num, option_val[0]) ) option = option_val[0].strip() diff --git a/casbin/core_enforcer.py b/casbin/core_enforcer.py index dc8086f7..eb917cde 100644 --- a/casbin/core_enforcer.py +++ b/casbin/core_enforcer.py @@ -79,11 +79,7 @@ def init_with_adapter(self, model_path, adapter=None): def init_with_model_and_adapter(self, m, adapter=None): """initializes an enforcer with a model and a database adapter.""" - if ( - not isinstance(m, Model) - or adapter is not None - and not isinstance(adapter, Adapter) - ): + if not isinstance(m, Model) or adapter is not None and not isinstance(adapter, Adapter): raise RuntimeError("Invalid parameters for enforcer.") self.adapter = adapter @@ -388,8 +384,7 @@ def enforce_ex(self, *rvals): if exp_has_eval: rule_names = util.get_eval_value(exp_string) rules = [ - util.escape_assertion(p_parameters[rule_name]) - for rule_name in rule_names + util.escape_assertion(p_parameters[rule_name]) for rule_name in rule_names ] exp_with_rule = util.replace_eval(exp_string, rules) expression = self._get_expression(exp_with_rule, functions) @@ -419,10 +414,7 @@ def enforce_ex(self, *rvals): else: policy_effects.add(Effector.ALLOW) - if ( - self.eft.intermediate_effect(policy_effects) - != Effector.INDETERMINATE - ): + if self.eft.intermediate_effect(policy_effects) != Effector.INDETERMINATE: explain_index = i break diff --git a/casbin/distributed_enforcer.py b/casbin/distributed_enforcer.py index b01778d3..64cb3233 100644 --- a/casbin/distributed_enforcer.py +++ b/casbin/distributed_enforcer.py @@ -49,9 +49,7 @@ def add_policy_self(self, should_persist, sec, ptype, rules): if sec == "g": try: - self.build_incremental_role_links( - PolicyOp.Policy_add, ptype, no_exists_policy - ) + self.build_incremental_role_links(PolicyOp.Policy_add, ptype, no_exists_policy) except Exception as e: self.logger.log("An exception occurred: " + e) return no_exists_policy @@ -81,9 +79,7 @@ def remove_policy_self(self, should_persist, sec, ptype, rules): return effected - def remove_filtered_policy_self( - self, should_persist, sec, ptype, field_index, *field_values - ): + def remove_filtered_policy_self(self, should_persist, sec, ptype, field_index, *field_values): """ remove_filtered_policy_self provides a method for dispatcher to remove an authorization rule from the current policy,field filters can be specified. @@ -91,9 +87,7 @@ def remove_filtered_policy_self( """ if should_persist: try: - self.adapter.remove_filtered_policy( - sec, ptype, field_index, field_values - ) + self.adapter.remove_filtered_policy(sec, ptype, field_index, field_values) except Exception as e: self.logger.log("An exception occurred: " + e) @@ -103,9 +97,7 @@ def remove_filtered_policy_self( if sec == "g": try: - self.build_incremental_role_links( - PolicyOp.Policy_remove, ptype, effects - ) + self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, effects) except Exception as e: self.logger.log("An exception occurred: " + e) return effects @@ -143,16 +135,12 @@ def update_policy_self(self, should_persist, sec, ptype, old_rule, new_rule): if sec == "g": try: - self.build_incremental_role_links( - PolicyOp.Policy_remove, ptype, [old_rule] - ) + self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, [old_rule]) except Exception as e: return False try: - self.build_incremental_role_links( - PolicyOp.Policy_add, ptype, [new_rule] - ) + self.build_incremental_role_links(PolicyOp.Policy_add, ptype, [new_rule]) except Exception as e: return False diff --git a/casbin/enforcer.py b/casbin/enforcer.py index 8c3d01f3..19a32c5b 100644 --- a/casbin/enforcer.py +++ b/casbin/enforcer.py @@ -152,9 +152,7 @@ def get_implicit_roles_for_user(self, name, domain=""): return res - def get_implicit_permissions_for_user( - self, user, domain="", filter_policy_dom=True - ): + def get_implicit_permissions_for_user(self, user, domain="", filter_policy_dom=True): """ gets implicit permissions for a user or role. Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles. diff --git a/casbin/internal_enforcer.py b/casbin/internal_enforcer.py index a038a517..0df0abbe 100644 --- a/casbin/internal_enforcer.py +++ b/casbin/internal_enforcer.py @@ -88,14 +88,10 @@ def _update_policies(self, sec, ptype, old_rules, new_rules): return rules_updated - def _update_filtered_policies( - self, sec, ptype, new_rules, field_index, *field_values - ): + def _update_filtered_policies(self, sec, ptype, new_rules, field_index, *field_values): """_update_filtered_policies deletes old rules and adds new rules.""" - old_rules = self.model.get_filtered_policy( - sec, ptype, field_index, *field_values - ) + old_rules = self.model.get_filtered_policy(sec, ptype, field_index, *field_values) if self.adapter and self.auto_save: try: @@ -154,19 +150,12 @@ def _remove_policies(self, sec, ptype, rules): def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): """removes rules based on field filters from the current policy.""" - rule_removed = self.model.remove_filtered_policy( - sec, ptype, field_index, *field_values - ) + rule_removed = self.model.remove_filtered_policy(sec, ptype, field_index, *field_values) if not rule_removed: return rule_removed if self.adapter and self.auto_save: - if ( - self.adapter.remove_filtered_policy( - sec, ptype, field_index, *field_values - ) - is False - ): + if self.adapter.remove_filtered_policy(sec, ptype, field_index, *field_values) is False: return False if self.watcher: @@ -174,9 +163,7 @@ def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): return rule_removed - def _remove_filtered_policy_returns_effects( - self, sec, ptype, field_index, *field_values - ): + def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values): """removes rules based on field filters from the current policy.""" rule_removed = self.model.remove_filtered_policy_returns_effects( sec, ptype, field_index, *field_values @@ -185,12 +172,7 @@ def _remove_filtered_policy_returns_effects( return rule_removed if self.adapter and self.auto_save: - if ( - self.adapter.remove_filtered_policy( - sec, ptype, field_index, *field_values - ) - is False - ): + if self.adapter.remove_filtered_policy(sec, ptype, field_index, *field_values) is False: return False if self.watcher: diff --git a/casbin/management_enforcer.py b/casbin/management_enforcer.py index e90b286b..7f8d4d43 100644 --- a/casbin/management_enforcer.py +++ b/casbin/management_enforcer.py @@ -153,17 +153,11 @@ def update_named_policies(self, ptype, old_rules, new_rules): def update_filtered_policies(self, new_rules, field_index, *field_values): """update_filtered_policies deletes old rules and adds new rules.""" - return self.update_filtered_named_policies( - "p", new_rules, field_index, *field_values - ) + return self.update_filtered_named_policies("p", new_rules, field_index, *field_values) - def update_filtered_named_policies( - self, ptype, new_rules, field_index, *field_values - ): + def update_filtered_named_policies(self, ptype, new_rules, field_index, *field_values): """update_filtered_named_policies deletes old rules and adds new rules.""" - return self._update_filtered_policies( - "p", ptype, new_rules, field_index, *field_values - ) + return self._update_filtered_policies("p", ptype, new_rules, field_index, *field_values) def remove_policy(self, *params): """removes an authorization rule from the current policy.""" @@ -271,9 +265,7 @@ def remove_grouping_policies(self, rules): def remove_filtered_grouping_policy(self, field_index, *field_values): """removes a role inheritance rule from the current policy, field filters can be specified.""" - return self.remove_filtered_named_grouping_policy( - "g", field_index, *field_values - ) + return self.remove_filtered_named_grouping_policy("g", field_index, *field_values) def remove_named_grouping_policy(self, ptype, *params): """removes a role inheritance rule from the current named policy.""" diff --git a/casbin/model/assertion.py b/casbin/model/assertion.py index 174cbe31..faf23976 100644 --- a/casbin/model/assertion.py +++ b/casbin/model/assertion.py @@ -31,15 +31,11 @@ def build_role_links(self, rm): self.rm = rm count = self.value.count("_") if count < 2: - raise RuntimeError( - 'the number of "_" in role definition should be at least 2' - ) + raise RuntimeError('the number of "_" in role definition should be at least 2') for rule in self.policy: if len(rule) < count: - raise RuntimeError( - "grouping policy elements do not meet role definition" - ) + raise RuntimeError("grouping policy elements do not meet role definition") if len(rule) > count: rule = rule[:count] @@ -52,9 +48,7 @@ def build_incremental_role_links(self, rm, op, rules): self.rm = rm count = self.value.count("_") if count < 2: - raise RuntimeError( - 'the number of "_" in role definition should be at least 2' - ) + raise RuntimeError('the number of "_" in role definition should be at least 2') for rule in rules: if len(rule) < count: raise TypeError("grouping policy elements do not meet role definition") diff --git a/casbin/model/model.py b/casbin/model/model.py index c0b8d9c5..a65509b6 100644 --- a/casbin/model/model.py +++ b/casbin/model/model.py @@ -107,9 +107,7 @@ def sort_policies_by_priority(self): if assertion.priority_index == -1: continue - assertion.policy = sorted( - assertion.policy, key=lambda x: x[assertion.priority_index] - ) + assertion.policy = sorted(assertion.policy, key=lambda x: x[assertion.priority_index]) for i, policy in enumerate(assertion.policy): assertion.policy_map[",".join(policy)] = i @@ -128,9 +126,7 @@ def sort_policies_by_subject_hierarchy(self): domain_index = index break - subject_hierarchy_map = self.get_subject_hierarchy_map( - self["g"]["g"].policy - ) + subject_hierarchy_map = self.get_subject_hierarchy_map(self["g"]["g"].policy) def compare_policy(policy): domain = DEFAULT_DOMAIN @@ -139,9 +135,7 @@ def compare_policy(policy): name = self.get_name_with_domain(domain, policy[sub_index]) return subject_hierarchy_map[name] - assertion.policy = sorted( - assertion.policy, key=compare_policy, reverse=True - ) + assertion.policy = sorted(assertion.policy, key=compare_policy, reverse=True) for i, policy in enumerate(assertion.policy): assertion.policy_map[",".join(policy)] = i @@ -194,11 +188,7 @@ def to_text(self): def write_string(sec): for p_type in self[sec]: value = self[sec][p_type].value - s.append( - "{} = {}\n".format( - sec, value.replace("p_", "p.").replace("r_", "r.") - ) - ) + s.append("{} = {}\n".format(sec, value.replace("p_", "p.").replace("r_", "r."))) s.append("[request_definition]\n") write_string("r") diff --git a/casbin/model/policy.py b/casbin/model/policy.py index 7c77cc74..2807450b 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -195,9 +195,7 @@ def update_policies(self, sec, ptype, old_rules, new_rules): if old_rule[priority_index] == new_rule[priority_index]: ast.policy[idx] = new_rule else: - raise Exception( - "New rule should have the same priority with old rule." - ) + raise Exception("New rule should have the same priority with old rule.") else: for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): ast.policy[idx] = new_rule @@ -234,9 +232,7 @@ def remove_policies_with_effected(self, sec, ptype, rules): return effected - def remove_filtered_policy_returns_effects( - self, sec, ptype, field_index, *field_values - ): + def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values): """ remove_filtered_policy_returns_effects removes policy rules based on field filters from the model. """ diff --git a/casbin/persist/adapters/update_adapter.py b/casbin/persist/adapters/update_adapter.py index 1ac09a3f..94a8c016 100644 --- a/casbin/persist/adapters/update_adapter.py +++ b/casbin/persist/adapters/update_adapter.py @@ -29,9 +29,7 @@ def update_policies(self, sec, ptype, old_rules, new_rules): """ pass - def update_filtered_policies( - self, sec, ptype, new_rules, field_index, *field_values - ): + def update_filtered_policies(self, sec, ptype, new_rules, field_index, *field_values): """ update_filtered_policies deletes old rules and adds new rules. """ diff --git a/casbin/rbac/default_role_manager/role_manager.py b/casbin/rbac/default_role_manager/role_manager.py index ffc48688..4a4547e8 100644 --- a/casbin/rbac/default_role_manager/role_manager.py +++ b/casbin/rbac/default_role_manager/role_manager.py @@ -97,9 +97,9 @@ def _matching_fn(self, str1, str2, match_order=MatchOrder.STR_PATTERN): if match_order == MatchOrder.PATTERN_STR: return match_error_handler(self.matching_func, str2, str1) elif match_order == MatchOrder.PATTERN_PATTERN: - return match_error_handler( - self.matching_func, str1, str2 - ) or match_error_handler(self.matching_func, str2, str1) + return match_error_handler(self.matching_func, str1, str2) or match_error_handler( + self.matching_func, str2, str1 + ) else: # match_order == MatchOrder.STR_PATTERN return match_error_handler(self.matching_func, str1, str2) @@ -153,9 +153,7 @@ def add_link(self, name1, name2, *domain): def delete_link(self, name1, name2, *domain): if Link(name1, name2) not in self.all_links: - raise RuntimeError( - f"error: link between {name1} and {name2} does not exist" - ) + raise RuntimeError(f"error: link between {name1} and {name2} does not exist") self.all_links.remove(Link(name1, name2)) user = self._get_role(name1) @@ -163,13 +161,9 @@ def delete_link(self, name1, name2, *domain): user.remove_role(role) for r in self.all_roles.values(): - if r.name != user.name and self._matching_fn( - user.name, r.name, MatchOrder.PATTERN_STR - ): + if r.name != user.name and self._matching_fn(user.name, r.name, MatchOrder.PATTERN_STR): r.remove_role(role) - if r.name != role.name and self._matching_fn( - role.name, r.name, MatchOrder.PATTERN_STR - ): + if r.name != role.name and self._matching_fn(role.name, r.name, MatchOrder.PATTERN_STR): role.remove_role(r) def has_link(self, name1, name2, *domain): @@ -270,9 +264,7 @@ def add_link(self, name1, name2, *domain): def delete_link(self, name1, name2, *domain): links = self._get_links(*domain) if Link(name1, name2) not in links: - raise RuntimeError( - f"error: link between {name1} and {name2} does not exist" - ) + raise RuntimeError(f"error: link between {name1} and {name2} does not exist") links.remove(Link(name1, name2)) def has_link(self, name1, name2, *domain): @@ -313,14 +305,10 @@ def _affected_role_managers(self, *domain): return [ self.rm_map[domain_str] for domain_str in self.rm_map.keys() - if match_error_handler( - self.domain_matching_func, domain_str, domain_pattern - ) + if match_error_handler(self.domain_matching_func, domain_str, domain_pattern) ] else: - return ( - [self.rm_map[domain_pattern]] if domain_pattern in self.rm_map else [] - ) + return [self.rm_map[domain_pattern]] if domain_pattern in self.rm_map else [] def add_matching_func(self, fn): super().add_matching_func(fn) diff --git a/casbin/synced_enforcer.py b/casbin/synced_enforcer.py index 716e9680..fd4cbd29 100644 --- a/casbin/synced_enforcer.py +++ b/casbin/synced_enforcer.py @@ -235,9 +235,7 @@ def get_named_grouping_policy(self, ptype): def get_filtered_named_grouping_policy(self, ptype, field_index, *field_values): """gets all the role inheritance rules in the policy, field filters can be specified.""" with self._rl: - return self._e.get_filtered_named_grouping_policy( - ptype, field_index, *field_values - ) + return self._e.get_filtered_named_grouping_policy(ptype, field_index, *field_values) def has_policy(self, *params): """determines whether an authorization rule exists.""" @@ -283,9 +281,7 @@ def remove_named_policy(self, ptype, *params): def remove_filtered_named_policy(self, ptype, field_index, *field_values): """removes an authorization rule from the current named policy, field filters can be specified.""" with self._wl: - return self._e.remove_filtered_named_policy( - ptype, field_index, *field_values - ) + return self._e.remove_filtered_named_policy(ptype, field_index, *field_values) def has_grouping_policy(self, *params): """determines whether a role inheritance rule exists.""" @@ -331,9 +327,7 @@ def remove_named_grouping_policy(self, ptype, *params): def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values): """removes a role inheritance rule from the current named policy, field filters can be specified.""" with self._wl: - return self._e.remove_filtered_named_grouping_policy( - ptype, field_index, *field_values - ) + return self._e.remove_filtered_named_grouping_policy(ptype, field_index, *field_values) def add_function(self, name, func): """adds a customized function.""" diff --git a/casbin/pyproject.toml b/pyproject.toml similarity index 100% rename from casbin/pyproject.toml rename to pyproject.toml diff --git a/tests/benchmarks/benchmark_model.py b/tests/benchmarks/benchmark_model.py index 5fb17a16..7d65c98a 100644 --- a/tests/benchmarks/benchmark_model.py +++ b/tests/benchmarks/benchmark_model.py @@ -62,12 +62,8 @@ def benchmark_rbac_model(): def test_benchmark_rbac_model_small(benchmark): e = get_enforcer(get_examples("rbac_model.conf")) - e.add_policies( - {("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(100)} - ) - e.add_grouping_policies( - {("user" + str(i), "group" + str(int(i / 10))) for i in range(1000)} - ) + e.add_policies({("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(100)}) + e.add_grouping_policies({("user" + str(i), "group" + str(int(i / 10))) for i in range(1000)}) @benchmark def benchmark_rbac_model(): @@ -77,12 +73,8 @@ def benchmark_rbac_model(): def test_benchmark_rbac_model_medium(benchmark): e = get_enforcer(get_examples("rbac_model.conf")) - e.add_policies( - {("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(1000)} - ) - e.add_grouping_policies( - {("user" + str(i), "group" + str(int(i / 10))) for i in range(10000)} - ) + e.add_policies({("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(1000)}) + e.add_grouping_policies({("user" + str(i), "group" + str(int(i / 10))) for i in range(10000)}) @benchmark def benchmark_rbac_model(): @@ -92,12 +84,8 @@ def benchmark_rbac_model(): def test_benchmark_rbac_model_large(benchmark): e = get_enforcer(get_examples("rbac_model.conf")) - e.add_policies( - {("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(10000)} - ) - e.add_grouping_policies( - {("user" + str(i), "group" + str(int(i / 10))) for i in range(100000)} - ) + e.add_policies({("group" + str(i), "data" + str(int(i / 10)), "read") for i in range(10000)}) + e.add_grouping_policies({("user" + str(i), "group" + str(int(i / 10))) for i in range(100000)}) @benchmark def benchmark_rbac_model(): @@ -148,9 +136,7 @@ def benchmark_rbac_with_deny(): def test_benchmark_prioriry(benchmark): - e = get_enforcer( - get_examples("priority_model.conf"), get_examples("priority_policy.csv") - ) + e = get_enforcer(get_examples("priority_model.conf"), get_examples("priority_policy.csv")) @benchmark def benchmark_rbac_with_deny(): @@ -158,9 +144,7 @@ def benchmark_rbac_with_deny(): def test_benchmark_keymatch(benchmark): - e = get_enforcer( - get_examples("keymatch_model.conf"), get_examples("keymatch_policy.csv") - ) + e = get_enforcer(get_examples("keymatch_model.conf"), get_examples("keymatch_policy.csv")) @benchmark def benchmark_keymatch(): diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 96c41bec..3d9d80e2 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -54,12 +54,8 @@ def test_new_config(self): self.assertEqual(config.get("multi5::name"), "r.sub==p.sub && r.obj==p.obj") self.assertEqual(config.get_bool("multi5::name"), False) - self.assertEqual( - config.get_string("multi5::name"), "r.sub==p.sub && r.obj==p.obj" - ) - self.assertEqual( - config.get_strings("multi5::name"), ["r.sub==p.sub && r.obj==p.obj"] - ) + self.assertEqual(config.get_string("multi5::name"), "r.sub==p.sub && r.obj==p.obj") + self.assertEqual(config.get_strings("multi5::name"), ["r.sub==p.sub && r.obj==p.obj"]) with self.assertRaises(ValueError): config.get_int("multi5::name") with self.assertRaises(ValueError): diff --git a/tests/rbac/test_role_manager.py b/tests/rbac/test_role_manager.py index 9e0327ef..40c2c1f0 100644 --- a/tests/rbac/test_role_manager.py +++ b/tests/rbac/test_role_manager.py @@ -95,9 +95,7 @@ def test_role(self): rm.clear() - match_fn = ( - lambda name1, name2: True if re.match("^" + name2 + "$", name1) else False - ) + match_fn = lambda name1, name2: True if re.match("^" + name2 + "$", name1) else False rm.add_matching_func(match_fn) @@ -125,9 +123,7 @@ def test_role(self): self.assertTrue(rm.has_link("g2", "any_group")) self.assertEqual(sorted(rm.get_roles("u1")), sorted(["g1", "any_user"])) - self.assertEqual( - sorted(rm.get_roles("u2")), sorted(["g2", "g1", r"g\d+", "any_user"]) - ) + self.assertEqual(sorted(rm.get_roles("u2")), sorted(["g2", "g1", r"g\d+", "any_user"])) self.assertEqual(rm.get_roles(r"u\d+"), ["any_user"]) self.assertEqual(rm.get_roles("u3"), ["any_user"]) self.assertEqual(rm.get_roles("g1"), ["any_group"]) @@ -307,9 +303,7 @@ def test_domain_role(self): self.assertTrue(rm.has_link("u4", "admin", "domain2")) rm.clear() - match_fn = ( - lambda name1, name2: True if re.match("^" + name2 + "$", name1) else False - ) + match_fn = lambda name1, name2: True if re.match("^" + name2 + "$", name1) else False rm.add_domain_matching_func(match_fn) rm.add_link("alice", "user", ".*") diff --git a/tests/test_distributed_api.py b/tests/test_distributed_api.py index 84ef9900..f93358f3 100644 --- a/tests/test_distributed_api.py +++ b/tests/test_distributed_api.py @@ -24,9 +24,7 @@ def get_enforcer(self, model=None, adapter=None): ) def test(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.add_policy_self( False, @@ -53,9 +51,7 @@ def test(self): e.update_policy_self( False, "p", "p", ["alice", "data1", "read"], ["alice", "data1", "write"] ) - e.update_policy_self( - False, "g", "g", ["alice", "data2_admin"], ["tom", "alice"] - ) + e.update_policy_self(False, "g", "g", ["alice", "data2_admin"], ["tom", "alice"]) self.assertFalse(e.enforce("alice", "data1", "read")) self.assertTrue(e.enforce("alice", "data1", "write")) diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 45f9022c..58e337ad 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -148,14 +148,10 @@ def test_enforce_key_match_custom_model(self): ) def custom_function(key1, key2): - if ( - key1 == "/alice_data2/myid/using/res_id" - and key2 == "/alice_data/:resource" - ): + if key1 == "/alice_data2/myid/using/res_id" and key2 == "/alice_data/:resource": return True elif ( - key1 == "/alice_data2/myid/using/res_id" - and key2 == "/alice_data2/:id/using/:resId" + key1 == "/alice_data2/myid/using/res_id" and key2 == "/alice_data2/:id/using/:resId" ): return True return False @@ -248,22 +244,16 @@ def test_multiple_policy_definitions(self): self.assertFalse(e.enforce(enforce_context, sub1, "/data2", "read")) def test_enforce_rbac(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) self.assertTrue(e.enforce("alice", "data1", "read")) self.assertFalse(e.enforce("bob", "data1", "read")) self.assertTrue(e.enforce("bob", "data2", "write")) self.assertTrue(e.enforce("alice", "data2", "read")) self.assertTrue(e.enforce("alice", "data2", "write")) - self.assertFalse( - e.enforce("bogus", "data2", "write") - ) # test non-existant subject + self.assertFalse(e.enforce("bogus", "data2", "write")) # test non-existant subject def test_enforce_rbac_empty_policy(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("empty_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("empty_policy.csv")) self.assertFalse(e.enforce("alice", "data1", "read")) self.assertFalse(e.enforce("bob", "data1", "read")) self.assertFalse(e.enforce("bob", "data2", "write")) diff --git a/tests/test_filter.py b/tests/test_filter.py index 6633b229..26eeb333 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -145,9 +145,7 @@ def test_filtered_adapter_empty_filepath(self): e.load_filtered_policy(None) def test_filtered_adapter_invalid_filepath(self): - adapter = casbin.persist.adapters.FilteredAdapter( - get_examples("does_not_exist_policy.csv") - ) + adapter = casbin.persist.adapters.FilteredAdapter(get_examples("does_not_exist_policy.csv")) e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) with self.assertRaises(RuntimeError): diff --git a/tests/test_management_api.py b/tests/test_management_api.py index 3e61d204..f785b345 100644 --- a/tests/test_management_api.py +++ b/tests/test_management_api.py @@ -52,17 +52,13 @@ def test_get_policy_api(self): ], ) - self.assertEqual( - e.get_filtered_policy(0, "alice"), [["alice", "data1", "read"]] - ) + self.assertEqual(e.get_filtered_policy(0, "alice"), [["alice", "data1", "read"]]) self.assertEqual(e.get_filtered_policy(0, "bob"), [["bob", "data2", "write"]]) self.assertEqual( e.get_filtered_policy(0, "data2_admin"), [["data2_admin", "data2", "read"], ["data2_admin", "data2", "write"]], ) - self.assertEqual( - e.get_filtered_policy(1, "data1"), [["alice", "data1", "read"]] - ) + self.assertEqual(e.get_filtered_policy(1, "data1"), [["alice", "data1", "read"]]) self.assertEqual( e.get_filtered_policy(1, "data2"), [ @@ -99,9 +95,7 @@ def test_get_policy_api(self): self.assertFalse(e.has_policy(["alice", "data2", "read"])) self.assertFalse(e.has_policy(["bob", "data3", "write"])) self.assertEqual(e.get_grouping_policy(), [["alice", "data2_admin"]]) - self.assertEqual( - e.get_filtered_grouping_policy(0, "alice"), [["alice", "data2_admin"]] - ) + self.assertEqual(e.get_filtered_grouping_policy(0, "alice"), [["alice", "data2_admin"]]) self.assertEqual(e.get_filtered_grouping_policy(0, "bob"), []) self.assertEqual(e.get_filtered_grouping_policy(1, "data1_admin"), []) self.assertEqual( @@ -198,11 +192,7 @@ def test_get_policy_multiple_matching_functions(self): km2_fn = casbin.util.key_match2_func self.assertEqual( - sorted( - e.get_filtered_policy( - 1, partial(km2_fn, "domain.2"), lambda a: "data" in a - ) - ), + sorted(e.get_filtered_policy(1, partial(km2_fn, "domain.2"), lambda a: "data" in a)), sorted( [ ["admin", "domain.*", "data1", "read"], @@ -213,9 +203,7 @@ def test_get_policy_multiple_matching_functions(self): self.assertEqual( sorted( - e.get_filtered_policy( - 1, partial(km2_fn, "domain.1"), lambda a: "data" in a, "read" - ) + e.get_filtered_policy(1, partial(km2_fn, "domain.1"), lambda a: "data" in a, "read") ), sorted( [ @@ -227,11 +215,7 @@ def test_get_policy_multiple_matching_functions(self): ) self.assertEqual( - sorted( - e.get_filtered_policy( - 1, partial(km2_fn, "domain.1"), "", "reading".startswith - ) - ), + sorted(e.get_filtered_policy(1, partial(km2_fn, "domain.1"), "", "reading".startswith)), sorted( [ ["admin", "domain.*", "data1", "read"], diff --git a/tests/test_rbac_api.py b/tests/test_rbac_api.py index 62c18b30..1d0a78a4 100644 --- a/tests/test_rbac_api.py +++ b/tests/test_rbac_api.py @@ -18,9 +18,7 @@ class TestRbacApi(TestCaseBase): def test_get_roles_for_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) self.assertEqual(e.get_roles_for_user("alice"), ["data2_admin"]) self.assertEqual(e.get_roles_for_user("bob"), []) @@ -28,24 +26,18 @@ def test_get_roles_for_user(self): self.assertEqual(e.get_roles_for_user("non_exist"), []) def test_get_users_for_role(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) self.assertEqual(e.get_users_for_role("data2_admin"), ["alice"]) def test_has_role_for_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) self.assertTrue(e.has_role_for_user("alice", "data2_admin")) self.assertFalse(e.has_role_for_user("alice", "data1_admin")) def test_add_role_for_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.add_role_for_user("alice", "data1_admin") self.assertEqual( sorted(e.get_roles_for_user("alice")), @@ -53,9 +45,7 @@ def test_add_role_for_user(self): ) def test_delete_role_for_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.add_role_for_user("alice", "data1_admin") self.assertEqual( sorted(e.get_roles_for_user("alice")), @@ -66,23 +56,17 @@ def test_delete_role_for_user(self): self.assertEqual(e.get_roles_for_user("alice"), ["data2_admin"]) def test_delete_roles_for_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.delete_roles_for_user("alice") self.assertEqual(e.get_roles_for_user("alice"), []) def test_delete_user(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.delete_user("alice") self.assertEqual(e.get_roles_for_user("alice"), []) def test_delete_role(self): - e = self.get_enforcer( - get_examples("rbac_model.conf"), get_examples("rbac_policy.csv") - ) + e = self.get_enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv")) e.delete_role("data2_admin") self.assertTrue(e.enforce("alice", "data1", "read")) self.assertFalse(e.enforce("alice", "data1", "write")) @@ -183,9 +167,7 @@ def test_enforce_implicit_roles_with_domain(self): get_examples("rbac_with_hierarchy_with_domains_policy.csv"), ) - self.assertEqual( - e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"] - ) + self.assertEqual(e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"]) self.assertEqual( sorted(e.get_implicit_roles_for_user("alice", "domain1")), sorted(["role:global_admin", "role:reader", "role:writer"]), @@ -228,9 +210,7 @@ def test_enforce_implicit_permissions_api_with_domain(self): get_examples("rbac_with_hierarchy_with_domains_policy.csv"), ) - self.assertEqual( - e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"] - ) + self.assertEqual(e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"]) self.assertEqual( sorted(e.get_implicit_roles_for_user("alice", "domain1")), sorted(["role:global_admin", "role:reader", "role:writer"]), @@ -280,9 +260,7 @@ def test_enforce_implicit_permissions_api_with_domain_matching_function(self): [], ) - self.assertEqual( - sorted(e.get_implicit_permissions_for_user("bob", "domain.1")), [] - ) + self.assertEqual(sorted(e.get_implicit_permissions_for_user("bob", "domain.1")), []) def test_enforce_implicit_permissions_api_with_domain_ignore_domain_policies_filter( self, @@ -292,18 +270,14 @@ def test_enforce_implicit_permissions_api_with_domain_ignore_domain_policies_fil get_examples("rbac_with_hierarchy_without_policy_domains.csv"), ) - self.assertEqual( - e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"] - ) + self.assertEqual(e.get_roles_for_user_in_domain("alice", "domain1"), ["role:global_admin"]) self.assertEqual( sorted(e.get_implicit_roles_for_user("alice", "domain1")), sorted(["role:global_admin", "role:reader", "role:writer"]), ) self.assertEqual( sorted( - e.get_implicit_permissions_for_user( - "alice", "domain1", filter_policy_dom=False - ) + e.get_implicit_permissions_for_user("alice", "domain1", filter_policy_dom=False) ), sorted( [ @@ -384,18 +358,10 @@ def test_implicit_user_api(self): get_examples("rbac_with_hierarchy_policy.csv"), ) - self.assertEqual( - ["alice"], e.get_implicit_users_for_permission("data1", "read") - ) - self.assertEqual( - ["alice"], e.get_implicit_users_for_permission("data1", "write") - ) - self.assertEqual( - ["alice"], e.get_implicit_users_for_permission("data2", "read") - ) - self.assertEqual( - ["alice", "bob"], e.get_implicit_users_for_permission("data2", "write") - ) + self.assertEqual(["alice"], e.get_implicit_users_for_permission("data1", "read")) + self.assertEqual(["alice"], e.get_implicit_users_for_permission("data1", "write")) + self.assertEqual(["alice"], e.get_implicit_users_for_permission("data2", "read")) + self.assertEqual(["alice", "bob"], e.get_implicit_users_for_permission("data2", "write")) def test_domain_match_model(self): e = self.get_enforcer( diff --git a/tests/util/test_builtin_operators.py b/tests/util/test_builtin_operators.py index e981f8e8..c8a4c9cc 100644 --- a/tests/util/test_builtin_operators.py +++ b/tests/util/test_builtin_operators.py @@ -36,31 +36,23 @@ def test_key_match2(self): self.assertTrue(util.key_match2_func("/foo", "/foo")) self.assertTrue(util.key_match2_func("/foo", "/foo*")) self.assertFalse(util.key_match2_func("/foo", "/foo/*")) - self.assertFalse( - util.key_match2_func("/foo/bar", "/foo") - ) # different with KeyMatch. + self.assertFalse(util.key_match2_func("/foo/bar", "/foo")) # different with KeyMatch. self.assertFalse(util.key_match2_func("/foo/bar", "/foo*")) self.assertTrue(util.key_match2_func("/foo/bar", "/foo/*")) - self.assertFalse( - util.key_match2_func("/foobar", "/foo") - ) # different with KeyMatch. + self.assertFalse(util.key_match2_func("/foobar", "/foo")) # different with KeyMatch. self.assertFalse(util.key_match2_func("/foobar", "/foo*")) self.assertFalse(util.key_match2_func("/foobar", "/foo/*")) self.assertFalse(util.key_match2_func("/", "/:resource")) self.assertTrue(util.key_match2_func("/resource1", "/:resource")) self.assertFalse(util.key_match2_func("/myid", "/:id/using/:resId")) - self.assertTrue( - util.key_match2_func("/myid/using/myresid", "/:id/using/:resId") - ) + self.assertTrue(util.key_match2_func("/myid/using/myresid", "/:id/using/:resId")) self.assertFalse(util.key_match2_func("/proxy/myid", "/proxy/:id/*")) self.assertTrue(util.key_match2_func("/proxy/myid/", "/proxy/:id/*")) self.assertTrue(util.key_match2_func("/proxy/myid/res", "/proxy/:id/*")) self.assertTrue(util.key_match2_func("/proxy/myid/res/res2", "/proxy/:id/*")) - self.assertTrue( - util.key_match2_func("/proxy/myid/res/res2/res3", "/proxy/:id/*") - ) + self.assertTrue(util.key_match2_func("/proxy/myid/res/res2/res3", "/proxy/:id/*")) self.assertFalse(util.key_match2_func("/proxy/", "/proxy/:id/*")) self.assertTrue(util.key_match2_func("/alice", "/:id")) @@ -84,46 +76,30 @@ def test_key_match3(self): self.assertFalse(util.key_match3_func("/", "/{resource}")) self.assertTrue(util.key_match3_func("/resource1", "/{resource}")) self.assertFalse(util.key_match3_func("/myid", "/{id}/using/{resId}")) - self.assertTrue( - util.key_match3_func("/myid/using/myresid", "/{id}/using/{resId}") - ) + self.assertTrue(util.key_match3_func("/myid/using/myresid", "/{id}/using/{resId}")) self.assertFalse(util.key_match3_func("/proxy/myid", "/proxy/{id}/*")) self.assertTrue(util.key_match3_func("/proxy/myid/", "/proxy/{id}/*")) self.assertTrue(util.key_match3_func("/proxy/myid/res", "/proxy/{id}/*")) self.assertTrue(util.key_match3_func("/proxy/myid/res/res2", "/proxy/{id}/*")) - self.assertTrue( - util.key_match3_func("/proxy/myid/res/res2/res3", "/proxy/{id}/*") - ) + self.assertTrue(util.key_match3_func("/proxy/myid/res/res2/res3", "/proxy/{id}/*")) self.assertFalse(util.key_match3_func("/proxy/", "/proxy/{id}/*")) - self.assertFalse( - util.key_match3_func("/myid/using/myresid", "/{id/using/{resId}") - ) + self.assertFalse(util.key_match3_func("/myid/using/myresid", "/{id/using/{resId}")) def test_key_match4(self): - self.assertTrue( - util.key_match4_func("/parent/123/child/123", "/parent/{id}/child/{id}") - ) - self.assertFalse( - util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{id}") - ) + self.assertTrue(util.key_match4_func("/parent/123/child/123", "/parent/{id}/child/{id}")) + self.assertFalse(util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{id}")) self.assertTrue( - util.key_match4_func( - "/parent/123/child/123", "/parent/{id}/child/{another_id}" - ) + util.key_match4_func("/parent/123/child/123", "/parent/{id}/child/{another_id}") ) self.assertTrue( - util.key_match4_func( - "/parent/123/child/456", "/parent/{id}/child/{another_id}" - ) + util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{another_id}") ) self.assertTrue( - util.key_match4_func( - "/parent/123/child/456", "/parent/{id}/child/{another_id}" - ) + util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{another_id}") ) self.assertFalse( util.key_match4_func( @@ -136,19 +112,13 @@ def test_key_match4(self): ) ) self.assertFalse( - util.key_match4_func( - "/parent/123/child/456/book/", "/parent/{id}/child/{id}/book/{id}" - ) + util.key_match4_func("/parent/123/child/456/book/", "/parent/{id}/child/{id}/book/{id}") ) self.assertFalse( - util.key_match4_func( - "/parent/123/child/456", "/parent/{id}/child/{id}/book/{id}" - ) + util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{id}/book/{id}") ) - self.assertFalse( - util.key_match4_func("/parent/123/child/123", "/parent/{i/d}/child/{i/d}") - ) + self.assertFalse(util.key_match4_func("/parent/123/child/123", "/parent/{i/d}/child/{i/d}")) def test_regex_match(self): self.assertTrue(util.regex_match_func("/topic/create", "/topic/create")) @@ -157,15 +127,9 @@ def test_regex_match(self): self.assertFalse(util.regex_match_func("/topic/edit", "/topic/edit/[0-9]+")) self.assertTrue(util.regex_match_func("/topic/edit/123", "/topic/edit/[0-9]+")) self.assertFalse(util.regex_match_func("/topic/edit/abc", "/topic/edit/[0-9]+")) - self.assertFalse( - util.regex_match_func("/foo/delete/123", "/topic/delete/[0-9]+") - ) - self.assertTrue( - util.regex_match_func("/topic/delete/0", "/topic/delete/[0-9]+") - ) - self.assertFalse( - util.regex_match_func("/topic/edit/123s", "/topic/delete/[0-9]+") - ) + self.assertFalse(util.regex_match_func("/foo/delete/123", "/topic/delete/[0-9]+")) + self.assertTrue(util.regex_match_func("/topic/delete/0", "/topic/delete/[0-9]+")) + self.assertFalse(util.regex_match_func("/topic/edit/123s", "/topic/delete/[0-9]+")) def test_glob_match(self): self.assertTrue(util.glob_match_func("/foo", "/foo")) @@ -205,9 +169,7 @@ def test_glob_match2(self): self.assertFalse(util.glob_match_func("/foo", "*/foo/*")) self.assertFalse(util.glob_match_func("/foo/bar", "*/foo")) self.assertFalse(util.glob_match_func("/foo/bar", "*/foo*")) - self.assertFalse( - util.glob_match_func("/foo/bar", "*/foo/*") - ) # different from Go + self.assertFalse(util.glob_match_func("/foo/bar", "*/foo/*")) # different from Go self.assertFalse(util.glob_match_func("/foobar", "*/foo")) self.assertFalse(util.glob_match_func("/foobar", "*/foo*")) # different from Go self.assertFalse(util.glob_match_func("/foobar", "*/foo/*")) diff --git a/tests/util/test_util.py b/tests/util/test_util.py index c60ff230..c914eb34 100644 --- a/tests/util/test_util.py +++ b/tests/util/test_util.py @@ -18,28 +18,20 @@ class TestUtil(TestCase): def test_remove_comments(self): - self.assertEqual( - util.remove_comments("r.act == p.act # comments"), "r.act == p.act" - ) - self.assertEqual( - util.remove_comments("r.act == p.act#comments"), "r.act == p.act" - ) + self.assertEqual(util.remove_comments("r.act == p.act # comments"), "r.act == p.act") + self.assertEqual(util.remove_comments("r.act == p.act#comments"), "r.act == p.act") self.assertEqual(util.remove_comments("r.act == p.act###"), "r.act == p.act") self.assertEqual(util.remove_comments("### comments"), "") self.assertEqual(util.remove_comments("r.act == p.act"), "r.act == p.act") def test_escape_assertion(self): self.assertEqual( - util.escape_assertion( - "m = r.sub == p.sub && r.obj == p.obj && r.act == p.act" - ), + util.escape_assertion("m = r.sub == p.sub && r.obj == p.obj && r.act == p.act"), "m = r_sub == p_sub && r_obj == p_obj && r_act == p_act", ) def test_array_remove_duplicates(self): - res = util.array_remove_duplicates( - ["data", "data1", "data2", "data1", "data2", "data3"] - ) + res = util.array_remove_duplicates(["data", "data1", "data2", "data1", "data2", "data3"]) self.assertEqual(res, ["data", "data1", "data2", "data3"]) def test_array_to_string(self): @@ -74,15 +66,9 @@ def test_replace_eval(self): def test_get_eval_value(self): self.assertEqual(util.get_eval_value("eval(a) && a && b && c"), ["a"]) self.assertEqual(util.get_eval_value("a && eval(a) && b && c"), ["a"]) + self.assertEqual(util.get_eval_value("eval(a) && eval(b) && a && b && c"), ["a", "b"]) + self.assertEqual(util.get_eval_value("a && eval(a) && eval(b) && b && c"), ["a", "b"]) self.assertEqual( - util.get_eval_value("eval(a) && eval(b) && a && b && c"), ["a", "b"] - ) - self.assertEqual( - util.get_eval_value("a && eval(a) && eval(b) && b && c"), ["a", "b"] - ) - self.assertEqual( - util.get_eval_value( - "eval(p.sub_rule) || p.obj == r.obj && eval(p.domain_rule)" - ), + util.get_eval_value("eval(p.sub_rule) || p.obj == r.obj && eval(p.domain_rule)"), ["p.sub_rule", "p.domain_rule"], )