diff --git a/dash-pipeline/python_model/__model_server.py b/dash-pipeline/python_model/__model_server.py index b6fe9311f..3a47d5153 100644 --- a/dash-pipeline/python_model/__model_server.py +++ b/dash-pipeline/python_model/__model_server.py @@ -32,10 +32,57 @@ class Range: priority : int +def insert_request_to_table_entry(insertRequest: InsertRequest, key_format: list): + entry = Entry() + + entry.values = [] + for idx, val in enumerate(insertRequest.values): + if key_format[idx] is EXACT: + entry.values.append(int(val.exact, 0)) + elif key_format[idx] is TERNARY: + ternary = Entry.Ternary() + ternary.value = int(val.ternary.value , 0) + ternary.mask = int(val.ternary.mask , 0) + entry.values.append(ternary) + elif key_format[idx] is LIST: + ternary_list = [] + for t in val.ternary_list: + ternary = Entry.Ternary() + ternary.value = int(t.value , 0) + ternary.mask = int(t.mask , 0) + ternary_list.append(ternary) + entry.values.append(ternary_list) + elif key_format[idx] is RANGE: + range = Entry.Range() + range.first = int(val.range.first , 0) + range.last = int(val.range.last , 0) + entry.values.append(range) + elif key_format[idx] is RANGE_LIST: + range_list = [] + for r in val.range_list: + range = Entry.Range() + range.first = int(r.first , 0) + range.last = int(r.last , 0) + range_list.append(range) + entry.values.append(range_list) + elif key_format[idx] is LPM: + lpm = Entry.LPM() + lpm.value = int(val.prefix.value , 0) + lpm.prefix_len = val.prefix.prefix_len + entry.values.append(lpm) + + entry.action = id_map[insertRequest.action] + + entry.params = [] + for param_str in insertRequest.params: + entry.params.append(int(param_str , 0)) + + entry.priority = insertRequest.priority + return entry + def table_insert_api(insertRequest: InsertRequest): table = id_map[insertRequest.table] - - + table.insert(insert_request_to_table_entry(insertRequest, list(table.key.values()))) def json_obj_to_insert_request(json_obj): insertRequest = InsertRequest() @@ -83,9 +130,6 @@ def json_obj_to_insert_request(json_obj): insertRequest.priority = json_obj["priority"] return insertRequest - - - class ModelTCPHandler(socketserver.BaseRequestHandler): def handle(self): api_id = self.request.recv(1)[0] @@ -93,10 +137,8 @@ def handle(self): json_buf_size = int(self.request.recv(8).decode("ascii"), 16) json_obj = json.loads(self.request.recv(json_buf_size)) insertRequest = json_obj_to_insert_request(json_obj) - - -# self.request.sendall(data) - + table_insert_api(insertRequest) + self.request.sendall(b'\x00') HOST, PORT = "localhost", 46500 diff --git a/dash-pipeline/python_model/__table.py b/dash-pipeline/python_model/__table.py index 5277e69ff..6c32f6db9 100644 --- a/dash-pipeline/python_model/__table.py +++ b/dash-pipeline/python_model/__table.py @@ -1,49 +1,66 @@ from inspect import * from __vars import * -from threading import Lock from __sai_keys import * -def EXACT(entry_value, match_value, width): +class Entry: + class Ternary: + value : int + mask : int + + class LPM: + value : int + prefix_len : int + + class Range: + first : int + last : int + + values : list + action : function + params : list[int] + priority : int + +def EXACT(entry_value: int, match_value: int, width: int): return entry_value == match_value -def TERNARY(entry_value, match_value, width): - value = entry_value["value"] - mask = entry_value["mask"] +def TERNARY(entry_value: Entry.Ternary, match_value: int, width: int): + value = entry_value.value + mask = entry_value.mask return (value & mask) == (match_value & mask) -def LIST(entry_value, match_value, width): +def LIST(entry_value: list[Entry.Ternary], match_value: int, width: int): for ev in entry_value: if TERNARY(ev, match_value, width): return True return False -def RANGE(entry_value, match_value, width): - first = entry_value["first"] - last = entry_value["last"] +def RANGE(entry_value: Entry.Range, match_value: int, width: int): + first = entry_value.first + last = entry_value.last return match_value >= first and match_value <= last -def RANGE_LIST(entry_value, match_value, width): +def RANGE_LIST(entry_value: list[Entry.Range], match_value, width): for ev in entry_value: if RANGE(ev, match_value, width): return True return False -def LPM(entry_value, match_value, width): - value = entry_value["value"] - prefix_len = entry_value["prefix_len"] +def LPM(entry_value: Entry.LPM, match_value: int, width: int): + value = entry_value.value + prefix_len = entry_value.prefix_len mask = ((1 << prefix_len) - 1) << (width - prefix_len) return (value & mask) == (match_value & mask) -def _winning_criteria_PRIORITY(a, b, key): - return a["priority"] < b["priority"] +def _winning_criteria_PRIORITY(a: Entry, b: Entry, key): + return a.priority < b.priority -def _winning_criteria_PREFIX_LEN(a, b, key): - lpm_key = None +def _winning_criteria_PREFIX_LEN(a: Entry, b: Entry, key): + idx = 0 for k in key: if key[k] == LPM: - lpm_key = k break - return a[lpm_key]["prefix_len"] > b[lpm_key]["prefix_len"] + idx = idx + 1 + return a.values[idx].prefix_len > b.values[idx].prefix_len class Table: def __init__(self, key, actions, default_action=NoAction, default_params=[], per_entry_stats = False, api_name=None, is_object=None): @@ -56,28 +73,11 @@ def __init__(self, key, actions, default_action=NoAction, default_params=[], per if (default_action is NoAction) and (NoAction not in self.actions): self.actions.append((NoAction, {DEFAULT_ONLY : True})) self.api_hints = self.__extract_api_hints(api_name, is_object) - self.lock = Lock() def insert(self, entry): - self.lock.acquire() - self.__insert(entry) - self.lock.release() - - def apply(self): - self.lock.acquire() - res = self.__apply() - self.lock.release() - return res - - def delete(self, entry): - self.lock.acquire() - self.__delete(entry) - self.lock.release() - - def __insert(self, entry): self.entries.append(entry) - def __apply(self): + def apply(self): entry = self.__lookup() res = {} if entry is None: @@ -87,25 +87,27 @@ def __apply(self): res["hit"] = False res["action_run"] = action else: - action = entry["action"] - params = entry["params"] + action = entry.action + params = entry.params action(*params) res["hit"] = True res["action_run"] = action return res - def __delete(self, entry): + def delete(self, entry): self.entries.remove(entry) - def __match_entry(self, entry): + def __match_entry(self, entry: Entry): + idx = 0 for k in self.key: _read_value_res = _read_value(k) match_value = _read_value_res[0] width = _read_value_res[1] match_routine = self.key[k] - entry_value = entry[k] + entry_value = entry.values[idx] if not match_routine(entry_value, match_value, width): return False + idx = idx + 1 return True def __get_all_matching_entries(self):