diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 645015e999..ddfe84d1f9 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -714,6 +714,7 @@ RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_11.py b/integration_tests/symbolics_11.py new file mode 100644 index 0000000000..9517b22303 --- /dev/null +++ b/integration_tests/symbolics_11.py @@ -0,0 +1,18 @@ +from sympy import Symbol, sin, pi +from lpython import S + +def test_extraction_of_elements(): + x: S = Symbol("x") + l1: list[S] = [x, pi, sin(x), Symbol("y")] + ele1: S = l1[0] + ele2: S = l1[1] + ele3: S = l1[2] + ele4: S = l1[3] + + assert(ele1 == x) + assert(ele2 == pi) + assert(ele3 == sin(x)) + assert(ele4 == Symbol("y")) + print(ele1, ele2, ele3, ele4) + +test_extraction_of_elements() \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 9bb9ef414a..1433c00410 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -245,6 +245,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::List) { + ASR::List_t* list = ASR::down_cast(xx.m_type); + if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); + ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type)); + xx.m_type = list_type; + } } } @@ -920,6 +927,47 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::ListConstant_t* list_constant = ASR::down_cast(x.m_value); + if (list_constant->m_type->type == ASR::ttypeType::List) { + ASR::List_t* list = ASR::down_cast(list_constant->m_type); + if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ + Vec temp_list; + temp_list.reserve(al, list_constant->n_args + 1); + + for (size_t i = 0; i < list_constant->n_args; ++i) { + ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]); + temp_list.push_back(al, value); + } + + ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); + ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type)); + ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p, + temp_list.size(), list_type)); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr)); + pass_result.push_back(al, stmt); + } + } + } else if (ASR::is_a(*x.m_value)) { + ASR::ListItem_t* list_item = ASR::down_cast(x.m_value); + if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) { + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); + ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope); + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = x.m_target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a, + list_item->m_pos, CPtr_type, nullptr)); + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym, + basic_assign_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } } else if (ASR::is_a(*x.m_value)) { ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_value); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {