Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
RCP[const FunctionSymbol] rcp_static_cast_FunctionSymbol "SymEngine::rcp_static_cast<const SymEngine::FunctionSymbol>"(RCP[const Basic] &b) nogil
RCP[const FunctionWrapper] rcp_static_cast_FunctionWrapper "SymEngine::rcp_static_cast<const SymEngine::FunctionWrapper>"(RCP[const Basic] &b) nogil
RCP[const Abs] rcp_static_cast_Abs "SymEngine::rcp_static_cast<const SymEngine::Abs>"(RCP[const Basic] &b) nogil
RCP[const Max] rcp_static_cast_Max "SymEngine::rcp_static_cast<const SymEngine::Max>"(RCP[const Basic] &b) nogil
RCP[const Min] rcp_static_cast_Min "SymEngine::rcp_static_cast<const SymEngine::Min>"(RCP[const Basic] &b) nogil
RCP[const Gamma] rcp_static_cast_Gamma "SymEngine::rcp_static_cast<const SymEngine::Gamma>"(RCP[const Basic] &b) nogil
RCP[const Derivative] rcp_static_cast_Derivative "SymEngine::rcp_static_cast<const SymEngine::Derivative>"(RCP[const Basic] &b) nogil
RCP[const Subs] rcp_static_cast_Subs "SymEngine::rcp_static_cast<const SymEngine::Subs>"(RCP[const Basic] &b) nogil
Expand Down Expand Up @@ -251,6 +253,8 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
bool is_a_ASech "SymEngine::is_a<SymEngine::ASech>"(const Basic &b) nogil
bool is_a_FunctionSymbol "SymEngine::is_a<SymEngine::FunctionSymbol>"(const Basic &b) nogil
bool is_a_Abs "SymEngine::is_a<SymEngine::Abs>"(const Basic &b) nogil
bool is_a_Max "SymEngine::is_a<SymEngine::Max>"(const Basic &b) nogil
bool is_a_Min "SymEngine::is_a<SymEngine::Min>"(const Basic &b) nogil
bool is_a_Gamma "SymEngine::is_a<SymEngine::Gamma>"(const Basic &b) nogil
bool is_a_Derivative "SymEngine::is_a<SymEngine::Derivative>"(const Basic &b) nogil
bool is_a_Subs "SymEngine::is_a<SymEngine::Subs>"(const Basic &b) nogil
Expand Down Expand Up @@ -433,6 +437,8 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
cdef RCP[const Basic] asech(RCP[const Basic] &arg) nogil except+
cdef RCP[const Basic] function_symbol(string name, const vec_basic &arg) nogil except+
cdef RCP[const Basic] abs(RCP[const Basic] &arg) nogil except+
cdef RCP[const Basic] max(const vec_basic &arg) nogil except+
cdef RCP[const Basic] min(const vec_basic &arg) nogil except+
cdef RCP[const Basic] gamma(RCP[const Basic] &arg) nogil except+
cdef RCP[const Basic] atan2(RCP[const Basic] &num, RCP[const Basic] &den) nogil except+

Expand Down Expand Up @@ -539,6 +545,12 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
cdef cppclass Abs(Function):
RCP[const Basic] get_arg() nogil

cdef cppclass Max(Function):
pass

cdef cppclass Min(Function):
pass

cdef cppclass Gamma(Function):
pass

Expand Down
74 changes: 74 additions & 0 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ cdef c2py(RCP[const symengine.Basic] o):
r = FunctionSymbol.__new__(FunctionSymbol)
elif (symengine.is_a_Abs(deref(o))):
r = Abs.__new__(Abs)
elif (symengine.is_a_Max(deref(o))):
r = Max.__new__(Max)
elif (symengine.is_a_Min(deref(o))):
r = Min.__new__(Min)
elif (symengine.is_a_Gamma(deref(o))):
r = Gamma.__new__(Gamma)
elif (symengine.is_a_Derivative(deref(o))):
Expand Down Expand Up @@ -208,6 +212,10 @@ def sympy2symengine(a, raise_error=False):
return log(a.args[0])
elif isinstance(a, sympy.Abs):
return abs(sympy2symengine(a.args[0], raise_error))
elif isinstance(a, sympy.Max):
return _max(*a.args)
elif isinstance(a, sympy.Min):
return _min(*a.args)
elif isinstance(a, sympy.gamma):
return gamma(a.args[0])
elif isinstance(a, sympy.Derivative):
Expand Down Expand Up @@ -654,6 +662,20 @@ cdef class Basic(object):
def has(self, *symbols):
return any([has_symbol(self, symbol) for symbol in symbols])

def args_as_sage(Basic self):
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
s = []
for i in range(Y.size()):
s.append(c2py(<RCP[const symengine.Basic]>(Y[i]))._sage_())
return s

def args_as_sympy(Basic self):
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
s = []
for i in range(Y.size()):
s.append(c2py(<RCP[const symengine.Basic]>(Y[i]))._sympy_())
return s

def series(ex, x=None, x0=0, n=6, as_deg_coef_pair=False):
# TODO: check for x0 an infinity, see sympy/core/expr.py
# TODO: nonzero x0
Expand Down Expand Up @@ -1393,6 +1415,42 @@ cdef class Abs(Function):
return abs(arg)


class Max(Function):

def __new__(cls, *args):
if not args:
return super(Max, cls).__new__(cls)
return _max(*args)

def _sympy_(self):
import sympy
s = self.args_as_sympy()
return sympy.Max(*s)

def _sage_(self):
import sage.all as sage
s = self.args_as_sage()
return sage.max(*s)


class Min(Function):

def __new__(cls, *args):
if not args:
return super(Min, cls).__new__(cls)
return _min(*args)

def _sympy_(self):
import sympy
s = self.args_as_sympy()
return sympy.Min(*s)

def _sage_(self):
import sage.all as sage
s = self.args_as_sage()
return sage.min(*s)


cdef class Derivative(Basic):

@property
Expand Down Expand Up @@ -2304,6 +2362,22 @@ def log(x, y = None):
cdef Basic Y = _sympify(y)
return c2py(symengine.log(X.thisptr, Y.thisptr))

def _max(*args):
cdef symengine.vec_basic v
cdef Basic e_
for e in args:
e_ = sympify(e)
v.push_back(e_.thisptr)
return c2py(symengine.max(v))

def _min(*args):
cdef symengine.vec_basic v
cdef Basic e_
for e in args:
e_ = sympify(e)
v.push_back(e_.thisptr)
return c2py(symengine.min(v))

def gamma(x):
cdef Basic X = _sympify(x)
return c2py(symengine.gamma(X.thisptr))
Expand Down
2 changes: 1 addition & 1 deletion symengine/sympy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
SympifyError, sqrt, I, E, pi, Matrix, Derivative, exp,
nextprime, mod_inverse, primitive_root, Lambdify as lambdify,
symarray, diff, eye, diag, ones, zeros, expand, Subs,
FunctionSymbol as AppliedUndef)
FunctionSymbol as AppliedUndef, Max, Min)
from types import ModuleType
import sys

Expand Down
36 changes: 34 additions & 2 deletions symengine/tests/test_sympy_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from symengine.sympy_compat import (Integer, Rational, S, Basic, Add, Mul,
Pow, symbols, Symbol, log, sin, sech, csch, zeros, atan2, Number, Float,
symengine)
Pow, symbols, Symbol, log, sin, cos, sech, csch, zeros, atan2, Number, Float,
symengine, Min, Max)
from symengine.utilities import raises


Expand Down Expand Up @@ -86,6 +86,38 @@ def test_Pow():
assert isinstance(i, Basic)


def test_Max():
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
assert Max(Integer(6)/3, 1) == 2
assert Max(-2, 2) == 2
assert Max(2, 2) == 2
assert Max(0.2, 0.3) == 0.3
assert Max(x, x) == x
assert Max(x, y) == Max(y, x)
assert Max(x, y, z) == Max(z, y, x)
assert Max(x, Max(y, z)) == Max(z, y, x)
assert Max(1000, 100, -100, x, y, z) == Max(x, y, z, 1000)
assert Max(cos(x), sin(x)) == Max(sin(x), cos(x))


def test_Min():
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
assert Min(Integer(6)/3, 1) == 1
assert Min(-2, 2) == -2
assert Min(2, 2) == 2
assert Min(0.2, 0.3) == 0.2
assert Min(x, x) == x
assert Min(x, y) == Min(y, x)
assert Min(x, y, z) == Min(z, y, x)
assert Min(x, Min(y, z)) == Min(z, y, x)
assert Min(1000, 100, -100, x, y, z) == Min(x, y, z, -100)
assert Min(cos(x), sin(x)) == Min(cos(x), sin(x))


def test_sin():
x = symbols("x")
i = sin(0)
Expand Down