Skip to content

Commit

Permalink
Simplified Subtract::Simplify by generalizing cases for Exponent and …
Browse files Browse the repository at this point in the history
…Imaginary (open-algebra#79)

Unbinkified code
  • Loading branch information
matthew-mccall authored Mar 26, 2024
1 parent 834ea90 commit 7cbf484
Showing 1 changed file with 25 additions and 52 deletions.
77 changes: 25 additions & 52 deletions src/Subtract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,81 +19,54 @@ Subtract<Expression>::Subtract(const Expression& minuend, const Expression& subt

auto Subtract<Expression>::Simplify() const -> std::unique_ptr<Expression>
{
auto simplifiedMinuend = mostSigOp ? mostSigOp->Simplify() : nullptr;
auto simplifiedSubtrahend = leastSigOp ? leastSigOp->Simplify() : nullptr;
const auto simplifiedMinuend = mostSigOp ? mostSigOp->Simplify() : nullptr;
const auto simplifiedSubtrahend = leastSigOp ? leastSigOp->Simplify() : nullptr;

Subtract simplifiedSubtract { *simplifiedMinuend, *simplifiedSubtrahend };
const Subtract simplifiedSubtract { *simplifiedMinuend, *simplifiedSubtrahend };

// 2 - 1 = 1
if (auto realCase = Subtract<Real>::Specialize(simplifiedSubtract); realCase != nullptr) {
const Real& minuend = realCase->GetMostSigOp();
const Real& subtrahend = realCase->GetLeastSigOp();

return std::make_unique<Real>(minuend.GetValue() - subtrahend.GetValue());
}

if (auto ImgCase = Subtract<Imaginary>::Specialize(simplifiedSubtract); ImgCase != nullptr) {
return std::make_unique<Multiply<Real, Imaginary>>(Real { 2.0 }, Imaginary {});
// x - x = 0
if (simplifiedMinuend->Equals(*simplifiedSubtrahend)) {
return std::make_unique<Real>(Real { 0.0 });
}

if (auto ImgCase = Subtract<Multiply<Expression, Imaginary>, Imaginary>::Specialize(simplifiedSubtract); ImgCase != nullptr) {
return std::make_unique<Multiply<Expression>>(
*(Subtract { ImgCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } }.Simplify()), Imaginary {});
}

if (auto ImgCase = Subtract<Imaginary, Multiply<Expression, Imaginary>>::Specialize(simplifiedSubtract); ImgCase != nullptr) {
return std::make_unique<Multiply<Expression>>(
*(Subtract { Real { 1.0 }, ImgCase->GetLeastSigOp().GetMostSigOp() }.Simplify()), Imaginary {});
}

if (auto ImgCase = Subtract<Multiply<Expression, Imaginary>, Multiply<Expression, Imaginary>>::Specialize(simplifiedSubtract); ImgCase != nullptr) {
return std::make_unique<Multiply<Expression>>(
*(Subtract { ImgCase->GetLeastSigOp().GetMostSigOp(), ImgCase->GetMostSigOp().GetMostSigOp() }.Simplify()), Imaginary {});
}

// exponent - exponent
if (auto exponentCase = Subtract<Exponent<Expression>, Exponent<Expression>>::Specialize(simplifiedSubtract); exponentCase != nullptr) {
if (exponentCase->GetMostSigOp().GetMostSigOp().Equals(exponentCase->GetLeastSigOp().GetMostSigOp()) && exponentCase->GetMostSigOp().GetLeastSigOp().Equals(exponentCase->GetLeastSigOp().GetLeastSigOp())) {
return std::make_unique<Real>(Real { 0.0 });
}
}

// a*exponent - exponent
if (auto exponentCase = Subtract<Multiply<Expression, Exponent<Expression>>, Exponent<Expression>>::Specialize(simplifiedSubtract); exponentCase != nullptr) {
if (exponentCase->GetMostSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetLeastSigOp().GetMostSigOp()) && exponentCase->GetMostSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetLeastSigOp().GetLeastSigOp())) {
if (Real { 1.0 }.Equals(exponentCase->GetMostSigOp().GetMostSigOp()))
return std::make_unique<Real>(Real { 0.0 });
return std::make_unique<Multiply<Expression>>(*(Subtract { exponentCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } }.Simplify()),
exponentCase->GetLeastSigOp());
// ax - x = (a-1)x
if (const auto minusOneCase = Subtract<Multiply<>, Expression>::Specialize(simplifiedSubtract); minusOneCase != nullptr) {
if (minusOneCase->GetMostSigOp().GetLeastSigOp().Equals(minusOneCase->GetLeastSigOp())) {
const Subtract newCoefficient { minusOneCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } };
return Multiply { newCoefficient, minusOneCase->GetLeastSigOp() }.Simplify();
}
}

// exponent - a*exponent
if (auto exponentCase = Subtract<Exponent<Expression>, Multiply<Expression, Exponent<Expression>>>::Specialize(simplifiedSubtract); exponentCase != nullptr) {
if (exponentCase->GetLeastSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetMostSigOp().GetMostSigOp()) && exponentCase->GetLeastSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp())) {
if (Real { 1.0 }.Equals(exponentCase->GetLeastSigOp().GetMostSigOp()))
return std::make_unique<Real>(Real { 0.0 });
return std::make_unique<Multiply<Expression>>(*(Subtract { Real { 1.0 }, exponentCase->GetLeastSigOp().GetMostSigOp() }.Simplify()),
exponentCase->GetMostSigOp());
// x-ax = (1-a)x
if (const auto oneMinusCase = Subtract<Expression, Multiply<>>::Specialize(simplifiedSubtract); oneMinusCase != nullptr) {
if (oneMinusCase->GetMostSigOp().Equals(oneMinusCase->GetLeastSigOp().GetLeastSigOp())) {
const Subtract newCoefficient { Real { 1.0 }, oneMinusCase->GetLeastSigOp().GetMostSigOp() };
return Multiply { newCoefficient, oneMinusCase->GetMostSigOp() }.Simplify();
}
}

// a*exponent - b*exponent
if (auto exponentCase = Subtract<Multiply<Expression, Exponent<Expression>>, Multiply<Expression, Exponent<Expression>>>::Specialize(simplifiedSubtract); exponentCase != nullptr) {
if (exponentCase->GetLeastSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp().GetMostSigOp()) && exponentCase->GetLeastSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp().GetLeastSigOp())) {
if (Real { 1.0 }.Equals(exponentCase->GetLeastSigOp().GetMostSigOp()))
return std::make_unique<Real>(Real { 0.0 });
return std::make_unique<Multiply<Expression>>(
*(Subtract { exponentCase->GetMostSigOp().GetMostSigOp(), exponentCase->GetLeastSigOp().GetMostSigOp() }.Simplify()),
exponentCase->GetMostSigOp().GetLeastSigOp());
// ax-bx= (a-b)x
if (const auto coefficientCase = Subtract<Multiply<>>::Specialize(simplifiedSubtract); coefficientCase != nullptr) {
if (coefficientCase->GetMostSigOp().GetLeastSigOp().Equals(coefficientCase->GetLeastSigOp().GetLeastSigOp())) {
const Subtract newCoefficient { coefficientCase->GetMostSigOp().GetMostSigOp(), coefficientCase->GetLeastSigOp().GetMostSigOp() };
return Multiply { newCoefficient, coefficientCase->GetLeastSigOp().GetLeastSigOp() }.Simplify();
}
}

// log(a) - log(b) = log(a / b)
if (auto logCase = Subtract<Log<Expression, Expression>, Log<Expression, Expression>>::Specialize(simplifiedSubtract); logCase != nullptr) {
if (const auto logCase = Subtract<Log<>>::Specialize(simplifiedSubtract); logCase != nullptr) {
if (logCase->GetMostSigOp().GetMostSigOp().Equals(logCase->GetLeastSigOp().GetMostSigOp())) {
const IExpression auto& base = logCase->GetMostSigOp().GetMostSigOp();
const IExpression auto& argument = Divide<Expression>({ logCase->GetMostSigOp().GetLeastSigOp(), logCase->GetLeastSigOp().GetLeastSigOp() });
return std::make_unique<Log<Expression>>(base, argument);
const IExpression auto& argument = Divide({ logCase->GetMostSigOp().GetLeastSigOp(), logCase->GetLeastSigOp().GetLeastSigOp() });
return std::make_unique<Log<>>(base, argument);
}
}

Expand Down

0 comments on commit 7cbf484

Please sign in to comment.