Skip to content

Commit

Permalink
remove unused elements
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick committed Mar 15, 2020
1 parent 79da0ee commit 6d82349
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
52 changes: 47 additions & 5 deletions pytm/pytm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def _setLabel(element):
return "<br/>".join(wrap(element.name, 14))


def _sort(elements, addOrder=False):
ordered = sorted(elements, key=lambda flow: flow.order)
def _sort(flows, addOrder=False):
ordered = sorted(flows, key=lambda flow: flow.order)
if not addOrder:
return ordered
for i, flow in enumerate(ordered):
Expand All @@ -109,6 +109,27 @@ def _sort(elements, addOrder=False):
return ordered


def _sort_elem(elements):
orders = {}
for e in elements:
try:
order = e.order
except AttributeError:
continue
if e.source not in orders or orders[e.source] > order:
orders[e.source] = order
m = max(orders.values()) + 1
return sorted(
elements,
key=lambda e: (
orders.get(e, m),
e.__class__.__name__,
getattr(e, "order", 0),
str(e),
),
)


def _match_responses(flows):
"""Ensure that responses are pointing to requests"""
index = defaultdict(list)
Expand Down Expand Up @@ -138,8 +159,8 @@ def _match_responses(flows):
return flows


def _applyDefaults(elements):
for e in elements:
def _apply_defaults(flows):
for e in flows:
e._safeset("data", e.source.data)
if e.isResponse:
e._safeset("protocol", e.source.protocol)
Expand All @@ -152,6 +173,21 @@ def _applyDefaults(elements):
e._safeset("isEncrypted", e.sink.isEncrypted)


def _get_elements_and_boundaries(flows):
"""filter out elements and boundaries not used in this TM"""
elements = {}
boundaries = {}
for e in flows:
elements[e] = True
elements[e.source] = True
elements[e.sink] = True
if e.source.inBoundary is not None:
boundaries[e.source.inBoundary] = True
if e.sink.inBoundary is not None:
boundaries[e.sink.inBoundary] = True
return (elements.keys(), boundaries.keys())


''' End of help functions '''


Expand Down Expand Up @@ -215,6 +251,7 @@ class TM():
onSet=lambda i, v: i._init_threats())
isOrdered = varBool(False)
mergeResponses = varBool(False)
ignoreUnused = varBool(False)

def __init__(self, name, **kwargs):
for key, value in kwargs.items():
Expand Down Expand Up @@ -252,10 +289,15 @@ def resolve(self):
def check(self):
if self.description is None:
raise ValueError("Every threat model should have at least a brief description of the system being modeled.")
_applyDefaults(TM._BagOfFlows)
_apply_defaults(TM._BagOfFlows)
if self.ignoreUnused:
TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows)
for e in (TM._BagOfElements):
e.check()
TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered))
if self.ignoreUnused:
# cannot rely on user defined order if assets are re-used in multiple models
TM._BagOfElements = _sort_elem(TM._BagOfElements)

def dfd(self):
print("digraph tm {\n\tgraph [\n\tfontname = Arial;\n\tfontsize = 14;\n\t]")
Expand Down
15 changes: 15 additions & 0 deletions tests/seq_unused.plantuml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@startuml
actor actor_User_579e9aae81 as "User"
database datastore_SQLDatabase_d2006ce1bb as "SQL Database"
entity server_WebServer_f2eb7a3ff7 as "Web Server"
actor_User_579e9aae81 -> server_WebServer_f2eb7a3ff7: User enters comments (*)
note left
bbb
end note
server_WebServer_f2eb7a3ff7 -> datastore_SQLDatabase_d2006ce1bb: Insert query with comments
note left
ccc
end note
datastore_SQLDatabase_d2006ce1bb -> server_WebServer_f2eb7a3ff7: Retrieve comments
server_WebServer_f2eb7a3ff7 -> actor_User_579e9aae81: Show comments (*)
@enduml
30 changes: 30 additions & 0 deletions tests/test_pytmfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,35 @@ def test_seq(self):
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.seq()

output = out.getvalue().strip()
self.maxDiff = None
self.assertEqual(output, expected)

def test_seq_unused(self):
random.seed(0)
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, 'seq_unused.plantuml')) as x:
expected = x.read().strip()

TM.reset()
tm = TM("my test tm", description="aaa", ignoreUnused=True)
internet = Boundary("Internet")
server_db = Boundary("Server/DB")
user = Actor("User", inBoundary=internet)
web = Server("Web Server")
db = Datastore("SQL Database", inBoundary=server_db)
Lambda("Unused Lambda")

Dataflow(user, web, "User enters comments (*)", note="bbb")
Dataflow(web, db, "Insert query with comments", note="ccc")
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.seq()

Expand Down Expand Up @@ -76,6 +105,7 @@ def test_dfd(self):
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.dfd()

Expand Down

0 comments on commit 6d82349

Please sign in to comment.