Skip to content

Commit

Permalink
Rewriters: Hide internal details of rewriters
Browse files Browse the repository at this point in the history
DivModRewriter is exposed to the outside world, but it was broken,
because it was including "OsmtInternalException.h" which is not suppose
to be exposed and which we do not install. Thus, trying to include
"DivModRewriter.h" in another application results in a compilation
error.

We fix this problem by hiding the existing rewriters and exposing to the
outside world only API methods to use this functionality.
  • Loading branch information
blishko committed Aug 7, 2024
1 parent 810046a commit c69c8a9
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 104 deletions.
3 changes: 2 additions & 1 deletion .clang-files
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
./src/proof/PGCheck.cc
./src/rewriters/ArithmeticEqualityRewriter.h
./src/rewriters/DistinctRewriter.h
./src/rewriters/DistinctRewriter.cc
./src/rewriters/DivModRewriter.h
./src/rewriters/Rewriter.h
./src/rewriters/Rewritings.cc
./src/rewriters/Rewritings.h
./src/rewriters/Substitutor.h

./src/tsolvers/lasolver/CutCreator.h
Expand Down
14 changes: 4 additions & 10 deletions src/logics/ArithLogic.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include "ArithLogic.h"

#include "DivModRewriter.h"
#include "FastRational.h"
#include "OsmtApiException.h"
#include "OsmtInternalException.h"
#include "Polynomial.h"
#include "Rewriter.h"
#include "Rewritings.h"
#include "PtStore.h"
#include "SStore.h"
#include "StringConv.h"
Expand Down Expand Up @@ -1147,16 +1148,9 @@ PTRef ArithLogic::removeAuxVars(PTRef tr) {
class AuxSymbolMatcherConfig : public DefaultRewriterConfig {
ArithLogic & logic;
public:
AuxSymbolMatcherConfig(ArithLogic & logic) : logic(logic) {}
explicit AuxSymbolMatcherConfig(ArithLogic & logic) : logic(logic) {}
PTRef rewrite(PTRef tr) override {
if (not logic.isVar(tr)) return tr; // Only variables can match
auto symName = std::string_view(logic.getSymName(tr));
if (symName.compare(0, DivModConfig::divPrefix.size(), DivModConfig::divPrefix) == 0) {
return DivModConfig::getDivTermFor(logic, tr);
} else if (symName.compare(0, DivModConfig::divPrefix.size(), DivModConfig::modPrefix) == 0){
return DivModConfig::getModTermFor(logic, tr);
}
return tr;
return opensmt::tryGetOriginalDivModTerm(logic, tr).value_or(tr);
}
};
// Note: this has negligible impact on performance, no need to check if there are divs or mods
Expand Down
4 changes: 2 additions & 2 deletions src/logics/ArrayTheory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
#include "ArrayTheory.h"

#include "ArrayHelpers.h"
#include "DistinctRewriter.h"
#include "Rewritings.h"

PTRef ArrayTheory::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext const &) {
// TODO: simplify select over store on the same index
fla = rewriteDistincts(getLogic(), fla);
fla = opensmt::rewriteDistincts(getLogic(), fla);
fla = instantiateReadOverStore(getLogic(), fla);
return fla;
}
7 changes: 3 additions & 4 deletions src/logics/LATheory.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#define OPENSMT_LATHEORY_H
#include "Theory.h"
#include "ArithmeticEqualityRewriter.h"
#include "DistinctRewriter.h"
#include "DivModRewriter.h"
#include "Rewritings.h"

template<typename LinAlgLogic, typename LinAlgTHandler>
class LATheory : public Theory
Expand Down Expand Up @@ -40,14 +39,14 @@ PTRef rewriteDivMod(TLogic &, PTRef fla) { return fla; }
template<>
PTRef rewriteDivMod<ArithLogic>(ArithLogic & logic, PTRef fla) {
// Real logic cannot have div and mod
return not logic.hasIntegers() ? fla : DivModRewriter(logic).rewrite(fla);
return not logic.hasIntegers() ? fla : opensmt::rewriteDivMod(logic,fla);
}

}

template<typename LinAlgLogic, typename LinAlgTSHandler>
PTRef LATheory<LinAlgLogic,LinAlgTSHandler>::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext const &) {
fla = rewriteDistincts(getLogic(), fla);
fla = opensmt::rewriteDistincts(getLogic(), fla);
fla = rewriteDivMod<LinAlgLogic>(lalogic, fla);
ArithmeticEqualityRewriter equalityRewriter(lalogic);
fla = equalityRewriter.rewrite(fla);
Expand Down
6 changes: 3 additions & 3 deletions src/logics/UFLATheory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

#include "ArithmeticEqualityRewriter.h"
#include "ArrayHelpers.h"
#include "DistinctRewriter.h"
#include "LATheory.h"
#include "OsmtInternalException.h"
#include "Substitutor.h"
#include "Rewritings.h"
#include "TreeOps.h"


PTRef UFLATheory::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext const &) {
fla = rewriteDistincts(getLogic(), fla);
fla = opensmt::rewriteDistincts(getLogic(), fla);
fla = rewriteDivMod<ArithLogic>(logic, fla);
PTRef purified = purify(fla);
if (logic.hasArrays()) {
Expand Down
3 changes: 2 additions & 1 deletion src/logics/UFTheory.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "Theory.h"
#include "TreeOps.h"
#include "DistinctRewriter.h"
#include "Rewritings.h"

PTRef UFTheory::preprocessBeforeSubstitutions(PTRef fla, PreprocessingContext const & context) {
return context.perPartition ? fla : getLogic().mkAnd(fla, getLogic().learnEqTransitivity(fla));
}

PTRef UFTheory::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext const & context) {
using namespace opensmt;
fla = context.frameCount == 0 ? rewriteDistinctsKeepTopLevel(getLogic(), fla)
: rewriteDistincts(getLogic(), fla);
AppearsInUfVisitor(getLogic()).visit(fla);
Expand Down
4 changes: 2 additions & 2 deletions src/rewriters/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
add_library(rewriters OBJECT "")

target_sources(rewriters
PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/DistinctRewriter.cc"
PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/Rewritings.cc"
)

install(FILES
Substitutor.h
Rewriter.h
DivModRewriter.h
Rewritings.h
DESTINATION ${INSTALL_HEADERS_DIR}
)
18 changes: 0 additions & 18 deletions src/rewriters/DistinctRewriter.cc

This file was deleted.

6 changes: 0 additions & 6 deletions src/rewriters/DistinctRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,4 @@ class KeepTopLevelDistinctRewriter : public Rewriter<KeepTopLevelDistinctRewrite
: Rewriter<KeepTopLevelDistinctRewriteConfig>(logic, config), config(logic, std::move(topLevelDistincts)) {}
};

inline PTRef rewriteDistincts(Logic & logic, PTRef fla) {
return DistinctRewriter(logic).rewrite(fla);
}

PTRef rewriteDistinctsKeepTopLevel(Logic & logic, PTRef fla);

#endif // OPENSMT_DISTINCTREWRITER_H
23 changes: 5 additions & 18 deletions src/rewriters/DivModRewriter.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, Martin Blicha <martin.blicha@gmail.com>
* Copyright (c) 2021-2024, Martin Blicha <martin.blicha@gmail.com>
*
* SPDX-License-Identifier: MIT
*
Expand All @@ -8,17 +8,9 @@
#ifndef OPENSMT_DIVMODREWRITER_H
#define OPENSMT_DIVMODREWRITER_H

#include "Rewriter.h"

#include "ArithLogic.h"
#include "PTRef.h"

#include "OsmtApiException.h"
#include "OsmtInternalException.h"
#include "TypeUtils.h"

#include <string>
#include <unordered_map>
#include "Rewriter.h"

class DivModConfig : public DefaultRewriterConfig {
ArithLogic & logic;
Expand Down Expand Up @@ -64,7 +56,7 @@ class DivModConfig : public DefaultRewriterConfig {
static std::string_view constexpr divPrefix = ".div";
static std::string_view constexpr modPrefix = ".mod";

DivModConfig(ArithLogic & logic) : logic(logic) {}
explicit DivModConfig(ArithLogic & logic) : logic(logic) {}

PTRef rewrite(PTRef term) override {
SymRef symRef = logic.getSymRef(term);
Expand Down Expand Up @@ -125,7 +117,7 @@ class DivModRewriter : Rewriter<DivModConfig> {
DivModConfig config;

public:
DivModRewriter(ArithLogic & logic) : Rewriter<DivModConfig>(logic, config), logic(logic), config(logic) {}
explicit DivModRewriter(ArithLogic & logic) : Rewriter<DivModConfig>(logic, config), logic(logic), config(logic) {}

PTRef rewrite(PTRef term) override {
if (term == PTRef_Undef or not logic.hasSortBool(term)) {
Expand All @@ -139,9 +131,4 @@ class DivModRewriter : Rewriter<DivModConfig> {
}
};

// Simple single-use version
inline PTRef rewriteDivMod(ArithLogic & logic, PTRef root) {
return DivModRewriter(logic).rewrite(root);
}

#endif // OPENSMT_DIVMODEREWRITER_H
#endif // OPENSMT_DIVMODREWRITER_H
44 changes: 44 additions & 0 deletions src/rewriters/Rewritings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2024 Martin Blicha <martin.blicha@gmail.com>
*
* SPDX-License-Identifier: MIT
*
*/

#include "Rewritings.h"

#include "DistinctRewriter.h"
#include "DivModRewriter.h"
#include "TreeOps.h"

namespace opensmt {

PTRef rewriteDistincts(Logic & logic, PTRef fla) {
return DistinctRewriter(logic).rewrite(fla);
}

PTRef rewriteDistinctsKeepTopLevel(Logic & logic, PTRef fla) {
vec<PTRef> topLevelConjuncts = ::topLevelConjuncts(logic, fla);
KeepTopLevelDistinctRewriter::TopLevelDistincts topLevelDistincts;
for (PTRef conj : topLevelConjuncts) {
if (logic.isDisequality(conj)) { topLevelDistincts.insert(conj); }
}
return KeepTopLevelDistinctRewriter(logic, std::move(topLevelDistincts)).rewrite(fla);
}

PTRef rewriteDivMod(ArithLogic & logic, PTRef term) {
return DivModRewriter(logic).rewrite(term);
}

std::optional<PTRef> tryGetOriginalDivModTerm(ArithLogic & logic, PTRef tr) {
if (not logic.isVar(tr)) return std::nullopt; // Only variables can match
auto symName = std::string_view(logic.getSymName(tr));
if (symName.compare(0, DivModConfig::divPrefix.size(), DivModConfig::divPrefix) == 0) {
return DivModConfig::getDivTermFor(logic, tr);
}
if (symName.compare(0, DivModConfig::modPrefix.size(), DivModConfig::modPrefix) == 0) {
return DivModConfig::getModTermFor(logic, tr);
}
return std::nullopt;
}
} // namespace opensmt
28 changes: 28 additions & 0 deletions src/rewriters/Rewritings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2024 Martin Blicha <martin.blicha@gmail.com>
*
* SPDX-License-Identifier: MIT
*
*/

#ifndef OPENSMT_REWRITINGS_H
#define OPENSMT_REWRITINGS_H

#include "ArithLogic.h"
#include "PTRef.h"

#include <optional>

namespace opensmt {

PTRef rewriteDistincts(Logic & logic, PTRef fla);

PTRef rewriteDistinctsKeepTopLevel(Logic & logic, PTRef fla);

PTRef rewriteDivMod(ArithLogic & logic, PTRef fla);

std::optional<PTRef> tryGetOriginalDivModTerm(ArithLogic & logic, PTRef term);

} // namespace opensmt

#endif // OPENSMT_REWRITINGS_H
30 changes: 6 additions & 24 deletions test/unit/test_LIALogicMkTerms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
//

#include <gtest/gtest.h>
#include <ArithLogic.h>
#include "DivModRewriter.h"
#include "ArithLogic.h"
#include "IteHandler.h"
#include "Rewritings.h"
#include "TreeOps.h"

#include <algorithm>
Expand Down Expand Up @@ -194,27 +194,10 @@ TEST_F(LIALogicMkTermsTest, test_EqualityNormalization_EqualityToConstant) {
}

TEST_F(LIALogicMkTermsTest, test_ReverseAuxRewrite) {

static constexpr std::initializer_list<std::string_view> prefixes = {IteHandler::itePrefix, DivModConfig::divPrefix, DivModConfig::modPrefix};


auto hasAuxSymbols = [this](PTRef tr) {
class AuxSymbolMatcher {
ArithLogic const & logic;

public:
AuxSymbolMatcher(ArithLogic const & logic) : logic(logic) {}
bool operator()(PTRef tr) {
std::string_view const name = logic.getSymName(tr);
return std::any_of(prefixes.begin(), prefixes.end(), [&name](std::string_view const prefix) {
return name.compare(0, prefix.size(), prefix) == 0;
});
};
};
auto predicate = AuxSymbolMatcher(logic);
auto config = TermCollectorConfig(predicate);
TermVisitor(logic, config).visit(tr);
return config.extractCollectedTerms().size() > 0;
auto auxiliaryVarsInTerm = matchingSubTerms(logic, tr, [&](PTRef subTerm) {
return opensmt::rewritings::tryGetOriginalDivModTerm(logic, subTerm).has_value(); });
return auxiliaryVarsInTerm.size() > 0;
};

PTRef a = logic.mkIntVar("a");
Expand All @@ -231,8 +214,7 @@ TEST_F(LIALogicMkTermsTest, test_ReverseAuxRewrite) {
PTRef nested = logic.mkEq(logic.getTerm_IntZero(), logic.mkMod(ite, c));

for (PTRef tr : {term, eq, nested}) {

PTRef termWithAux = DivModRewriter(logic).rewrite(IteHandler(logic).rewrite(tr));
PTRef termWithAux = opensmt::rewritings::rewriteDivMod(logic, IteHandler(logic).rewrite(tr));
ASSERT_TRUE(hasAuxSymbols(termWithAux));

PTRef termWithoutAux = logic.removeAuxVars(termWithAux);
Expand Down
Loading

0 comments on commit c69c8a9

Please sign in to comment.