diff --git a/pytm/pytm.py b/pytm/pytm.py index a193c6d..f50a367 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -98,8 +98,8 @@ def _setLabel(element): return "
".join(wrap(element.name, 14)) -def _sort(elements, addOrder=False): - ordered = sorted(elements, key=lambda flow: flow.order) +def _sort(flows, addOrder=False): + ordered = sorted(flows, key=lambda flow: flow.order) if not addOrder: return ordered for i, flow in enumerate(ordered): @@ -109,6 +109,27 @@ def _sort(elements, addOrder=False): return ordered +def _sort_elem(elements): + orders = {} + for e in elements: + try: + order = e.order + except AttributeError: + continue + if e.source not in orders or orders[e.source] > order: + orders[e.source] = order + m = max(orders.values()) + 1 + return sorted( + elements, + key=lambda e: ( + orders.get(e, m), + e.__class__.__name__, + getattr(e, "order", 0), + str(e), + ), + ) + + def _match_responses(flows): """Ensure that responses are pointing to requests""" index = defaultdict(list) @@ -138,8 +159,8 @@ def _match_responses(flows): return flows -def _applyDefaults(elements): - for e in elements: +def _apply_defaults(flows): + for e in flows: e._safeset("data", e.source.data) if e.isResponse: e._safeset("protocol", e.source.protocol) @@ -152,6 +173,21 @@ def _applyDefaults(elements): e._safeset("isEncrypted", e.sink.isEncrypted) +def _get_elements_and_boundaries(flows): + """filter out elements and boundaries not used in this TM""" + elements = {} + boundaries = {} + for e in flows: + elements[e] = True + elements[e.source] = True + elements[e.sink] = True + if e.source.inBoundary is not None: + boundaries[e.source.inBoundary] = True + if e.sink.inBoundary is not None: + boundaries[e.sink.inBoundary] = True + return (elements.keys(), boundaries.keys()) + + ''' End of help functions ''' @@ -215,6 +251,7 @@ class TM(): onSet=lambda i, v: i._init_threats()) isOrdered = varBool(False) mergeResponses = varBool(False) + ignoreUnused = varBool(False) def __init__(self, name, **kwargs): for key, value in kwargs.items(): @@ -252,10 +289,15 @@ def resolve(self): def check(self): if self.description is None: raise ValueError("Every threat model should have at least a brief description of the system being modeled.") - _applyDefaults(TM._BagOfFlows) + _apply_defaults(TM._BagOfFlows) + if self.ignoreUnused: + TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows) for e in (TM._BagOfElements): e.check() TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered)) + if self.ignoreUnused: + # cannot rely on user defined order if assets are re-used in multiple models + TM._BagOfElements = _sort_elem(TM._BagOfElements) def dfd(self): print("digraph tm {\n\tgraph [\n\tfontname = Arial;\n\tfontsize = 14;\n\t]") diff --git a/tests/seq_unused.plantuml b/tests/seq_unused.plantuml new file mode 100644 index 0000000..a83cc7d --- /dev/null +++ b/tests/seq_unused.plantuml @@ -0,0 +1,15 @@ +@startuml +actor actor_User_579e9aae81 as "User" +database datastore_SQLDatabase_d2006ce1bb as "SQL Database" +entity server_WebServer_f2eb7a3ff7 as "Web Server" +actor_User_579e9aae81 -> server_WebServer_f2eb7a3ff7: User enters comments (*) +note left +bbb +end note +server_WebServer_f2eb7a3ff7 -> datastore_SQLDatabase_d2006ce1bb: Insert query with comments +note left +ccc +end note +datastore_SQLDatabase_d2006ce1bb -> server_WebServer_f2eb7a3ff7: Retrieve comments +server_WebServer_f2eb7a3ff7 -> actor_User_579e9aae81: Show comments (*) +@enduml diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index c1f017b..3f1d968 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -49,6 +49,35 @@ def test_seq(self): Dataflow(db, web, "Retrieve comments") Dataflow(web, user, "Show comments (*)") + tm.check() + with captured_output() as (out, err): + tm.seq() + + output = out.getvalue().strip() + self.maxDiff = None + self.assertEqual(output, expected) + + def test_seq_unused(self): + random.seed(0) + dir_path = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(dir_path, 'seq_unused.plantuml')) as x: + expected = x.read().strip() + + TM.reset() + tm = TM("my test tm", description="aaa", ignoreUnused=True) + internet = Boundary("Internet") + server_db = Boundary("Server/DB") + user = Actor("User", inBoundary=internet) + web = Server("Web Server") + db = Datastore("SQL Database", inBoundary=server_db) + Lambda("Unused Lambda") + + Dataflow(user, web, "User enters comments (*)", note="bbb") + Dataflow(web, db, "Insert query with comments", note="ccc") + Dataflow(db, web, "Retrieve comments") + Dataflow(web, user, "Show comments (*)") + + tm.check() with captured_output() as (out, err): tm.seq() @@ -76,6 +105,7 @@ def test_dfd(self): Dataflow(db, web, "Retrieve comments") Dataflow(web, user, "Show comments (*)") + tm.check() with captured_output() as (out, err): tm.dfd()