Skip to content

Commit

Permalink
[Feature] Add UnaryExpression and Negate (open-algebra#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-mccall authored Apr 5, 2024
1 parent aa0063c commit dc52199
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 2 deletions.
3 changes: 1 addition & 2 deletions include/Oasis/Add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class Add<

[[nodiscard]] auto ToString() const -> std::string final;

static auto Specialize(const Expression& other) -> std::unique_ptr<Add>;
static auto Specialize(const Expression& other, tf::Subflow& subflow) -> std::unique_ptr<Add>;
DECL_SPECIALIZE(Add)

EXPRESSION_TYPE(Add)
EXPRESSION_CATEGORY(Associative | Commutative)
Expand Down
6 changes: 6 additions & 0 deletions include/Oasis/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ enum class ExpressionType {
Divide,
Exponent,
Log,
Negate,
Sqrt,
};

/**
Expand Down Expand Up @@ -280,6 +282,10 @@ class Expression {
return category; \
}

#define DECL_SPECIALIZE(type) \
static auto Specialize(const Expression& other) -> std::unique_ptr<type>; \
static auto Specialize(const Expression& other, tf::Subflow&) -> std::unique_ptr<type>;

} // namespace Oasis

std::unique_ptr<Oasis::Expression> operator+(const std::unique_ptr<Oasis::Expression>& lhs, const std::unique_ptr<Oasis::Expression>& rhs);
Expand Down
60 changes: 60 additions & 0 deletions include/Oasis/Negate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//
// Created by Matthew McCall on 3/29/24.
//

#ifndef NEGATE_HPP
#define NEGATE_HPP

#include "Multiply.hpp"
#include "fmt/core.h"

#include "UnaryExpression.hpp"

namespace Oasis {

template <typename OperandT>
class Negate final : public UnaryExpression<Negate, OperandT> {
public:
Negate() = default;
Negate(const Negate& other)
: UnaryExpression<Negate, OperandT>(other)
{
}

explicit Negate(const OperandT& operand)
: UnaryExpression<Negate, OperandT>(operand)
{
}

[[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
{
return Multiply {
Real { -1.0 },
this->GetOperand()
}
.Simplify();
}

auto Simplify(tf::Subflow& subflow) const -> std::unique_ptr<Expression> override
{
return Multiply {
Real { -1.0 },
this->GetOperand()
}
.Simplify(subflow);
}

[[nodiscard]] auto ToString() const -> std::string override
{
return fmt::format("-({})", this->GetOperand().ToString());
}

IMPL_SPECIALIZE_UNARYEXPR(Negate, OperandT)

EXPRESSION_TYPE(Negate)
EXPRESSION_CATEGORY(None)
};

} // Oasis

#endif // NEGATE_HPP
127 changes: 127 additions & 0 deletions include/Oasis/UnaryExpression.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//
// Created by Matthew McCall on 3/29/24.
//

#ifndef UNARYEXPRESSION_HPP
#define UNARYEXPRESSION_HPP

#include "Expression.hpp"

namespace Oasis {

template <template <IExpression> class DerivedT, IExpression OperandT>
class UnaryExpression : public Expression {

using DerivedSpecialized = DerivedT<OperandT>;
using DerivedGeneralized = DerivedT<Expression>;

public:
UnaryExpression() = default;

UnaryExpression(const UnaryExpression& other)
: Expression(other)
{
if (other.HasOperand()) {
SetOperand(other.GetOperand());
}
}

explicit UnaryExpression(const OperandT& operand)
{
SetOperand(operand);
}

[[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
{
return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
}

auto Copy(tf::Subflow&) const -> std::unique_ptr<Expression> final
{
return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
}

[[nodiscard]] auto Equals(const Expression& other) const -> bool final
{
if (!other.Is<DerivedSpecialized>()) {
return false;
}

// generalize
const auto otherGeneralized = other.Generalize();
const auto& otherUnaryGeneralized = dynamic_cast<const DerivedGeneralized&>(*otherGeneralized);

return op->Equals(otherUnaryGeneralized.GetOperand());
}

[[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
{
return std::make_unique<DerivedGeneralized>(*this);
}

auto Generalize(tf::Subflow& subflow) const -> std::unique_ptr<Expression> final
{
return DerivedGeneralized { *this }.Copy(subflow);
}

auto GetOperand() const -> const OperandT&
{
return *op;
}

auto HasOperand() const -> bool
{
return op != nullptr;
}

[[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
{
return this->GetType() == other.GetType();
}

auto StructurallyEquivalent(const Expression& other, tf::Subflow&) const -> bool final
{
return this->GetType() == other.GetType();
}

auto SetOperand(const OperandT& operand) -> void
{
if constexpr (std::same_as<OperandT, Expression>) {
this->op = operand.Copy();
} else {
this->op = std::make_unique<OperandT>(operand);
}
}

protected:
std::unique_ptr<OperandT> op;
};

#define IMPL_SPECIALIZE_UNARYEXPR(DerivedT, OperandT) \
static auto Specialize(const Expression& other) -> std::unique_ptr<DerivedT> \
{ \
if (!other.Is<DerivedT>()) { \
return nullptr; \
} \
\
auto specialized = std::make_unique<DerivedT<OperandT>>(); \
std::unique_ptr<Expression> otherGeneralized = other.Generalize(); \
const auto& otherUnary = dynamic_cast<const DerivedT<Expression>&>(*otherGeneralized); \
\
if (auto operand = OperandT::Specialize(otherUnary.GetOperand()); operand != nullptr) { \
specialized->op = std::move(operand); \
return specialized; \
} \
\
return nullptr; \
} \
\
static auto Specialize(const Expression& other, tf::Subflow&) -> std::unique_ptr<DerivedT> \
{ \
/* TODO: Actually implement */ \
return DerivedT<OperandT>::Specialize(other); \
}

} // Oasis

#endif // UNARYEXPRESSION_HPP
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(Oasis_SOURCES
Imaginary.cpp
Log.cpp
Multiply.cpp
Negate.cpp
Real.cpp
Subtract.cpp
Undefined.cpp
Expand All @@ -30,8 +31,11 @@ set(Oasis_HEADERS
../include/Oasis/LeafExpression.hpp
../include/Oasis/Log.hpp
../include/Oasis/Multiply.hpp
../include/Oasis/Negate.hpp
../include/Oasis/Real.hpp
../include/Oasis/Subtract.hpp
../include/Oasis/UnaryExpression.hpp
../include/Oasis/Undefined.hpp
../include/Oasis/Variable.hpp)

# Adds a library target called "Oasis" to be built from source files.
Expand Down
8 changes: 8 additions & 0 deletions src/Negate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//
// Created by Matthew McCall on 3/29/24.
//

#include "Oasis/Negate.hpp"

namespace Oasis {
} // Oasis
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(Oasis_TESTS
ExponentTests.cpp
LogTests.cpp
MultiplyTests.cpp
NegateTests.cpp
PolynomialTests.cpp
SubtractTests.cpp)

Expand Down
19 changes: 19 additions & 0 deletions tests/NegateTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//
// Created by Matthew McCall on 4/5/24.
//

#include "catch2/catch_test_macros.hpp"

#include <Oasis/Negate.hpp>

TEST_CASE("Negate", "[Negate]")
{
const Oasis::Negate negativeOne {
Oasis::Real { 1.0 }
};

const auto simplified = negativeOne.Simplify();

const Oasis::Real expected { -1.0 };
REQUIRE(simplified->Equals(expected));
}

0 comments on commit dc52199

Please sign in to comment.