Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic ASR pass #2107

Closed
certik opened this issue Jul 5, 2023 · 9 comments
Closed

Symbolic ASR pass #2107

certik opened this issue Jul 5, 2023 · 9 comments
Assignees
Labels

Comments

@certik
Copy link
Contributor

certik commented Jul 5, 2023

The idea is to take this ASR:

                                    (=
                                        (Var 2 z)
                                        (IntrinsicFunction
                                            SymbolicAdd
                                            [(Var 2 x)
                                            (Var 2 y)]
                                            0
                                            (SymbolicExpression)
                                            ()
                                        )
                                        ()
                                    )

And transform it into something like this:

(FunctionCall "symengine_add" x y)
@certik
Copy link
Contributor Author

certik commented Jul 5, 2023

As a first step, let's write Python code using existing LPython:

from lpython import ccall

@ccall(header="symengine/cwrapper.h")
def basic_add(x: CPtr, y: CPtr, z: CPtr) -> None:
    pass

@ccall(header="symengine/cwrapper.h")
def basic_new_stack(x: CPtr) -> None:
    pass

def main0():
    x: CPtr
    basic_new_stack(x)
    y: CPtr
    basic_new_stack(y)
    z: CPtr
    basic_new_stack(z)
    basic_add(z, x, y)
    print("ok")

main0()

And make it work for the symbolics_01.py fully, by calling symengine manually like any other C library.

@anutosh491
Copy link
Collaborator

Hello @certik . I did try some workarounds for this in the last couple of days .
The C code what we get through the above program is quite close to what we want

// Implementations
void main0()
{
    void* x;
    void* y;
    void* z;
    basic_new_stack(x);
    symbol_set(x, "x");
    basic_new_stack(y);
    symbol_set(y, "y");
    basic_new_stack(z);
    basic_add(z, x, y);
}

If only we would have basic x, basic y, basic z here we would be done .
As talking about SymEngine's basic type , what is it exactly

  1. An array containing 1 element of type basic_struct
  2. basic_struct is nothing but a structure containing a void pointer
struct CRCPBasic_C {
    void *data;
#if !defined(WITH_SYMENGINE_RCP)
    void *teuchos_handle;
    int teuchos_strength;
#endif
};

typedef struct CRCPBasic basic_struct;

typedef basic_struct basic[1];

@anutosh491
Copy link
Collaborator

So I thought of trying out a few things like
1)

from lpython import ccall, CPtr, dataclass, empty_c_void_p

@dataclass
class BasicStruct:
    x: CPtr

@ccall(header="symengine/cwrapper.h")
def basic_new_stack(x: list[BasicStruct]) -> None:
    pass

def main0():
    lst: list[BasicStruct]
    y: BasicStruct = BasicStruct()
    y.x = empty_c_void_p()
    lst = [y]
    basic_new_stack(lst)

But I think declaring a list with a non primitive user defined data type is prohibhited.

(lf) anutosh491@spbhat68:~/lpython/lpython$ lpython --show-c examples/expr2.py 
Internal Compiler Error: Unhandled exception
Traceback (most recent call last):

    get_type(list_type->m_type);
AssertFailed: false
  1. I thought maybe using a numpy array would be a workaround , so I tried
from lpython import ccall, CPtr, dataclass, empty_c_void_p
from numpy import array

@dataclass
class BasicStruct:
    x: CPtr

@ccall(header="symengine/cwrapper.h")
def basic_new_stack(x: BasicStruct[1]) -> None:
    pass

def main0():
    b: CPtr = empty_c_void_p()
    y: BasicStruct = BasicStruct(b)
    lst: BasicStruct[1] = array([y])
    basic_new_stack(lst)

main0()

This generates the C code

struct BasicStruct {
 void* x;
};


struct xBasicStruct
{
    struct BasicStruct *data;
    struct dimension_descriptor dims[32];
    int32_t n_dims;
    bool is_allocated;
};


inline void struct_deepcopy_BasicStruct(struct BasicStruct* src, struct BasicStruct* dest);


// Implementations
void main0()
{
    int32_t __1_k;
    void* b;
    struct xBasicStruct lst_value;
    struct xBasicStruct* lst = &lst_value;
    struct BasicStruct lst_data[1];
    lst->data = lst_data;
    lst->n_dims = 1;
    lst->dims[0].lower_bound = 0;
    lst->dims[0].length = 1;
    struct BasicStruct y_value;
    struct BasicStruct* y = &y_value;
    b = NULL;
    y->x = b;
    __1_k = ((int32_t)lst->dims[1-1].lower_bound);
    struct_deepcopy_BasicStruct(y, &lst->data[(__1_k - lst->dims[0].lower_bound)]);
    __1_k = __1_k + 1;
    basic_new_stack(lst);
}

void _lpython_main_program()
{
    main0();
}

Which is actually close to what we want if you go through it , it's just how the types are named after this point , hence we have the following error from symengine .

(lf) anutosh491@spbhat68:~/lpython/lpython$ lpython --backend=c --enable-symengine examples/expr2.py 
expr2__tmp__generated__.c: In function ‘main0’:
expr2__tmp__generated__.c:51:21: warning: passing argument 1 of ‘basic_new_stack’ from incompatible pointer type [-Wincompatible-pointer-types]
   51 |     basic_new_stack(lst);
      |                     ^~~
      |                     |
      |                     struct xBasicStruct *
In file included from expr2__tmp__generated__.c:3:
/home/anutosh491/conda_root/envs/lf/include/symengine/cwrapper.h:100:28: note: expected ‘basic_struct *’ {aka ‘struct CRCPBasic_C *’} but argument is of type ‘struct xBasicStruct *’
  100 | void basic_new_stack(basic s);
      |                      ~~~~~~^

Essentially struct CRCPBasic_C * and struct xBasicStruct * are similar structures here !

@certik
Copy link
Contributor Author

certik commented Jul 9, 2023

I think from a binary compatibility point of view, I think basic is just a raw pointer. I think it would work in LLVM, but probably fail in the C backend due to a C compiler refusing to identify void* with basic.

So we need to figure out how to improve LPython to allow us to interface with basic.

@anutosh491
Copy link
Collaborator

Well yes , unlike getting a segmenation fault through the C backend , the LLVM backend does compile everything successfully.
Though the result returned might not reflect that

(lf) anutosh491@spbhat68:~/lpython/lpython$ cat examples/expr2.py 
from lpython import ccall

@ccall(header="symengine/cwrapper.h")
def basic_new_stack(x: CPtr) -> None:
    pass

@ccall(header="symengine/cwrapper.h")
def basic_const_pi(x: CPtr) -> None:
    pass

def main0():
    x: CPtr
    basic_new_stack(x)
    basic_const_pi(x)
    print(x)

main0()
define void @__module__global_symbols_main0() {
.entry:
  %x = alloca void*, align 8
  %0 = load void*, void** %x, align 8
  call void @basic_new_stack(void* %0)
  %1 = load void*, void** %x, align 8
  call void @basic_const_pi(void* %1)
  %2 = load void*, void** %x, align 8
  %3 = ptrtoint void* %2 to i64
  call void (i8*, ...) @_lfortran_printf(i8* getelementptr inbounds ([7 x i8], [7 x i8]* @2, i32 0, i32 0), i64 %3, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
  br label %return

return:                                           ; preds = %.entry
  ret void
}
(lf) anutosh491@spbhat68:~/lpython/lpython$ lpython --enable-symengine examples/expr2.py 
139665039103144

The LLVM code generated here looks correct to me and we are also using the basic_new_stack function so a valid memory address should have been assigned to x , so I am not sure where the garbage value is being generated from. Looking into it.

@anutosh491
Copy link
Collaborator

We could implement a workaround for the basic type using this framework

#include "symengine/cwrapper.h"

#include <stdlib.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <lfortran_intrinsics.h>


// Implementations
void main0()
{   
    void* x0;
    void* x;
    x = &x0;
    basic_new_stack(x);
    basic_const_pi(x);
    printf("%s\n", basic_str(x));
}

void _lpython_main_program()
{
    main0();
}

int main(int argc, char* argv[])
{
    _lpython_set_argv(argc, argv);
    _lpython_main_program();
    return 0;
}

@anutosh491
Copy link
Collaborator

anutosh491 commented Jul 14, 2023

ASR for

from lpython import ccall

@ccall(header="symengine/cwrapper.h")
def basic_new_heap() -> CPtr:
    pass

@ccall(header="symengine/cwrapper.h")
def basic_const_pi(x: CPtr) -> None:
    pass

@ccall(header="symengine/cwrapper.h")
def basic_str(x: CPtr) -> str:
    pass

def main0():
    x: CPtr = basic_new_heap()
    basic_const_pi(x)
    print(basic_str(x))

main0()

(TranslationUnit
    (SymbolTable
        1
        {
            _global_symbols:
                (Module
                    (SymbolTable
                        8
                        {
                            _lpython_main_program:
                                (Function
                                    (SymbolTable
                                        7
                                        {
                                            
                                        })
                                    _lpython_main_program
                                    (FunctionType
                                        []
                                        ()
                                        Source
                                        Implementation
                                        ()
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        []
                                        []
                                        .false.
                                    )
                                    [main0]
                                    []
                                    [(SubroutineCall
                                        8 main0
                                        ()
                                        []
                                        ()
                                    )]
                                    ()
                                    Public
                                    .false.
                                    .false.
                                    ()
                                ),
                            basic_const_pi:
                                (Function
                                    (SymbolTable
                                        3
                                        {
                                            x:
                                                (Variable
                                                    3
                                                    x
                                                    []
                                                    In
                                                    ()
                                                    ()
                                                    Default
                                                    (CPtr)
                                                    ()
                                                    BindC
                                                    Public
                                                    Required
                                                    .true.
                                                )
                                        })
                                    basic_const_pi
                                    (FunctionType
                                        [(CPtr)]
                                        ()
                                        BindC
                                        Interface
                                        ()
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        []
                                        []
                                        .false.
                                    )
                                    []
                                    [(Var 3 x)]
                                    []
                                    ()
                                    Public
                                    .false.
                                    .false.
                                    "symengine/cwrapper.h"
                                ),
                            basic_new_heap:
                                (Function
                                    (SymbolTable
                                        2
                                        {
                                            _lpython_return_variable:
                                                (Variable
                                                    2
                                                    _lpython_return_variable
                                                    []
                                                    ReturnVar
                                                    ()
                                                    ()
                                                    Default
                                                    (CPtr)
                                                    ()
                                                    BindC
                                                    Public
                                                    Required
                                                    .false.
                                                )
                                        })
                                    basic_new_heap
                                    (FunctionType
                                        []
                                        (CPtr)
                                        BindC
                                        Interface
                                        ()
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        []
                                        []
                                        .false.
                                    )
                                    []
                                    []
                                    []
                                    (Var 2 _lpython_return_variable)
                                    Public
                                    .false.
                                    .false.
                                    "symengine/cwrapper.h"
                                ),
                            basic_str:
                                (Function
                                    (SymbolTable
                                        4
                                        {
                                            _lpython_return_variable:
                                                (Variable
                                                    4
                                                    _lpython_return_variable
                                                    []
                                                    ReturnVar
                                                    ()
                                                    ()
                                                    Default
                                                    (Character 1 -2 ())
                                                    ()
                                                    BindC
                                                    Public
                                                    Required
                                                    .false.
                                                ),
                                            x:
                                                (Variable
                                                    4
                                                    x
                                                    []
                                                    In
                                                    ()
                                                    ()
                                                    Default
                                                    (CPtr)
                                                    ()
                                                    BindC
                                                    Public
                                                    Required
                                                    .true.
                                                )
                                        })
                                    basic_str
                                    (FunctionType
                                        [(CPtr)]
                                        (Character 1 -2 ())
                                        BindC
                                        Interface
                                        ()
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        []
                                        []
                                        .false.
                                    )
                                    []
                                    [(Var 4 x)]
                                    []
                                    (Var 4 _lpython_return_variable)
                                    Public
                                    .false.
                                    .false.
                                    "symengine/cwrapper.h"
                                ),
                            main0:
                                (Function
                                    (SymbolTable
                                        5
                                        {
                                            x:
                                                (Variable
                                                    5
                                                    x
                                                    []
                                                    Local
                                                    ()
                                                    ()
                                                    Default
                                                    (CPtr)
                                                    ()
                                                    Source
                                                    Public
                                                    Required
                                                    .false.
                                                )
                                        })
                                    main0
                                    (FunctionType
                                        []
                                        ()
                                        Source
                                        Implementation
                                        ()
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        .false.
                                        []
                                        []
                                        .false.
                                    )
                                    [basic_new_heap
                                    basic_const_pi
                                    basic_str]
                                    []
                                    [(=
                                        (Var 5 x)
                                        (FunctionCall
                                            8 basic_new_heap
                                            ()
                                            []
                                            (CPtr)
                                            ()
                                            ()
                                        )
                                        ()
                                    )
                                    (SubroutineCall
                                        8 basic_const_pi
                                        ()
                                        [((Var 5 x))]
                                        ()
                                    )
                                    (Print
                                        ()
                                        [(FunctionCall
                                            8 basic_str
                                            ()
                                            [((Var 5 x))]
                                            (Character 1 -2 ())
                                            ()
                                            ()
                                        )]
                                        ()
                                        ()
                                    )]
                                    ()
                                    Public
                                    .false.
                                    .false.
                                    ()
                                )
                        })
                    _global_symbols
                    []
                    .false.
                    .false.
                ),
            main_program:
                (Program
                    (SymbolTable
                        6
                        {
                            _lpython_main_program:
                                (ExternalSymbol
                                    6
                                    _lpython_main_program
                                    8 _lpython_main_program
                                    _global_symbols
                                    []
                                    _lpython_main_program
                                    Public
                                )
                        })
                    main_program
                    [_global_symbols]
                    [(SubroutineCall
                        6 _lpython_main_program
                        ()
                        []
                        ()
                    )]
                )
        })
    []
)

@anutosh491
Copy link
Collaborator

The pass has been correctly implemented through #2255 . Though the ASR pass doesn't support freeing of variables through the basic_free_stack function . I will try adding support for it and once that is done the issue can be closed.

@certik
Copy link
Contributor Author

certik commented Aug 9, 2023

I think this is implemented, I opened an #2262 for the rest.

@certik certik closed this as completed Aug 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants