Skip to content

Commit

Permalink
refactor: modify model structure (#174)
Browse files Browse the repository at this point in the history
Signed-off-by: ffyuanda <46557895+ffyuanda@users.noreply.github.com>
  • Loading branch information
ffyuanda authored Jul 9, 2021
1 parent e9f35bf commit 03b3cc8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 55 deletions.
28 changes: 14 additions & 14 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_with_model_and_adapter(self, m, adapter=None):

def _initialize(self):
self.rm_map = dict()
self.eft = get_effector(self.model.model["e"]["e"].value)
self.eft = get_effector(self.model["e"]["e"].value)
self.watcher = None

self.enabled = True
Expand Down Expand Up @@ -159,8 +159,8 @@ def clear_policy(self):
self.model.clear_policy()

def init_rm_map(self):
if "g" in self.model.model.keys():
for ptype in self.model.model["g"]:
if "g" in self.model.keys():
for ptype in self.model["g"]:
self.rm_map[ptype] = default_role_manager.RoleManager(10)

def load_policy(self):
Expand Down Expand Up @@ -268,24 +268,24 @@ def enforce_ex(self, *rvals):

functions = self.fm.get_functions()

if "g" in self.model.model.keys():
for key, ast in self.model.model["g"].items():
if "g" in self.model.keys():
for key, ast in self.model["g"].items():
rm = ast.rm
functions[key] = generate_g_function(rm)

if "m" not in self.model.model.keys():
if "m" not in self.model.keys():
raise RuntimeError("model is undefined")

if "m" not in self.model.model["m"].keys():
if "m" not in self.model["m"].keys():
raise RuntimeError("model is undefined")

r_tokens = self.model.model["r"]["r"].tokens
p_tokens = self.model.model["p"]["p"].tokens
r_tokens = self.model["r"]["r"].tokens
p_tokens = self.model["p"]["p"].tokens

if len(r_tokens) != len(rvals):
raise RuntimeError("invalid request size")

exp_string = self.model.model["m"]["m"].value
exp_string = self.model["m"]["m"].value
has_eval = util.has_eval(exp_string)
if not has_eval:
expression = self._get_expression(exp_string, functions)
Expand All @@ -294,11 +294,11 @@ def enforce_ex(self, *rvals):

r_parameters = dict(zip(r_tokens, rvals))

policy_len = len(self.model.model["p"]["p"].policy)
policy_len = len(self.model["p"]["p"].policy)

explain_index = -1
if not 0 == policy_len:
for i, pvals in enumerate(self.model.model["p"]["p"].policy):
for i, pvals in enumerate(self.model["p"]["p"].policy):
if len(p_tokens) != len(pvals):
raise RuntimeError("invalid policy size")

Expand Down Expand Up @@ -353,7 +353,7 @@ def enforce_ex(self, *rvals):

parameters = r_parameters.copy()

for token in self.model.model["p"]["p"].tokens:
for token in self.model["p"]["p"].tokens:
parameters[token] = ""

result = expression.eval(parameters)
Expand All @@ -380,7 +380,7 @@ def enforce_ex(self, *rvals):

explain_rule = []
if explain_index != -1 and explain_index < policy_len:
explain_rule = self.model.model["p"]["p"].policy[explain_index]
explain_rule = self.model["p"]["p"].policy[explain_index]

return result, explain_rule

Expand Down
8 changes: 4 additions & 4 deletions casbin/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def add_def(self, sec, key, value):
else:
ast.value = util.remove_comments(util.escape_assertion(ast.value))

if sec not in self.model.keys():
self.model[sec] = {}
if sec not in self.keys():
self[sec] = {}

self.model[sec][key] = ast
self[sec][key] = ast

return True

Expand Down Expand Up @@ -76,6 +76,6 @@ def load_model_from_text(self, text):

def print_model(self):
self.logger.info("Model:")
for k, v in self.model.items():
for k, v in self.items():
for i, j in v.items():
self.logger.info("%s.%s: %s", k, i, j.value)
87 changes: 51 additions & 36 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,66 @@ def __init__(self):
self.logger = logging.getLogger(__name__)
self.model = {}

def __getitem__(self, item):
return self.model.get(item)

def __setitem__(self, key, value):
self.model[key] = value

def keys(self):
return self.model.keys()

def values(self):
return self.model.values()

def items(self):
return self.model.items()

def build_role_links(self, rm_map):
"""initializes the roles in RBAC."""

if "g" not in self.model.keys():
if "g" not in self.keys():
return

for ptype, ast in self.model["g"].items():
for ptype, ast in self["g"].items():
rm = rm_map[ptype]
ast.build_role_links(rm)

def build_incremental_role_links(self, rm, op, sec, ptype, rules):
if sec == "g":
self.model.get(sec).get(ptype).build_incremental_role_links(rm, op, rules)
self[sec].get(ptype).build_incremental_role_links(rm, op, rules)

def print_policy(self):
"""Log using info"""

self.logger.info("Policy:")
for sec in ["p", "g"]:
if sec not in self.model.keys():
if sec not in self.keys():
continue

for key, ast in self.model[sec].items():
for key, ast in self[sec].items():
self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy))

def clear_policy(self):
"""clears all current policy."""

for sec in ["p", "g"]:
if sec not in self.model.keys():
if sec not in self.keys():
continue

for key in self.model[sec].keys():
self.model[sec][key].policy = []
for key in self[sec].keys():
self[sec][key].policy = []

def get_policy(self, sec, ptype):
"""gets all rules in a policy."""

return self.model[sec][ptype].policy
return self[sec][ptype].policy

def get_filtered_policy(self, sec, ptype, field_index, *field_values):
"""gets rules based on field filters from a policy."""
return [
rule
for rule in self.model[sec][ptype].policy
for rule in self[sec][ptype].policy
if all(
value == "" or rule[field_index + i] == value
for i, value in enumerate(field_values)
Expand All @@ -59,18 +74,18 @@ def get_filtered_policy(self, sec, ptype, field_index, *field_values):

def has_policy(self, sec, ptype, rule):
"""determines whether a model has the specified policy rule."""
if sec not in self.model.keys():
if sec not in self.keys():
return False
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return False

return rule in self.model[sec][ptype].policy
return rule in self[sec][ptype].policy

def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the model."""

if not self.has_policy(sec, ptype, rule):
self.model[sec][ptype].policy.append(rule)
self[sec][ptype].policy.append(rule)
return True

return False
Expand All @@ -83,19 +98,19 @@ def add_policies(self, sec, ptype, rules):
return False

for rule in rules:
self.model[sec][ptype].policy.append(rule)
self[sec][ptype].policy.append(rule)

return True

def update_policy(self, sec, ptype, old_rule, new_rule):
"""update a policy rule from the model."""

if sec not in self.model.keys():
if sec not in self.keys():
return False
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return False

ast = self.model[sec][ptype]
ast = self[sec][ptype]

if old_rule in ast.policy:
rule_index = ast.policy.index(old_rule)
Expand All @@ -116,14 +131,14 @@ def update_policy(self, sec, ptype, old_rule, new_rule):
def update_policies(self, sec, ptype, old_rules, new_rules):
"""update policy rules from the model."""

if sec not in self.model.keys():
if sec not in self.keys():
return False
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return False
if len(old_rules) != len(new_rules):
return False

ast = self.model[sec][ptype]
ast = self[sec][ptype]
old_rules_index = []

for old_rule in old_rules:
Expand Down Expand Up @@ -152,18 +167,18 @@ def remove_policy(self, sec, ptype, rule):
if not self.has_policy(sec, ptype, rule):
return False

self.model[sec][ptype].policy.remove(rule)
self[sec][ptype].policy.remove(rule)

return rule not in self.model[sec][ptype].policy
return rule not in self[sec][ptype].policy

def remove_policies(self, sec, ptype, rules):
"""RemovePolicies removes policy rules from the model."""

for rule in rules:
if not self.has_policy(sec, ptype, rule):
return False
self.model[sec][ptype].policy.remove(rule)
if rule in self.model[sec][ptype].policy:
self[sec][ptype].policy.remove(rule)
if rule in self[sec][ptype].policy:
return False

return True
Expand All @@ -188,12 +203,12 @@ def remove_filtered_policy_returns_effects(

if len(field_values) == 0:
return []
if sec not in self.model.keys():
if sec not in self.keys():
return []
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return []

for rule in self.model[sec][ptype].policy:
for rule in self[sec][ptype].policy:
if all(
value == "" or rule[field_index + i] == value
for i, value in enumerate(field_values[0])
Expand All @@ -202,7 +217,7 @@ def remove_filtered_policy_returns_effects(
else:
tmp.append(rule)

self.model[sec][ptype].policy = tmp
self[sec][ptype].policy = tmp

return effects

Expand All @@ -211,12 +226,12 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
tmp = []
res = False

if sec not in self.model.keys():
if sec not in self.keys():
return res
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return res

for rule in self.model[sec][ptype].policy:
for rule in self[sec][ptype].policy:
if all(
value == "" or rule[field_index + i] == value
for i, value in enumerate(field_values)
Expand All @@ -225,19 +240,19 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
else:
tmp.append(rule)

self.model[sec][ptype].policy = tmp
self[sec][ptype].policy = tmp

return res

def get_values_for_field_in_policy(self, sec, ptype, field_index):
"""gets all values for a field for all rules in a policy, duplicated values are removed."""
values = []
if sec not in self.model.keys():
if sec not in self.keys():
return values
if ptype not in self.model[sec]:
if ptype not in self[sec]:
return values

for rule in self.model[sec][ptype].policy:
for rule in self[sec][ptype].policy:
value = rule[field_index]
if value not in values:
values.append(value)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_model_set_load(self):
self.assertTrue(e.model is None)
# creating new model
e.load_model()
self.assertTrue(e.model)
self.assertTrue(e.model is not None)

def test_enforcer_basic_without_spaces(self):
e = self.get_enforcer(
Expand Down

0 comments on commit 03b3cc8

Please sign in to comment.