Skip to content

Commit b35e744

Browse files
Added IntrinsicFunction in ASR (#1616)
Co-authored-by: Thirumalai-Shaktivel <thirumalaishaktivel@gmail.com>
1 parent dfb552e commit b35e744

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+971
-127
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# IntrinsicFunction
2+
3+
An intrinsic function. An **expr** node.
4+
5+
## Declaration
6+
7+
### Syntax
8+
9+
```
10+
IntrinsicFunction(expr* args, int intrinsic_id, int overload_id,
11+
ttype type, expr? value)
12+
```
13+
14+
### Arguments
15+
16+
* `args` represents all arguments passed to the function
17+
* `intrinsic_id` is the unique ID of the generic intrinsic function
18+
* `overload_id` is the ID of the signature within the given generic function
19+
* `type` represents the type of the output
20+
* `value` is an optional compile time value
21+
22+
### Return values
23+
24+
The return value is the expression that the `IntrinsicFunction` represents.
25+
26+
## Description
27+
28+
**IntrinsicFunction** represents an intrinsic function (such as `Abs`,
29+
`Modulo`, `Sin`, `Cos`, `LegendreP`, `FlipSign`, ...) that either the backend
30+
or the middle-end (optimizer) needs to have some special logic for. Typically a
31+
math function, but does not have to be.
32+
33+
IntrinsicFunction is both side-effect-free (no writes to global variables) and
34+
deterministic (no reads from global variables). They are also elemental: can be
35+
vectorized over any argument(s). They can be used inside parallel code and
36+
cached.
37+
38+
The `intrinsic_id` determines the generic function uniquely (`Sin` and `Abs`
39+
have different number, but `IntegerAbs` and `RealAbs` share the number) and
40+
`overload_id` uniquely determines the signature starting from 0 for each
41+
generic function (e.g., `IntegerAbs`, `RealAbs` and `ComplexAbs` can have
42+
`overload_id` equal to 0, 1 and 2, and `RealSin`, `ComplexSin` can be 0, 1).
43+
44+
Backend use cases: Some architectures have special hardware instructions for
45+
operations like Sqrt or Sin and if they are faster than a software
46+
implementation, the backend will use it. This includes the `FlipSign` function
47+
which is our own "special function" that the optimizer emits for certain
48+
conditional floating point operations, and the backend emits an efficient bit
49+
manipulation implementation for architectures that support it.
50+
51+
Middle-end use cases: the middle-end can use the high level semantics to
52+
simplify, such as `sin(e)**2 + cos(e)**2 -> 1`, or it could approximate
53+
expressions like `if (abs(sin(x) - 0.5) < 0.3)` with a lower accuracy version
54+
of `sin`.
55+
56+
We provide ASR -> ASR lowering transformations that substitute the given
57+
intrinsic function with an ASR implementation using more primitive ASR nodes,
58+
typically implemented in the surface language (say a `sin` implementation using
59+
argument reduction and a polynomial fit, or a `sqrt` implementation using a
60+
general power formula `x**(0.5)`, or `LegendreP(2,x)` implementation using a
61+
formula `(3*x**2-1)/2`).
62+
63+
This design also makes it possible to allow selecting using command line
64+
options how certain intrinsic functions should be implemented, for example if
65+
trigonometric functions should be implemented using our own fast
66+
implementation, `libm` accurate implementation, we could also call into other
67+
libraries. These choices should happen at the ASR level, and then the result
68+
further optimized (such as inlined) as needed.
69+
70+
## Types
71+
72+
The argument types in `args` have the types of the corresponding signature as
73+
determined by `intrinsic_id`. For example `IntegerAbs` accepts an integer, but
74+
`RealAbs` accepts a real.
75+
76+
## Examples
77+
78+
The following example code creates `IntrinsicFunction` ASR node:
79+
80+
```fortran
81+
sin(0.5)
82+
```
83+
84+
ASR:
85+
86+
```
87+
(TranslationUnit
88+
(SymbolTable
89+
1
90+
{
91+
})
92+
[(IntrinsicFunction
93+
[(RealConstant
94+
0.500000
95+
(Real 4 [])
96+
)]
97+
0
98+
0
99+
(Real 4 [])
100+
(RealConstant 0.479426 (Real 4 []))
101+
)]
102+
)
103+
```
104+
105+
## See Also
106+
107+
[FunctionCall]()

integration_tests/test_builtin_abs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def test_abs():
3232

3333
b: bool
3434
b = True
35-
assert abs(b) == 1
35+
assert abs(i32(b)) == 1
3636
b = False
37-
assert abs(b) == 0
37+
assert abs(i32(b)) == 0
3838

3939
test_abs()

src/libasr/ASR.asdl

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ expr
224224
| NamedExpr(expr target, expr value, ttype type)
225225
| FunctionCall(symbol name, symbol? original_name, call_arg* args,
226226
ttype type, expr? value, expr? dt)
227+
| IntrinsicFunction(int intrinsic_id, expr* args, int overload_id,
228+
ttype type, expr? value)
227229
| StructTypeConstructor(symbol dt_sym, call_arg* args, ttype type, expr? value)
228230
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
229231
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)

src/libasr/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ set(SRC
4747
pass/unused_functions.cpp
4848
pass/flip_sign.cpp
4949
pass/div_to_mul.cpp
50+
pass/intrinsic_function.cpp
5051
pass/fma.cpp
5152
pass/loop_vectorise.cpp
5253
pass/sign_from_value.cpp

src/libasr/asdl_cpp.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1520,7 +1520,10 @@ def visitField(self, field, cons):
15201520
self.emit( 's.append("()");', 3)
15211521
self.emit("}", 2)
15221522
else:
1523-
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
1523+
if field.name == "intrinsic_id":
1524+
self.emit('s.append(self().convert_intrinsic_id(x.m_%s));' % field.name, 2)
1525+
else:
1526+
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
15241527
elif field.type == "float" and not field.seq and not field.opt:
15251528
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
15261529
elif field.type == "bool" and not field.seq and not field.opt:

src/libasr/asr_utils.h

+17-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <libasr/string_utils.h>
1212
#include <libasr/utils.h>
1313

14+
#include <complex>
15+
1416
namespace LCompilers {
1517

1618
namespace ASRUtils {
@@ -801,7 +803,21 @@ static inline bool all_args_evaluated(const Vec<ASR::array_index_t> &args) {
801803
return true;
802804
}
803805

804-
template <typename T>
806+
static inline bool extract_value(ASR::expr_t* value_expr,
807+
std::complex<double>& value) {
808+
if( !ASR::is_a<ASR::ComplexConstant_t>(*value_expr) ) {
809+
return false;
810+
}
811+
812+
ASR::ComplexConstant_t* value_const = ASR::down_cast<ASR::ComplexConstant_t>(value_expr);
813+
value = std::complex(value_const->m_re, value_const->m_im);
814+
return true;
815+
}
816+
817+
template <typename T,
818+
typename = typename std::enable_if<
819+
std::is_same<T, std::complex<double>>::value == false &&
820+
std::is_same<T, std::complex<float>>::value == false>::type>
805821
static inline bool extract_value(ASR::expr_t* value_expr, T& value) {
806822
if( !is_value_constant(value_expr) ) {
807823
return false;

src/libasr/codegen/asr_to_c.cpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <libasr/pass/class_constructor.h>
1414
#include <libasr/pass/array_op.h>
1515
#include <libasr/pass/subroutine_from_function.h>
16+
#include <libasr/pass/intrinsic_function.h>
1617

1718
#include <map>
1819
#include <utility>
@@ -680,6 +681,18 @@ R"(
680681
}
681682
}
682683

684+
// Process global functions
685+
size_t i;
686+
for (i = 0; i < global_func_order.size(); i++) {
687+
ASR::symbol_t* sym = x.m_global_scope->get_symbol(global_func_order[i]);
688+
// Ignore external symbols because they are already defined by the loop above.
689+
if( !sym || ASR::is_a<ASR::ExternalSymbol_t>(*sym) ) {
690+
continue ;
691+
}
692+
visit_symbol(*sym);
693+
unit_src += src;
694+
}
695+
683696
// Process modules in the right order
684697
std::vector<std::string> build_order
685698
= ASRUtils::determine_module_dependencies(x);
@@ -693,18 +706,6 @@ R"(
693706
}
694707
}
695708

696-
// Process global functions
697-
size_t i;
698-
for (i = 0; i < global_func_order.size(); i++) {
699-
ASR::symbol_t* sym = x.m_global_scope->get_symbol(global_func_order[i]);
700-
// Ignore external symbols because they are already defined by the loop above.
701-
if( !sym || ASR::is_a<ASR::ExternalSymbol_t>(*sym) ) {
702-
continue ;
703-
}
704-
visit_symbol(*sym);
705-
unit_src += src;
706-
}
707-
708709
// Then the main program:
709710
for (auto &item : x.m_global_scope->get_scope()) {
710711
if (ASR::is_a<ASR::Program_t>(*item.second)) {

src/libasr/codegen/asr_to_cpp.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <libasr/asr_utils.h>
1111
#include <libasr/string_utils.h>
1212
#include <libasr/pass/unused_functions.h>
13+
#include <libasr/pass/intrinsic_function.h>
1314

1415

1516
namespace LCompilers {
@@ -745,6 +746,7 @@ Result<std::string> asr_to_cpp(Allocator &al, ASR::TranslationUnit_t &asr,
745746
LCompilers::PassOptions pass_options;
746747
pass_options.always_run = true;
747748
pass_unused_functions(al, asr, pass_options);
749+
pass_replace_intrinsic_function(al, asr, pass_options);
748750
ASRToCPPVisitor v(diagnostics, co, default_lower_bound);
749751
try {
750752
v.visit_asr((ASR::asr_t &)asr);

src/libasr/codegen/asr_to_wasm.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include <libasr/pass/unused_functions.h>
1616
#include <libasr/pass/pass_array_by_data.h>
1717
#include <libasr/pass/print_arr.h>
18+
#include <libasr/pass/intrinsic_function.h>
19+
#include <libasr/exception.h>
20+
#include <libasr/asr_utils.h>
1821

1922
// #define SHOW_ASR
2023

@@ -3136,6 +3139,7 @@ Result<Vec<uint8_t>> asr_to_wasm_bytes_stream(ASR::TranslationUnit_t &asr,
31363139
pass_array_by_data(al, asr, pass_options);
31373140
pass_replace_print_arr(al, asr, pass_options);
31383141
pass_replace_do_loops(al, asr, pass_options);
3142+
pass_replace_intrinsic_function(al, asr, pass_options);
31393143
pass_options.always_run = true;
31403144
pass_unused_functions(al, asr, pass_options);
31413145

src/libasr/pass/array_op.cpp

+83
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,89 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
594594
replace_ArrayOpCommon<ASR::LogicalCompare_t>(x, "_logical_comp_op_res");
595595
}
596596

597+
void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
598+
LCOMPILERS_ASSERT(current_scope != nullptr);
599+
const Location& loc = x->base.base.loc;
600+
std::vector<bool> array_mask(x->n_args, false);
601+
bool at_least_one_array = false;
602+
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
603+
array_mask[iarg] = ASRUtils::is_array(
604+
ASRUtils::expr_type(x->m_args[iarg]));
605+
at_least_one_array = at_least_one_array || array_mask[iarg];
606+
}
607+
if (!at_least_one_array) {
608+
return ;
609+
}
610+
std::string res_prefix = "_elemental_func_call_res";
611+
ASR::expr_t* result_var_copy = result_var;
612+
bool is_all_rank_0 = true;
613+
std::vector<ASR::expr_t*> operands;
614+
ASR::expr_t* operand = nullptr;
615+
int common_rank = 0;
616+
bool are_all_rank_same = true;
617+
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
618+
result_var = nullptr;
619+
ASR::expr_t** current_expr_copy_9 = current_expr;
620+
current_expr = &(x->m_args[iarg]);
621+
self().replace_expr(x->m_args[iarg]);
622+
operand = *current_expr;
623+
current_expr = current_expr_copy_9;
624+
operands.push_back(operand);
625+
int rank_operand = PassUtils::get_rank(operand);
626+
if( common_rank == 0 ) {
627+
common_rank = rank_operand;
628+
}
629+
if( common_rank != rank_operand &&
630+
rank_operand > 0 ) {
631+
are_all_rank_same = false;
632+
}
633+
array_mask[iarg] = (rank_operand > 0);
634+
is_all_rank_0 = is_all_rank_0 && (rank_operand <= 0);
635+
}
636+
if( is_all_rank_0 ) {
637+
return ;
638+
}
639+
if( !are_all_rank_same ) {
640+
throw LCompilersException("Broadcasting support not yet available "
641+
"for different shape arrays.");
642+
}
643+
result_var = result_var_copy;
644+
if( result_var == nullptr ) {
645+
result_var = PassUtils::create_var(result_counter, res_prefix,
646+
loc, operand, al, current_scope);
647+
result_counter += 1;
648+
}
649+
*current_expr = result_var;
650+
651+
Vec<ASR::expr_t*> idx_vars, loop_vars;
652+
std::vector<int> loop_var_indices;
653+
Vec<ASR::stmt_t*> doloop_body;
654+
create_do_loop(loc, common_rank,
655+
idx_vars, loop_vars, loop_var_indices, doloop_body,
656+
[=, &operands, &idx_vars, &doloop_body] () {
657+
Vec<ASR::expr_t*> ref_args;
658+
ref_args.reserve(al, x->n_args);
659+
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
660+
ASR::expr_t* ref = operands[iarg];
661+
if( array_mask[iarg] ) {
662+
ref = PassUtils::create_array_ref(operands[iarg], idx_vars, al);
663+
}
664+
ref_args.push_back(al, ref);
665+
}
666+
Vec<ASR::dimension_t> empty_dim;
667+
empty_dim.reserve(al, 1);
668+
ASR::ttype_t* dim_less_type = ASRUtils::duplicate_type(al, x->m_type, &empty_dim);
669+
ASR::expr_t* op_el_wise = ASRUtils::EXPR(ASR::make_IntrinsicFunction_t(al, loc,
670+
x->m_intrinsic_id, ref_args.p, ref_args.size(), x->m_overload_id,
671+
dim_less_type, nullptr));
672+
ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al);
673+
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, res, op_el_wise, nullptr));
674+
doloop_body.push_back(al, assign);
675+
});
676+
use_custom_loop_params = false;
677+
result_var = nullptr;
678+
}
679+
597680
void replace_FunctionCall(ASR::FunctionCall_t* x) {
598681
std::string x_name;
599682
if( x->m_name->type == ASR::symbolType::ExternalSymbol ) {

src/libasr/pass/global_symbols.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace LCompilers {
1616

1717
void pass_wrap_global_syms_into_module(Allocator &al,
1818
ASR::TranslationUnit_t &unit,
19-
const LCompilers::PassOptions& pass_options) {
19+
const LCompilers::PassOptions &/*pass_options*/) {
2020
Location loc = unit.base.base.loc;
2121
char *module_name = s2c(al, "_global_symbols");
2222
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);

0 commit comments

Comments
 (0)