Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ignore unused elements #84

Merged
merged 2 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions pytm/pytm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def _setLabel(element):
return "<br/>".join(wrap(element.name, 14))


def _sort(elements, addOrder=False):
ordered = sorted(elements, key=lambda flow: flow.order)
def _sort(flows, addOrder=False):
ordered = sorted(flows, key=lambda flow: flow.order)
if not addOrder:
return ordered
for i, flow in enumerate(ordered):
Expand All @@ -120,6 +120,27 @@ def _sort(elements, addOrder=False):
return ordered


def _sort_elem(elements):
orders = {}
for e in elements:
try:
order = e.order
except AttributeError:
continue
if e.source not in orders or orders[e.source] > order:
orders[e.source] = order
m = max(orders.values()) + 1
return sorted(
elements,
key=lambda e: (
orders.get(e, m),
e.__class__.__name__,
getattr(e, "order", 0),
str(e),
),
)


def _match_responses(flows):
"""Ensure that responses are pointing to requests"""
index = defaultdict(list)
Expand Down Expand Up @@ -149,8 +170,8 @@ def _match_responses(flows):
return flows


def _applyDefaults(elements):
for e in elements:
def _apply_defaults(flows):
for e in flows:
e._safeset("data", e.source.data)
if e.isResponse:
e._safeset("protocol", e.source.protocol)
Expand Down Expand Up @@ -190,6 +211,21 @@ def _describe_classes(classes):
print()


def _get_elements_and_boundaries(flows):
"""filter out elements and boundaries not used in this TM"""
elements = {}
boundaries = {}
for e in flows:
elements[e] = True
elements[e.source] = True
elements[e.sink] = True
if e.source.inBoundary is not None:
boundaries[e.source.inBoundary] = True
if e.sink.inBoundary is not None:
boundaries[e.sink.inBoundary] = True
return (elements.keys(), boundaries.keys())


''' End of help functions '''


Expand Down Expand Up @@ -279,6 +315,7 @@ class TM():
doc="JSON file with custom threats")
isOrdered = varBool(False, doc="Automatically order all Dataflows")
mergeResponses = varBool(False, doc="Merge response edges in DFDs")
ignoreUnused = varBool(False, doc="Ignore elements not used in any Dataflow")

def __init__(self, name, **kwargs):
for key, value in kwargs.items():
Expand Down Expand Up @@ -316,10 +353,15 @@ def resolve(self):
def check(self):
if self.description is None:
raise ValueError("Every threat model should have at least a brief description of the system being modeled.")
_applyDefaults(TM._BagOfFlows)
TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered))
_apply_defaults(TM._BagOfFlows)
if self.ignoreUnused:
TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows)
for e in (TM._BagOfElements):
e.check()
TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered))
if self.ignoreUnused:
# cannot rely on user defined order if assets are re-used in multiple models
TM._BagOfElements = _sort_elem(TM._BagOfElements)

def dfd(self):
print("digraph tm {\n\tgraph [\n\tfontname = Arial;\n\tfontsize = 14;\n\t]")
Expand Down Expand Up @@ -691,6 +733,7 @@ def dfd(self, **kwargs):


class SetOfProcesses(Process):

def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)

Expand Down
15 changes: 15 additions & 0 deletions tests/seq_unused.plantuml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@startuml
actor actor_User_579e9aae81 as "User"
database datastore_SQLDatabase_d2006ce1bb as "SQL Database"
entity server_WebServer_f2eb7a3ff7 as "Web Server"
actor_User_579e9aae81 -> server_WebServer_f2eb7a3ff7: User enters comments (*)
note left
bbb
end note
server_WebServer_f2eb7a3ff7 -> datastore_SQLDatabase_d2006ce1bb: Insert query with comments
note left
ccc
end note
datastore_SQLDatabase_d2006ce1bb -> server_WebServer_f2eb7a3ff7: Retrieve comments
server_WebServer_f2eb7a3ff7 -> actor_User_579e9aae81: Show comments (*)
@enduml
30 changes: 30 additions & 0 deletions tests/test_pytmfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,35 @@ def test_seq(self):
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.seq()

output = out.getvalue().strip()
self.maxDiff = None
self.assertEqual(output, expected)

def test_seq_unused(self):
random.seed(0)
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, 'seq_unused.plantuml')) as x:
expected = x.read().strip()

TM.reset()
tm = TM("my test tm", description="aaa", ignoreUnused=True)
internet = Boundary("Internet")
server_db = Boundary("Server/DB")
user = Actor("User", inBoundary=internet)
web = Server("Web Server")
db = Datastore("SQL Database", inBoundary=server_db)
Lambda("Unused Lambda")

Dataflow(user, web, "User enters comments (*)", note="bbb")
Dataflow(web, db, "Insert query with comments", note="ccc")
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.seq()

Expand Down Expand Up @@ -76,6 +105,7 @@ def test_dfd(self):
Dataflow(db, web, "Retrieve comments")
Dataflow(web, user, "Show comments (*)")

tm.check()
with captured_output() as (out, err):
tm.dfd()

Expand Down