diff --git a/src/Subtract.cpp b/src/Subtract.cpp index bbeee382..068aea61 100644 --- a/src/Subtract.cpp +++ b/src/Subtract.cpp @@ -19,11 +19,12 @@ Subtract::Subtract(const Expression& minuend, const Expression& subt auto Subtract::Simplify() const -> std::unique_ptr { - auto simplifiedMinuend = mostSigOp ? mostSigOp->Simplify() : nullptr; - auto simplifiedSubtrahend = leastSigOp ? leastSigOp->Simplify() : nullptr; + const auto simplifiedMinuend = mostSigOp ? mostSigOp->Simplify() : nullptr; + const auto simplifiedSubtrahend = leastSigOp ? leastSigOp->Simplify() : nullptr; - Subtract simplifiedSubtract { *simplifiedMinuend, *simplifiedSubtrahend }; + const Subtract simplifiedSubtract { *simplifiedMinuend, *simplifiedSubtrahend }; + // 2 - 1 = 1 if (auto realCase = Subtract::Specialize(simplifiedSubtract); realCase != nullptr) { const Real& minuend = realCase->GetMostSigOp(); const Real& subtrahend = realCase->GetLeastSigOp(); @@ -31,69 +32,41 @@ auto Subtract::Simplify() const -> std::unique_ptr return std::make_unique(minuend.GetValue() - subtrahend.GetValue()); } - if (auto ImgCase = Subtract::Specialize(simplifiedSubtract); ImgCase != nullptr) { - return std::make_unique>(Real { 2.0 }, Imaginary {}); + // x - x = 0 + if (simplifiedMinuend->Equals(*simplifiedSubtrahend)) { + return std::make_unique(Real { 0.0 }); } - if (auto ImgCase = Subtract, Imaginary>::Specialize(simplifiedSubtract); ImgCase != nullptr) { - return std::make_unique>( - *(Subtract { ImgCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } }.Simplify()), Imaginary {}); - } - - if (auto ImgCase = Subtract>::Specialize(simplifiedSubtract); ImgCase != nullptr) { - return std::make_unique>( - *(Subtract { Real { 1.0 }, ImgCase->GetLeastSigOp().GetMostSigOp() }.Simplify()), Imaginary {}); - } - - if (auto ImgCase = Subtract, Multiply>::Specialize(simplifiedSubtract); ImgCase != nullptr) { - return std::make_unique>( - *(Subtract { ImgCase->GetLeastSigOp().GetMostSigOp(), ImgCase->GetMostSigOp().GetMostSigOp() }.Simplify()), Imaginary {}); - } - - // exponent - exponent - if (auto exponentCase = Subtract, Exponent>::Specialize(simplifiedSubtract); exponentCase != nullptr) { - if (exponentCase->GetMostSigOp().GetMostSigOp().Equals(exponentCase->GetLeastSigOp().GetMostSigOp()) && exponentCase->GetMostSigOp().GetLeastSigOp().Equals(exponentCase->GetLeastSigOp().GetLeastSigOp())) { - return std::make_unique(Real { 0.0 }); - } - } - - // a*exponent - exponent - if (auto exponentCase = Subtract>, Exponent>::Specialize(simplifiedSubtract); exponentCase != nullptr) { - if (exponentCase->GetMostSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetLeastSigOp().GetMostSigOp()) && exponentCase->GetMostSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetLeastSigOp().GetLeastSigOp())) { - if (Real { 1.0 }.Equals(exponentCase->GetMostSigOp().GetMostSigOp())) - return std::make_unique(Real { 0.0 }); - return std::make_unique>(*(Subtract { exponentCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } }.Simplify()), - exponentCase->GetLeastSigOp()); + // ax - x = (a-1)x + if (const auto minusOneCase = Subtract, Expression>::Specialize(simplifiedSubtract); minusOneCase != nullptr) { + if (minusOneCase->GetMostSigOp().GetLeastSigOp().Equals(minusOneCase->GetLeastSigOp())) { + const Subtract newCoefficient { minusOneCase->GetMostSigOp().GetMostSigOp(), Real { 1.0 } }; + return Multiply { newCoefficient, minusOneCase->GetLeastSigOp() }.Simplify(); } } - // exponent - a*exponent - if (auto exponentCase = Subtract, Multiply>>::Specialize(simplifiedSubtract); exponentCase != nullptr) { - if (exponentCase->GetLeastSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetMostSigOp().GetMostSigOp()) && exponentCase->GetLeastSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp())) { - if (Real { 1.0 }.Equals(exponentCase->GetLeastSigOp().GetMostSigOp())) - return std::make_unique(Real { 0.0 }); - return std::make_unique>(*(Subtract { Real { 1.0 }, exponentCase->GetLeastSigOp().GetMostSigOp() }.Simplify()), - exponentCase->GetMostSigOp()); + // x-ax = (1-a)x + if (const auto oneMinusCase = Subtract>::Specialize(simplifiedSubtract); oneMinusCase != nullptr) { + if (oneMinusCase->GetMostSigOp().Equals(oneMinusCase->GetLeastSigOp().GetLeastSigOp())) { + const Subtract newCoefficient { Real { 1.0 }, oneMinusCase->GetLeastSigOp().GetMostSigOp() }; + return Multiply { newCoefficient, oneMinusCase->GetMostSigOp() }.Simplify(); } } - // a*exponent - b*exponent - if (auto exponentCase = Subtract>, Multiply>>::Specialize(simplifiedSubtract); exponentCase != nullptr) { - if (exponentCase->GetLeastSigOp().GetLeastSigOp().GetMostSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp().GetMostSigOp()) && exponentCase->GetLeastSigOp().GetLeastSigOp().GetLeastSigOp().Equals(exponentCase->GetMostSigOp().GetLeastSigOp().GetLeastSigOp())) { - if (Real { 1.0 }.Equals(exponentCase->GetLeastSigOp().GetMostSigOp())) - return std::make_unique(Real { 0.0 }); - return std::make_unique>( - *(Subtract { exponentCase->GetMostSigOp().GetMostSigOp(), exponentCase->GetLeastSigOp().GetMostSigOp() }.Simplify()), - exponentCase->GetMostSigOp().GetLeastSigOp()); + // ax-bx= (a-b)x + if (const auto coefficientCase = Subtract>::Specialize(simplifiedSubtract); coefficientCase != nullptr) { + if (coefficientCase->GetMostSigOp().GetLeastSigOp().Equals(coefficientCase->GetLeastSigOp().GetLeastSigOp())) { + const Subtract newCoefficient { coefficientCase->GetMostSigOp().GetMostSigOp(), coefficientCase->GetLeastSigOp().GetMostSigOp() }; + return Multiply { newCoefficient, coefficientCase->GetLeastSigOp().GetLeastSigOp() }.Simplify(); } } // log(a) - log(b) = log(a / b) - if (auto logCase = Subtract, Log>::Specialize(simplifiedSubtract); logCase != nullptr) { + if (const auto logCase = Subtract>::Specialize(simplifiedSubtract); logCase != nullptr) { if (logCase->GetMostSigOp().GetMostSigOp().Equals(logCase->GetLeastSigOp().GetMostSigOp())) { const IExpression auto& base = logCase->GetMostSigOp().GetMostSigOp(); - const IExpression auto& argument = Divide({ logCase->GetMostSigOp().GetLeastSigOp(), logCase->GetLeastSigOp().GetLeastSigOp() }); - return std::make_unique>(base, argument); + const IExpression auto& argument = Divide({ logCase->GetMostSigOp().GetLeastSigOp(), logCase->GetLeastSigOp().GetLeastSigOp() }); + return std::make_unique>(base, argument); } }