Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR Pass] Handle ArraySection and SIMDArray BinOp (LFortran Sync) #2426

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2615,7 +2615,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
} else {
l = "1";
}
size_str += "*" + l;
size_str += "*" + sym + "->dims[" + std::to_string(j) + "].length";
out += indent + sym + "->dims[" + std::to_string(j) + "].lower_bound = ";
out += st + ";\n";
out += indent + sym + "->dims[" + std::to_string(j) + "].length = ";
Expand Down
70 changes: 65 additions & 5 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::string current_der_type_name;

//! Helpful for debugging while testing LLVM code
void print_util(llvm::Value* v, std::string fmt_chars, std::string endline="\t") {
void print_util(llvm::Value* v, std::string fmt_chars, std::string endline) {
// Usage:
// print_util(tmp, "%d") // `tmp` to be an integer type
// print_util(tmp, "%d", "\n") // `tmp` is an integer type to match the format specifiers
std::vector<llvm::Value *> args;
std::vector<std::string> fmt;
args.push_back(v);
Expand Down Expand Up @@ -2320,7 +2320,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
Vec<llvm::Value*> llvm_diminfo;
llvm_diminfo.reserve(al, 2 * x.n_args + 1);
if( array_t->m_physical_type == ASR::array_physical_typeType::PointerToDataArray ||
array_t->m_physical_type == ASR::array_physical_typeType::FixedSizeArray ) {
array_t->m_physical_type == ASR::array_physical_typeType::FixedSizeArray ||
array_t->m_physical_type == ASR::array_physical_typeType::SIMDArray ) {
int ptr_loads_copy = ptr_loads;
for( size_t idim = 0; idim < x.n_args; idim++ ) {
ptr_loads = 2 - !LLVM::is_llvm_pointer(*ASRUtils::expr_type(m_dims[idim].m_start));
Expand Down Expand Up @@ -2354,7 +2355,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else {
tmp = arr_descr->get_single_element(array, indices, x.n_args,
array_t->m_physical_type == ASR::array_physical_typeType::PointerToDataArray,
array_t->m_physical_type == ASR::array_physical_typeType::FixedSizeArray,
array_t->m_physical_type == ASR::array_physical_typeType::FixedSizeArray || array_t->m_physical_type == ASR::array_physical_typeType::SIMDArray,
llvm_diminfo.p, is_polymorphic, current_select_type_block_type);
}
}
Expand Down Expand Up @@ -4705,10 +4706,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
ASR::ttype_t* target_type = ASRUtils::expr_type(x.m_target);
ASR::ttype_t* value_type = ASRUtils::expr_type(x.m_value);
ASR::expr_t *m_value = x.m_value;
if (ASRUtils::is_simd_array(x.m_target) && ASR::is_a<ASR::ArraySection_t>(*m_value)) {
m_value = ASR::down_cast<ASR::ArraySection_t>(m_value)->m_v;
}
int ptr_loads_copy = ptr_loads;
ptr_loads = 2 - (ASRUtils::is_character(*value_type) ||
ASRUtils::is_array(value_type));
this->visit_expr_wrapper(x.m_value, true);
this->visit_expr_wrapper(m_value, true);
ptr_loads = ptr_loads_copy;
if( ASR::is_a<ASR::Var_t>(*x.m_value) &&
ASR::is_a<ASR::Union_t>(*value_type) ) {
Expand Down Expand Up @@ -4752,6 +4757,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_value_data_only_array = (value_ptype == ASR::array_physical_typeType::PointerToDataArray);
bool is_target_fixed_sized_array = (target_ptype == ASR::array_physical_typeType::FixedSizeArray);
bool is_value_fixed_sized_array = (value_ptype == ASR::array_physical_typeType::FixedSizeArray);
bool is_target_simd_array = (target_ptype == ASR::array_physical_typeType::SIMDArray);
// bool is_target_descriptor_based_array = (target_ptype == ASR::array_physical_typeType::DescriptorArray);
bool is_value_descriptor_based_array = (value_ptype == ASR::array_physical_typeType::DescriptorArray);
if( is_value_fixed_sized_array && is_target_fixed_sized_array ) {
Expand Down Expand Up @@ -4844,6 +4850,27 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
arr_descr->copy_array_data_only(value_data, target_data, module.get(),
llvm_data_type, llvm_size);
}
} else if ( is_target_simd_array ) {
if (ASR::is_a<ASR::ArraySection_t>(*x.m_value)) {
int idx = 1;
ASR::ArraySection_t *arr = down_cast<ASR::ArraySection_t>(x.m_value);
(void) ASRUtils::extract_value(arr->m_args->m_left, idx);
value = llvm_utils->create_gep(value, idx-1);
target = llvm_utils->create_gep(target, 0);
ASR::dimension_t* asr_dims = nullptr;
size_t asr_n_dims = ASRUtils::extract_dimensions_from_ttype(target_type, asr_dims);
int64_t size = ASRUtils::get_fixed_size_of_array(asr_dims, asr_n_dims);
llvm::Type* llvm_data_type = llvm_utils->get_type_from_ttype_t_util(ASRUtils::type_get_past_array(
ASRUtils::type_get_past_allocatable(ASRUtils::type_get_past_pointer(target_type))), module.get());
llvm::DataLayout data_layout(module.get());
uint64_t data_size = data_layout.getTypeAllocSize(llvm_data_type);
llvm::Value* llvm_size = llvm::ConstantInt::get(context, llvm::APInt(32, size));
llvm_size = builder->CreateMul(llvm_size,
llvm::ConstantInt::get(context, llvm::APInt(32, data_size)));
builder->CreateMemCpy(target, llvm::MaybeAlign(), value, llvm::MaybeAlign(), llvm_size);
} else {
builder->CreateStore(value, target);
}
} else {
arr_descr->copy_array(value, target, module.get(),
target_type, false, false);
Expand Down Expand Up @@ -4929,6 +4956,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ASRUtils::expr_value(m_arg) == nullptr ) {
tmp = llvm_utils->create_gep(tmp, 0);
}
} else if (
m_new == ASR::array_physical_typeType::SIMDArray &&
m_old == ASR::array_physical_typeType::FixedSizeArray) {
// pass
} else if(
m_new == ASR::array_physical_typeType::DescriptorArray &&
m_old == ASR::array_physical_typeType::FixedSizeArray) {
Expand Down Expand Up @@ -5988,6 +6019,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value *right_val = tmp;
lookup_enum_value_for_nonints = false;
LCOMPILERS_ASSERT(ASRUtils::is_real(*x.m_type))
if (ASRUtils::is_simd_array(x.m_right) && is_a<ASR::Var_t>(*x.m_right)) {
right_val = CreateLoad(right_val);
}
if (ASRUtils::is_simd_array(x.m_left) && is_a<ASR::Var_t>(*x.m_left)) {
left_val = CreateLoad(left_val);
}
switch (x.m_op) {
case ASR::binopType::Add: {
tmp = builder->CreateFAdd(left_val, right_val);
Expand Down Expand Up @@ -9149,6 +9186,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = LLVM::CreateLoad(*builder, target);
break;
}
case ASR::array_physical_typeType::SIMDArray: {
if( x.m_bound == ASR::arrayboundType::LBound ) {
tmp = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
} else if( x.m_bound == ASR::arrayboundType::UBound ) {
int64_t size = ASRUtils::get_fixed_size_of_array(ASRUtils::expr_type(x.m_v));
tmp = llvm::ConstantInt::get(context, llvm::APInt(32, size));
}
break;
}
default: {
LCOMPILERS_ASSERT(false);
}
Expand Down Expand Up @@ -9178,6 +9224,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t &x) {
this->visit_expr_wrapper(x.m_array, true);
llvm::Value *value = tmp;
llvm::Type* ele_type = llvm_utils->get_type_from_ttype_t_util(
ASRUtils::type_get_past_array(x.m_type), module.get());
size_t n_eles = ASRUtils::get_fixed_size_of_array(x.m_type);
llvm::Type* vec_type = FIXED_VECTOR_TYPE::get(ele_type, n_eles);
llvm::AllocaInst *vec = builder->CreateAlloca(vec_type, nullptr);
for (size_t i=0; i < n_eles; i++) {
builder->CreateStore(value, llvm_utils->create_gep(vec, i));
}
tmp = CreateLoad(vec);
}

};


Expand Down
61 changes: 54 additions & 7 deletions src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,16 +716,25 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
op_n_dims = x_dims.size();
}

ASR::ttype_t* x_m_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, loc,
ASRUtils::type_get_past_allocatable(ASRUtils::duplicate_type(al,
ASRUtils::type_get_past_pointer(x->m_type), &empty_dims))));

ASR::ttype_t* x_m_type;
if (op_expr && ASRUtils::is_simd_array(op_expr)) {
x_m_type = ASRUtils::expr_type(op_expr);
} else {
x_m_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, loc,
ASRUtils::type_get_past_allocatable(ASRUtils::duplicate_type(al,
ASRUtils::type_get_past_pointer(x->m_type), &empty_dims))));
}
ASR::expr_t* array_section_pointer = PassUtils::create_var(
result_counter, "_array_section_pointer_", loc,
x_m_type, al, current_scope);
result_counter += 1;
pass_result.push_back(al, ASRUtils::STMT(ASRUtils::make_Associate_t_util(
al, loc, array_section_pointer, *current_expr)));
if (op_expr && ASRUtils::is_simd_array(op_expr)) {
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(
al, loc, array_section_pointer, *current_expr, nullptr)));
} else {
pass_result.push_back(al, ASRUtils::STMT(ASRUtils::make_Associate_t_util(
al, loc, array_section_pointer, *current_expr)));
}
*current_expr = array_section_pointer;

// Might get used in other replace_* methods as well.
Expand All @@ -740,6 +749,33 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {

template <typename T>
void replace_ArrayOpCommon(T* x, std::string res_prefix) {
bool is_left_simd = ASRUtils::is_simd_array(x->m_left);
bool is_right_simd = ASRUtils::is_simd_array(x->m_right);
if ( is_left_simd && is_right_simd ) {
return;
} else if ( ( is_left_simd && !is_right_simd) ||
(!is_left_simd && is_right_simd) ) {
ASR::expr_t** current_expr_copy = current_expr;
ASR::expr_t* op_expr_copy = op_expr;
if (is_left_simd) {
// Replace ArraySection, case: a = a + b(:4)
if (ASR::is_a<ASR::ArraySection_t>(*x->m_right)) {
current_expr = &(x->m_right);
op_expr = x->m_left;
this->replace_expr(x->m_right);
}
} else {
// Replace ArraySection, case: a = b(:4) + a
if (ASR::is_a<ASR::ArraySection_t>(*x->m_left)) {
current_expr = &(x->m_left);
op_expr = x->m_right;
this->replace_expr(x->m_left);
}
}
current_expr = current_expr_copy;
op_expr = op_expr_copy;
return;
}
const Location& loc = x->base.base.loc;
bool current_status = use_custom_loop_params;
use_custom_loop_params = false;
Expand Down Expand Up @@ -1587,7 +1623,18 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit

void visit_Assignment(const ASR::Assignment_t &x) {
if (ASRUtils::is_simd_array(x.m_target)) {
return;
size_t n_dims = 1;
if (ASR::is_a<ASR::ArraySection_t>(*x.m_value)) {
n_dims = ASRUtils::extract_n_dims_from_ttype(
ASRUtils::expr_type(down_cast<ASR::ArraySection_t>(
x.m_value)->m_v));
}
if (n_dims == 1) {
if (!ASR::is_a<ASR::ArrayPhysicalCast_t>(*x.m_value)) {
this->visit_expr(*x.m_value);
}
return;
}
}
if( (ASR::is_a<ASR::Pointer_t>(*ASRUtils::expr_type(x.m_target)) &&
ASR::is_a<ASR::GetPointer_t>(*x.m_value)) ||
Expand Down
Loading