Skip to content

Commit

Permalink
engine: make short circuiting configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
williballenthin committed Nov 8, 2021
1 parent ad119d7 commit 3e74da9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 43 deletions.
104 changes: 68 additions & 36 deletions capa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def __str__(self):
def __repr__(self):
return str(self)

def evaluate(self, features: FeatureSet) -> Result:
def evaluate(self, features: FeatureSet, short_circuit=True) -> Result:
"""
classes that inherit `Statement` must implement `evaluate`
args:
short_circuit (bool): if true, then statements like and/or/some may short circuit.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -85,19 +88,24 @@ def __init__(self, children, description=None):
super(And, self).__init__(description=description)
self.children = children

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.and"] += 1

results = []
for child in self.children:
result = child.evaluate(ctx)
results.append(result)
if not result:
# short circuit
return Result(False, self, results)
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if not result:
# short circuit
return Result(False, self, results)

return Result(True, self, results)
return Result(True, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
success = all(results)
return Result(success, self, results)


class Or(Statement):
Expand All @@ -113,19 +121,24 @@ def __init__(self, children, description=None):
super(Or, self).__init__(description=description)
self.children = children

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.or"] += 1

results = []
for child in self.children:
result = child.evaluate(ctx)
results.append(result)
if result:
# short circuit as soon as we hit one match
return Result(True, self, results)
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if result:
# short circuit as soon as we hit one match
return Result(True, self, results)

return Result(False, self, results)
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
success = any(results)
return Result(success, self, results)


class Not(Statement):
Expand All @@ -135,11 +148,11 @@ def __init__(self, child, description=None):
super(Not, self).__init__(description=description)
self.child = child

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.not"] += 1

results = [self.child.evaluate(ctx)]
results = [self.child.evaluate(ctx, short_circuit=short_circuit)]
success = not results[0]
return Result(success, self, results)

Expand All @@ -158,23 +171,32 @@ def __init__(self, count, children, description=None):
self.count = count
self.children = children

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.some"] += 1

results = []
satisfied_children_count = 0
for child in self.children:
result = child.evaluate(ctx)
results.append(result)
if result:
satisfied_children_count += 1
if short_circuit:
results = []
satisfied_children_count = 0
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if result:
satisfied_children_count += 1

if satisfied_children_count >= self.count:
# short circuit as soon as we hit the threshold
return Result(True, self, results)
if satisfied_children_count >= self.count:
# short circuit as soon as we hit the threshold
return Result(True, self, results)

return Result(False, self, results)
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
# note that here we cast the child result as a bool
# because we've overridden `__bool__` above.
#
# we can't use `if child is True` because the instance is not True.
success = sum([1 for child in results if bool(child) is True]) >= self.count
return Result(success, self, results)


class Range(Statement):
Expand All @@ -186,7 +208,7 @@ def __init__(self, child, min=None, max=None, description=None):
self.min = min if min is not None else 0
self.max = max if max is not None else (1 << 64 - 1)

def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.range"] += 1

Expand Down Expand Up @@ -214,7 +236,7 @@ def __init__(self, scope, child):
self.scope = scope
self.child = child

def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
raise ValueError("cannot evaluate a subscope directly!")


Expand Down Expand Up @@ -272,8 +294,18 @@ def match(rules: List["capa.rules.Rule"], features: FeatureSet, va: int) -> Tupl
features = collections.defaultdict(set, copy.copy(features))

for rule in rules:
res = rule.evaluate(features)
res = rule.evaluate(features, short_circuit=True)
if res:
# we first matched the rule with short circuiting enabled.
# this is much faster than without short circuiting.
# however, we want to collect all results thoroughly,
# so once we've found a match quickly,
# go back and capture results without short circuiting.
res = rule.evaluate(features, short_circuit=False)

# sanity check
assert bool(res) is True

results[rule.name].append((va, res))
# we need to update the current `features`
# because subsequent iterations of this loop may use newly added features,
Expand Down
16 changes: 12 additions & 4 deletions capa/features/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __str__(self):
def __repr__(self):
return str(self)

def evaluate(self, ctx: Dict["Feature", Set[int]]) -> Result:
def evaluate(self, ctx: Dict["Feature", Set[int]], **kwargs) -> Result:
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature." + self.name] += 1
return Result(self in ctx, self, [], locations=ctx.get(self, []))
Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, value: str, description=None):
super(Substring, self).__init__(value, description=description)
self.value = value

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.substring"] += 1

Expand All @@ -210,6 +210,10 @@ def evaluate(self, ctx):

if self.value in feature.value:
matches[feature.value].extend(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break

if matches:
# finalize: defaultdict -> dict
Expand Down Expand Up @@ -280,7 +284,7 @@ def __init__(self, value: str, description=None):
"invalid regular expression: %s it should use Python syntax, try it at https://pythex.org" % value
)

def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.regex"] += 1

Expand All @@ -302,6 +306,10 @@ def evaluate(self, ctx):
# so that they don't have to prefix/suffix their terms like: /.*foo.*/.
if self.re.search(feature.value):
matches[feature.value].extend(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break

if matches:
# finalize: defaultdict -> dict
Expand Down Expand Up @@ -366,7 +374,7 @@ def __init__(self, value: bytes, description=None):
super(Bytes, self).__init__(value, description=description)
self.value = value

def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1

Expand Down
4 changes: 2 additions & 2 deletions capa/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,10 @@ def extract_subscope_rules(self):
for new_rule in self._extract_subscope_rules_rec(self.statement):
yield new_rule

def evaluate(self, features: FeatureSet):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.rule"] += 1
return self.statement.evaluate(features)
return self.statement.evaluate(features, short_circuit=short_circuit)

@classmethod
def from_dict(cls, d, definition):
Expand Down
10 changes: 9 additions & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,15 @@ def test_render_offset():
assert str(capa.features.insn.Offset(1, bitness=capa.features.common.BITNESS_X64)) == "offset/x64(0x1)"


def test_short_circuit_order():
def test_short_circuit():
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True

# with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2


def test_eval_order():
# base cases.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True
Expand Down

0 comments on commit 3e74da9

Please sign in to comment.