Skip to content

Commit

Permalink
Merge pull request #2301 from Smit-create/flip_sign
Browse files Browse the repository at this point in the history
PASS: Update FlipSign pass to use Intrinsic Function
  • Loading branch information
Smit-create authored Aug 26, 2023
2 parents ab5d201 + 1cb7340 commit a838e88
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 21 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ RUN(NAME expr_19 LABELS cpython llvm c)
RUN(NAME expr_20 LABELS cpython llvm c)
RUN(NAME expr_21 LABELS cpython llvm c)
RUN(NAME expr_22 LABELS cpython llvm c)
RUN(NAME expr_23 LABELS cpython llvm c)

RUN(NAME expr_01u LABELS cpython llvm c NOFAST)
RUN(NAME expr_02u LABELS cpython llvm c NOFAST)
Expand Down
23 changes: 23 additions & 0 deletions integration_tests/expr_23.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from lpython import f32, i32

def flip_sign_check():
x: f32
eps: f32 = f32(1e-5)

number: i32 = 123
x = f32(5.5)

if (number%2 == 1):
x = -x

assert abs(x - f32(-5.5)) < eps

number = 124
x = f32(5.5)

if (number%2 == 1):
x = -x

assert abs(x - f32(5.5)) < eps

flip_sign_check()
10 changes: 6 additions & 4 deletions src/libasr/pass/flip_sign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ class FlipSignVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FlipSi
// xi = xor(shiftl(int(Nd),63), xi)
LCOMPILERS_ASSERT(flip_sign_signal_variable);
LCOMPILERS_ASSERT(flip_sign_variable);
ASR::stmt_t* flip_sign_call = PassUtils::get_flipsign(flip_sign_signal_variable,
flip_sign_variable, al, unit, pass_options, current_scope,
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
pass_result.push_back(al, flip_sign_call);
ASR::expr_t* flip_sign_result = PassUtils::get_flipsign(flip_sign_signal_variable,
flip_sign_variable, al, unit, x.base.base.loc);
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc,
flip_sign_variable, flip_sign_result, nullptr)));
}
}

Expand Down Expand Up @@ -212,6 +212,8 @@ void pass_replace_flip_sign(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& pass_options) {
FlipSignVisitor v(al, unit, pass_options);
v.visit_TranslationUnit(unit);
PassUtils::UpdateDependenciesVisitor u(al);
u.visit_TranslationUnit(unit);
}


Expand Down
86 changes: 86 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ enum class IntrinsicScalarFunctions : int64_t {
Exp2,
Expm1,
FMA,
FlipSign,
ListIndex,
Partition,
ListReverse,
Expand Down Expand Up @@ -95,6 +96,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(Exp2)
INTRINSIC_NAME_CASE(Expm1)
INTRINSIC_NAME_CASE(FMA)
INTRINSIC_NAME_CASE(FlipSign)
INTRINSIC_NAME_CASE(ListIndex)
INTRINSIC_NAME_CASE(Partition)
INTRINSIC_NAME_CASE(ListReverse)
Expand Down Expand Up @@ -1343,6 +1345,86 @@ namespace FMA {

} // namespace FMA

namespace FlipSign {

static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 2,
"ASR Verify: Call to FlipSign must have exactly 2 arguments",
x.base.base.loc, diagnostics);
ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]);
ASRUtils::require_impl((is_integer(*type1) && is_real(*type2)),
"ASR Verify: Arguments to FlipSign must be of int and real type respectively",
x.base.base.loc, diagnostics);
}

static ASR::expr_t *eval_FlipSign(Allocator &al, const Location &loc,
ASR::ttype_t* t1, Vec<ASR::expr_t*> &args) {
int a = ASR::down_cast<ASR::IntegerConstant_t>(args[0])->m_n;
double b = ASR::down_cast<ASR::RealConstant_t>(args[1])->m_r;
if (a % 2 == 1) b = -b;
return make_ConstantWithType(make_RealConstant_t, b, t1, loc);
}

static inline ASR::asr_t* create_FlipSign(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 2) {
err("Intrinsic FlipSign function accepts exactly 2 arguments", loc);
}
ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]);
if (!ASRUtils::is_integer(*type1) || !ASRUtils::is_real(*type2)) {
err("Argument of the FlipSign function must be int and real respectively",
args[0]->base.loc);
}
ASR::expr_t *m_value = nullptr;
if (all_args_evaluated(args)) {
Vec<ASR::expr_t*> arg_values; arg_values.reserve(al, 2);
arg_values.push_back(al, expr_value(args[0]));
arg_values.push_back(al, expr_value(args[1]));
m_value = eval_FlipSign(al, loc, expr_type(args[1]), arg_values);
}
return ASR::make_IntrinsicScalarFunction_t(al, loc,
static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
args.p, args.n, 0, ASRUtils::expr_type(args[1]), m_value);
}

static inline ASR::expr_t* instantiate_FlipSign(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
declare_basic_variables("_lcompilers_optimization_flipsign_" + type_to_str_python(arg_types[1]));
fill_func_arg("signal", arg_types[0]);
fill_func_arg("variable", arg_types[1]);
auto result = declare(fn_name, return_type, ReturnVar);
/*
real(real32) function flipsigni32r32(signal, variable)
integer(int32), intent(in) :: signal
real(real32), intent(out) :: variable
integer(int32) :: q
q = signal/2
flipsigni32r32 = variable
if (signal - 2*q == 1 ) flipsigni32r32 = -variable
end subroutine
*/

ASR::expr_t *two = i(2, arg_types[0]);
ASR::expr_t *q = iDiv(args[0], two);
ASR::expr_t *cond = iSub(args[0], iMul(two, q));
body.push_back(al, b.If(iEq(cond, i(1, arg_types[0])), {
b.Assignment(result, f32_neg(args[1], arg_types[1]))
}, {
b.Assignment(result, args[1])
}));

ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args,
body, result, Source, Implementation, nullptr);
scope->add_symbol(fn_name, f_sym);
return b.Call(f_sym, new_args, return_type, nullptr);
}

} // namespace FlipSign

#define create_exp_macro(X, stdeval) \
namespace X { \
static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \
Expand Down Expand Up @@ -2368,6 +2450,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
{&FMA::instantiate_FMA, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
{&FlipSign::instantiate_FlipSign, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
{&Abs::instantiate_Abs, &Abs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
Expand Down Expand Up @@ -2456,6 +2540,8 @@ namespace IntrinsicScalarFunctionRegistry {
"exp2"},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
"fma"},
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
"flipsign"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
"expm1"},
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),
Expand Down
23 changes: 12 additions & 11 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,24 +588,25 @@ namespace LCompilers {
}


ASR::stmt_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope,
const std::function<void (const std::string &, const Location &)> err) {
ASR::symbol_t *v = import_generic_procedure("flipsign", "lfortran_intrinsic_optimization",
al, unit, pass_options, current_scope, arg0->base.loc);
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc){
ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
arg_types.reserve(al, 2);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
Vec<ASR::call_arg_t> args;
args.reserve(al, 2);
ASR::call_arg_t arg0_, arg1_;
arg0_.loc = arg0->base.loc, arg0_.m_value = arg0;
args.push_back(al, arg0_);
arg1_.loc = arg1->base.loc, arg1_.m_value = arg1;
args.push_back(al, arg1_);
return ASRUtils::STMT(
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
arg0->base.loc, v, args, current_scope, al,
err));
return instantiate_function(al, loc,
unit.m_global_scope, arg_types, type, args, 0);
}

ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int64type, Allocator& al) {
Expand Down
8 changes: 2 additions & 6 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,8 @@ namespace LCompilers {
ASR::expr_t* get_bound(ASR::expr_t* arr_expr, int dim, std::string bound,
Allocator& al);


ASR::stmt_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope,
const std::function<void (const std::string &, const Location &)> err);
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc);

ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int32type, Allocator& al);

Expand Down

0 comments on commit a838e88

Please sign in to comment.