Skip to content

Commit

Permalink
Overloading all operators on value sets is done, fixed issues with si…
Browse files Browse the repository at this point in the history
…gned division and remainder
  • Loading branch information
ckirsch committed Feb 9, 2025
1 parent 797a00b commit 361b1af
Showing 1 changed file with 86 additions and 53 deletions.
139 changes: 86 additions & 53 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,70 @@ def __xor__(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: x ^ y)

def __lshift__(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: (x << y) % 2**self.sid_line.size)

def LShR(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: (x >> y) % 2**self.sid_line.size)

def __rshift__(self, values):
# right shift operator computes arithmetic right shift in Python
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values,
lambda x, y: (self.sid_line.get_signed_value(x) >> y) % 2**self.sid_line.size)

def __add__(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: (x + y) % 2**self.sid_line.size)

def __sub__(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: (x - y) % 2**self.sid_line.size)

def __mul__(self, values):
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: (x * y) % 2**self.sid_line.size)

def __div__(self, values):
# using the integer portion of division, not floor division,
# because int(x / y) != x // y if x < 0 or y < 0 since
# the integer portion of division truncates towards 0 whereas
# floor division truncates towards negative infinity
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values,
lambda x, y: (int(self.sid_line.get_signed_value(x) / values.sid_line.get_signed_value(y))
if not (y == 0 or (self.sid_line.get_signed_value(x) == -2**(self.sid_line.size - 1) and
values.sid_line.get_signed_value(y) == -1))
else -1 if y == 0 else -2**(self.sid_line.size - 1)) % 2**self.sid_line.size)

def UDiv(self, values):
# using floor division is ok since x >= 0 and y >= 0
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values,
lambda x, y: x // y if y != 0 else 2**self.sid_line.size - 1)

def SRem(self, values):
# using the integer portion of division, not the % operator,
# because x % y != x - int(x / y) * y if x < 0 since
# the % operator in Python computes modulo, not remainder,
# such that x // y * y + x % y == x holds in Python for all x and y even if x < 0
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values,
lambda x, y: (self.sid_line.get_signed_value(x) -
int(self.sid_line.get_signed_value(x) / values.sid_line.get_signed_value(y)) *
values.sid_line.get_signed_value(y))
% 2**self.sid_line.size
if not (y == 0 or (self.sid_line.get_signed_value(x) == -2**(self.sid_line.size - 1) and
values.sid_line.get_signed_value(y) == -1))
else x if y == 0 else 0)

def URem(self, values):
# using the % operator is ok since x >= 0 and y >= 0
assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
return self.apply_binary(self.sid_line, values, lambda x, y: x % y if y != 0 else x)

def Concat(self, values, sid_line):
assert isinstance(self.sid_line, Bitvec) and isinstance(values.sid_line, Bitvec)
return self.apply_binary(sid_line, values, lambda x, y: (x << values.sid_line.size) + y)
Expand Down Expand Up @@ -1536,14 +1600,6 @@ def get_mapped_array_expression_for(self, index):
arg2_line = self.arg2_line.get_mapped_array_expression_for(None)
return self.copy(arg1_line, arg2_line)

def propagate(self, arg1_value, arg2_value, op_lambda):
results = Values(self.sid_line)
for value1 in arg1_value.values:
for value2 in arg2_value.values:
results.set_value(self.sid_line, op_lambda(value1, value2),
Constraints.intersection(arg1_value.values[value1], arg2_value.values[value2]))
return results

class Implies(Binary):
keyword = OP_IMPLIES

Expand Down Expand Up @@ -1813,58 +1869,33 @@ def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
if not arg1_line.sid_line.match_sorts(arg2_line.sid_line):
raise model_error("compatible first and second operand sorts", line_no)

def propagate(self, arg1_value, arg2_value, op_lambda):
return super().propagate(arg1_value, arg2_value, lambda x, y: op_lambda(x, y) % 2**self.sid_line.size)

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg2_value = self.arg2_line.get_values(step)
if Instance.PROPAGATE_BINARY:
if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
if self.op == OP_SLL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x << y)
self.cache_values[step] = arg1_value << arg2_value
elif self.op == OP_SRL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x >> y)
self.cache_values[step] = arg1_value.LShR(arg2_value)
elif self.op == OP_SRA:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: self.arg1_line.sid_line.get_signed_value(x) >> y)
self.cache_values[step] = arg1_value >> arg2_value
elif self.op == OP_ADD:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x + y)
self.cache_values[step] = arg1_value + arg2_value
elif self.op == OP_SUB:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x - y)
self.cache_values[step] = arg1_value - arg2_value
elif self.op == OP_MUL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x * y)
self.cache_values[step] = arg1_value * arg2_value
elif self.op == OP_SDIV:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: int(self.arg1_line.sid_line.get_signed_value(x) /
self.arg2_line.sid_line.get_signed_value(y))
if not (y == 0 or
(self.arg1_line.sid_line.get_signed_value(x) ==
-2**(self.sid_line.size - 1) and
self.arg2_line.sid_line.get_signed_value(y) == -1))
else -1 if y == 0 else -2**(self.sid_line.size - 1))
self.cache_values[step] = arg1_value / arg2_value
elif self.op == OP_UDIV:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: int(x / y) if y != 0 else 2**self.sid_line.size - 1)
self.cache_values[step] = arg1_value.UDiv(arg2_value)
elif self.op == OP_SREM:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: self.arg1_line.sid_line.get_signed_value(x) %
self.arg2_line.sid_line.get_signed_value(y)
if not (y == 0 or
(self.arg1_line.sid_line.get_signed_value(x) ==
-2**(self.sid_line.size - 1) and
self.arg2_line.sid_line.get_signed_value(y) == -1))
else x if y == 0 else 0)
self.cache_values[step] = arg1_value.SRem(arg2_value)
else:
assert self.op == OP_UREM
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x % y if y != 0 else x)
self.cache_values[step] = arg1_value.URem(arg2_value)
return self.cache_values[step]
arg1_value = arg1_value.get_expression()
arg2_value = arg2_value.get_expression()
Expand All @@ -1873,27 +1904,29 @@ def get_values(self, step):

def get_z3(self):
if self.z3 is None:
z3_arg1 = self.arg1_line.get_z3()
z3_arg2 = self.arg2_line.get_z3()
if self.op == OP_SLL:
self.z3 = self.arg1_line.get_z3() << self.arg2_line.get_z3()
self.z3 = z3_arg1 << z3_arg2
elif self.op == OP_SRL:
self.z3 = z3.LShR(self.arg1_line.get_z3(), self.arg2_line.get_z3())
self.z3 = z3.LShR(z3_arg1, z3_arg2)
elif self.op == OP_SRA:
self.z3 = self.arg1_line.get_z3() >> self.arg2_line.get_z3()
self.z3 = z3_arg1 >> z3_arg2
elif self.op == OP_ADD:
self.z3 = self.arg1_line.get_z3() + self.arg2_line.get_z3()
self.z3 = z3_arg1 + z3_arg2
elif self.op == OP_SUB:
self.z3 = self.arg1_line.get_z3() - self.arg2_line.get_z3()
self.z3 = z3_arg1 - z3_arg2
elif self.op == OP_MUL:
self.z3 = self.arg1_line.get_z3() * self.arg2_line.get_z3()
self.z3 = z3_arg1 * z3_arg2
elif self.op == OP_SDIV:
self.z3 = self.arg1_line.get_z3() / self.arg2_line.get_z3()
self.z3 = z3_arg1 / z3_arg2
elif self.op == OP_UDIV:
self.z3 = z3.UDiv(self.arg1_line.get_z3(), self.arg2_line.get_z3())
self.z3 = z3.UDiv(z3_arg1, z3_arg2)
elif self.op == OP_SREM:
self.z3 = z3.SRem(self.arg1_line.get_z3(), self.arg2_line.get_z3())
self.z3 = z3.SRem(z3_arg1, z3_arg2)
else:
assert self.op == OP_UREM
self.z3 = z3.URem(self.arg1_line.get_z3(), self.arg2_line.get_z3())
self.z3 = z3.URem(z3_arg1, z3_arg2)
return self.z3

def get_bitwuzla(self, tm):
Expand Down

0 comments on commit 361b1af

Please sign in to comment.