Skip to content

Commit

Permalink
Merge pull request #8442 from diffblue/zero_extend
Browse files Browse the repository at this point in the history
zero extension expression
  • Loading branch information
kroening authored Nov 1, 2024
2 parents 20a1ecf + 5420b97 commit 8fcd9b1
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/solvers/flattening/boolbv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
return convert_replication(to_replication_expr(expr));
else if(expr.id()==ID_extractbits)
return convert_extractbits(to_extractbits_expr(expr));
else if(expr.id() == ID_zero_extend)
return convert_bitvector(to_zero_extend_expr(expr).lower());
else if(expr.id()==ID_bitnot || expr.id()==ID_bitand ||
expr.id()==ID_bitor || expr.id()==ID_bitxor ||
expr.id()==ID_bitxnor || expr.id()==ID_bitnor ||
Expand Down
8 changes: 5 additions & 3 deletions src/solvers/floatbv/float_bv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,10 @@ exprt float_bvt::mul(

// zero-extend the fractions (unpacked fraction has the hidden bit)
typet new_fraction_type=unsignedbv_typet((spec.f+1)*2);
const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type);
const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type);
const exprt fraction1 =
zero_extend_exprt{unpacked1.fraction, new_fraction_type};
const exprt fraction2 =
zero_extend_exprt{unpacked2.fraction, new_fraction_type};

// multiply the fractions
unbiased_floatt result;
Expand Down Expand Up @@ -750,7 +752,7 @@ exprt float_bvt::div(
unsignedbv_typet(div_width));

// zero-extend fraction2 to match fraction1
const typecast_exprt fraction2(unpacked2.fraction, fraction1.type());
const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()};

// divide fractions
unbiased_floatt result;
Expand Down
4 changes: 4 additions & 0 deletions src/solvers/smt2/smt2_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr)
{
convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns));
}
else if(expr.id() == ID_zero_extend)
{
convert_expr(to_zero_extend_expr(expr).lower());
}
else if(expr.id() == ID_function_application)
{
const auto &function_application_expr = to_function_application_expr(expr);
Expand Down
13 changes: 13 additions & 0 deletions src/solvers/smt2_incremental/convert_expr_to_smt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt(
count_trailing_zeros.pretty());
}

static smt_termt convert_expr_to_smt(
const zero_extend_exprt &zero_extend,
const sub_expression_mapt &converted)
{
UNREACHABLE_BECAUSE(
"zero_extend expression should have been lowered by the decision "
"procedure before conversion to smt terms");
}

static smt_termt convert_expr_to_smt(
const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok,
const sub_expression_mapt &converted)
Expand Down Expand Up @@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion(
{
return convert_expr_to_smt(*count_trailing_zeros, converted);
}
if(const auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
{
return convert_expr_to_smt(*zero_extend, converted);
}
if(
const auto prophecy_r_or_w_ok =
expr_try_dynamic_cast<prophecy_r_or_w_ok_exprt>(expr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "smt2_incremental_decision_procedure.h"

#include <util/arith_tools.h>
#include <util/bitvector_expr.h>
#include <util/byte_operators.h>
#include <util/c_types.h>
#include <util/range.h>
Expand Down Expand Up @@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns)
return expr;
}

static exprt lower_zero_extend(exprt expr, const namespacet &ns)
{
expr.visit_pre([](exprt &expr) {
if(auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
{
expr = zero_extend->lower();
}
});
return expr;
}

void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined(
const exprt &in_expr)
{
Expand Down Expand Up @@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties()

exprt smt2_incremental_decision_proceduret::lower(exprt expression) const
{
const exprt lowered = struct_encoding.encode(lower_enum(
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
const exprt lowered = struct_encoding.encode(lower_zero_extend(
lower_enum(
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
ns),
ns));
log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) {
if(lowered != expression)
Expand Down
21 changes: 18 additions & 3 deletions src/util/bitvector_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const
typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted));

// zero-extend the replacement bit to match src
auto new_value_casted = typecast_exprt(
typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type);
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};

// shift the replacement bits
auto new_value_shifted = shl_exprt(new_value_casted, index());
Expand Down Expand Up @@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const
bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted);

// zero-extend or shrink the replacement bits to match src
auto new_value_casted = typecast_exprt(new_value(), src_bv_type);
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};

// shift the replacement bits
auto new_value_shifted = shl_exprt(new_value_casted, index());
Expand Down Expand Up @@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const

return typecast_exprt::conditional_cast(result, type());
}

exprt zero_extend_exprt::lower() const
{
const auto old_width = to_bitvector_type(op().type()).get_width();
const auto new_width = to_bitvector_type(type()).get_width();

if(new_width > old_width)
{
return concatenation_exprt{
bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()};
}
else // new_width <= old_width
{
return extractbits_exprt{op(), 0, type()};
}
}
44 changes: 44 additions & 0 deletions src/util/bitvector_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr)
return ret;
}

/// \brief zero extension
/// The operand is converted to the given type by either
/// a) truncating if the new type is shorter, or
/// b) padding with most-significant zero bits if the new type is larger, or
/// c) reinterprets the operand as the given type if their widths match.
class zero_extend_exprt : public unary_exprt
{
public:
zero_extend_exprt(exprt _op, typet _type)
: unary_exprt(ID_zero_extend, std::move(_op), std::move(_type))
{
}

// a lowering to extraction or concatenation
exprt lower() const;
};

template <>
inline bool can_cast_expr<zero_extend_exprt>(const exprt &base)
{
return base.id() == ID_zero_extend;
}

/// \brief Cast an exprt to a \ref zero_extend_exprt
///
/// \a expr must be known to be \ref zero_extend_exprt.
///
/// \param expr: Source expression
/// \return Object of type \ref zero_extend_exprt
inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr)
{
PRECONDITION(expr.id() == ID_zero_extend);
zero_extend_exprt::check(expr);
return static_cast<const zero_extend_exprt &>(expr);
}

/// \copydoc to_zero_extend_expr(const exprt &)
inline zero_extend_exprt &to_zero_extend_expr(exprt &expr)
{
PRECONDITION(expr.id() == ID_zero_extend);
zero_extend_exprt::check(expr);
return static_cast<zero_extend_exprt &>(expr);
}

#endif // CPROVER_UTIL_BITVECTOR_EXPR_H
6 changes: 6 additions & 0 deletions src/util/format_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ void format_expr_configt::setup()
<< format(expr.type()) << ')';
};

expr_map[ID_zero_extend] =
[](std::ostream &os, const exprt &expr) -> std::ostream & {
return os << "zero_extend(" << format(to_zero_extend_expr(expr).op())
<< ", " << format(expr.type()) << ')';
};

expr_map[ID_floatbv_typecast] =
[](std::ostream &os, const exprt &expr) -> std::ostream & {
const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr);
Expand Down
1 change: 1 addition & 0 deletions src/util/irep_ids.def
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit)
IREP_ID_ONE(extractbits)
IREP_ID_ONE(update_bit)
IREP_ID_ONE(update_bits)
IREP_ID_ONE(zero_extend)
IREP_ID_TWO(C_reference, #reference)
IREP_ID_TWO(C_rvalue_reference, #rvalue_reference)
IREP_ID_ONE(true)
Expand Down
19 changes: 10 additions & 9 deletions src/util/lower_byte_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2491,15 +2491,16 @@ static exprt lower_byte_update(
exprt zero_extended;
if(bit_width > update_size_bits)
{
zero_extended = concatenation_exprt{
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
value,
bv_typet{bit_width}};

if(!is_little_endian)
to_concatenation_expr(zero_extended)
.op0()
.swap(to_concatenation_expr(zero_extended).op1());
if(is_little_endian)
zero_extended = zero_extend_exprt{value, bv_typet{bit_width}};
else
{
// Big endian -- the zero is added as LSB.
zero_extended = concatenation_exprt{
value,
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
bv_typet{bit_width}};
}
}
else
zero_extended = value;
Expand Down
4 changes: 4 additions & 0 deletions src/util/simplify_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node)
{
r = simplify_extractbits(to_extractbits_expr(expr));
}
else if(expr.id() == ID_zero_extend)
{
r = simplify_zero_extend(to_zero_extend_expr(expr));
}
else if(expr.id()==ID_ieee_float_equal ||
expr.id()==ID_ieee_float_notequal)
{
Expand Down
2 changes: 2 additions & 0 deletions src/util/simplify_expr_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class unary_overflow_exprt;
class unary_plus_exprt;
class update_exprt;
class with_exprt;
class zero_extend_exprt;

class simplify_exprt
{
Expand Down Expand Up @@ -152,6 +153,7 @@ class simplify_exprt
[[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &);
[[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &);
[[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &);
[[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &);
[[nodiscard]] resultt<> simplify_mult(const mult_exprt &);
[[nodiscard]] resultt<> simplify_div(const div_exprt &);
[[nodiscard]] resultt<> simplify_mod(const mod_exprt &);
Expand Down
12 changes: 12 additions & 0 deletions src/util/simplify_expr_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr)
return std::move(new_expr);
}

simplify_exprt::resultt<>
simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr)
{
if(!can_cast_type<bitvector_typet>(expr.type()))
return unchanged(expr);

if(!can_cast_type<bitvector_typet>(expr.op().type()))
return unchanged(expr);

return changed(simplify_node(expr.lower()));
}

simplify_exprt::resultt<>
simplify_exprt::simplify_shifts(const shift_exprt &expr)
{
Expand Down

0 comments on commit 8fcd9b1

Please sign in to comment.