Skip to content
This repository has been archived by the owner on Aug 31, 2024. It is now read-only.

Edit plot #42

Merged
merged 1 commit into from
May 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions src/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,8 @@ def proto_names(self):
{'tcp', 'udp', 'oth'})))

@staticmethod
def proto_freq(record, labels, init_key, success_key, ev_key):
"""Average success rate by protocol"""
total = sum(record[init_key])
rec = smean(record[rarr(ev_key)])
return [round(sdiv(sum([
rget(b, lbl, 0) for b in record[success_key]]), total), 2)
if round(sdiv(total, rec), 2) > 0 else 0
for lbl in labels]
def flatten(two_d_list):
return [j for sub in two_d_list for j in sub]

def classifier_table(self):

Expand Down Expand Up @@ -124,16 +118,17 @@ def collapse(ds, cl, rb, rows):
def evasion_table(self):
def extract_values(record):
lbl = record[rarr('labels')]
vld = record[rarr('n_valid')]
nr = sum(record[rarr('n_records')])
ne = sum(record[rarr('n_evasions')])
vl = sum(record[rarr('n_valid')])
bm = sum([r['benign'] for r in lbl])
evades = sdiv(ne, nr)
valid = sdiv(vl, nr)
bl = sdiv(bm, vl)
does_evade = evades >= 0.005
valid, bl = sdiv(sum(vld), ne), sdiv(bm, ne)
return ResultsPlot.std_cols(record) + [
round(evades, 2),
round(valid, 2) if does_evade else 0,
round(valid, 2),
f"{100 * bl:.0f}--{100 * (1. - bl):.0f}"
if does_evade else '--']

Expand All @@ -142,21 +137,17 @@ def extract_values(record):
return self.std_hd + h, mat

def proto_table(self):
proto1 = rarr('n_evasions'), rarr('proto_evasions'), 'n_records'
proto2 = rarr('n_valid'), rarr('proto_valid'), 'n_evasions'

def extract_proto_values(record, labels):
evades = ResultsPlot.proto_freq(record, labels, *proto1)
valid = [v if e > 0 else 0 for e, v in zip(
evades, ResultsPlot.proto_freq(
record, labels, *proto2))]
return ResultsPlot.std_cols(record) + evades + valid

h = [f"e/{p}" for p in self.proto_names] + \
[f"v/{p}" for p in self.proto_names]
mat = [extract_proto_values(record, self.proto_names)
for record in self.raw_rata]
return self.std_hd + h, mat
headers, mat = self.flatten([[f"e/{p}", f"v/{p}"] for p in self.proto_names]), []
p_keys = ['proto_init', 'proto_evasions', 'proto_valid']
for record in self.raw_rata:
m = ResultsPlot.std_cols(record)
psum = lambda arr_key: sum(
[x[p] if p in x else 0 for x in record[rarr(arr_key)]])
for p in self.proto_names:
init, evs, val = [psum(k) for k in p_keys]
m += [round(sdiv(evs, init), 2), round(sdiv(val, init), 2)]
mat.append(m)
return self.std_hd + headers, mat

def reasons_table(self):
headers = ["DS", "ATK", "Proto", "Reason", "Freq"]
Expand Down Expand Up @@ -187,7 +178,8 @@ def reasons_table(self):
tot = sum([r[1] for r in reasons])
for r, v in reasons:
r, v = r.replace(proto, '', 1), v / tot
mat.append([ds, att, proto, r, round(v, 2)])
if round(v, 3) > 0:
mat.append([ds, att, proto, r, round(v, 3)])
return headers, mat

def write_table(self, headers, mat, file_name, sorter=None):
Expand Down