Skip to content

Commit

Permalink
Correct argument handling for constructors
Browse files Browse the repository at this point in the history
Before, e.g. for the mpz:
>>> import gmpy2
>>> gmpy2.__version__
'2.2.1'
>>> gmpy2.mpz(s=1)
mpz(0)
>>> gmpy2.mpz(1, s=1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: argument for function given by name ('s') and position (1)
  • Loading branch information
skirpichev authored and casevh committed Dec 23, 2024
1 parent 3c59644 commit e68afd3
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 22 deletions.
39 changes: 18 additions & 21 deletions src/gmpy2_cache.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
PyObject *out = NULL;
int base = 0;
Py_ssize_t argc;
static char *kwlist[] = {"s", "base", NULL };
static char *kwlist[] = {"", "base", NULL};
CTXT_Object *context = NULL;

if (type != &MPZ_Type) {
Expand All @@ -81,7 +81,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)

argc = PyTuple_GET_SIZE(args);

if (argc == 0) {
if (argc == 0 && !keywds) {
return (PyObject*)GMPy_MPZ_New(context);
}

Expand Down Expand Up @@ -158,7 +158,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
return NULL;
}

if ((base != 0) && ((base < 2)|| (base > 62))) {
if (base != 0 && (base < 2 || base > 62)) {
VALUE_ERROR("base for mpz() must be 0 or in the interval [2, 62]");
return NULL;
}
Expand All @@ -181,7 +181,6 @@ GMPy_MPZ_Dealloc(MPZ_Object *self)
{
if (global.in_gmpympzcache < CACHE_SIZE &&
self->z->_mp_alloc <= MAX_CACHE_MPZ_LIMBS) {

global.gmpympzcache[(global.in_gmpympzcache)++] = self;
}
else {
Expand Down Expand Up @@ -220,7 +219,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
PyObject *temp = NULL;
int base = 0;
Py_ssize_t argc;
static char *kwlist[] = {"s", "base", NULL };
static char *kwlist[] = {"", "base", NULL};
CTXT_Object *context = NULL;

if (type != &XMPZ_Type) {
Expand All @@ -232,7 +231,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)

argc = PyTuple_GET_SIZE(args);

if (argc == 0) {
if (argc == 0 && !keywds) {
return (PyObject*)GMPy_XMPZ_New(context);
}

Expand Down Expand Up @@ -294,7 +293,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
return NULL;
}

if ((base != 0) && ((base < 2)|| (base > 62))) {
if (base != 0 && (base < 2 || base > 62)) {
VALUE_ERROR("base for xmpz() must be 0 or in the interval [2, 62]");
return NULL;
}
Expand All @@ -317,7 +316,6 @@ GMPy_XMPZ_Dealloc(XMPZ_Object *self)
{
if (global.in_gmpyxmpzcache < CACHE_SIZE &&
self->z->_mp_alloc <= MAX_CACHE_MPZ_LIMBS) {

global.gmpyxmpzcache[(global.in_gmpyxmpzcache)++] = self;
}
else {
Expand Down Expand Up @@ -356,7 +354,7 @@ GMPy_MPQ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
PyObject *n = NULL, *m = NULL;
int base = 10;
Py_ssize_t argc, keywdc = 0;
static char *kwlist[] = {"s", "base", NULL };
static char *kwlist[] = {"", "base", NULL};
CTXT_Object *context = NULL;

if (type != &MPQ_Type) {
Expand Down Expand Up @@ -389,15 +387,15 @@ GMPy_MPQ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
n = PyTuple_GetItem(args, 0);

/* Handle the case where the first argument is a string. */
if (PyStrOrUnicode_Check(n)) {
if (PyStrOrUnicode_Check(n) || keywdc) {
/* keyword base is legal */
if (keywdc || argc > 1) {
if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|i", kwlist, &n, &base))) {
return NULL;
}
}

if ((base != 0) && ((base < 2) || (base > 62))) {
if (base != 0 && (base < 2 || base > 62)) {
VALUE_ERROR("base for mpq() must be 0 or in the interval [2, 62]");
return NULL;
}
Expand Down Expand Up @@ -505,8 +503,8 @@ GMPy_MPFR_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
/* Assumes mpfr_prec_t is the same as a long. */
mpfr_prec_t prec = 0;

static char *kwlist_s[] = {"s", "precision", "base", "context", NULL};
static char *kwlist_n[] = {"n", "precision", "context", NULL};
static char *kwlist_s[] = {"", "precision", "base", "context", NULL};
static char *kwlist_n[] = {"", "precision", "context", NULL};

if (type != &MPFR_Type) {
TYPE_ERROR("mpfr.__new__() requires mpfr type");
Expand Down Expand Up @@ -675,9 +673,9 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
/* Assumes mpfr_prec_t is the same as a long. */
mpfr_prec_t rprec = 0, iprec = 0;

static char *kwlist_c[] = {"c", "precision", "context", NULL};
static char *kwlist_r[] = {"real", "imag", "precision", "context", NULL};
static char *kwlist_s[] = {"s", "precision", "base", "context", NULL};
static char *kwlist_c[] = {"", "precision", "context", NULL};
static char *kwlist_r[] = {"", "imag", "precision", "context", NULL};
static char *kwlist_s[] = {"", "precision", "base", "context", NULL};

if (type != &MPC_Type) {
TYPE_ERROR("mpc.__new__() requires mpc type");
Expand Down Expand Up @@ -777,7 +775,7 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)
if (IS_REAL(arg0)) {
if (keywdc || argc > 1) {
if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|OOO", kwlist_r,
&arg0, &arg1, &prec, &context)))
&arg0, &arg1, &prec, &context)))
return NULL;
}

Expand Down Expand Up @@ -843,9 +841,9 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds)

if (IS_COMPLEX_ONLY(arg0)) {
if (keywdc || argc > 1) {
if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|O", kwlist_c,
&arg0, &prec)))
return NULL;
if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|OO", kwlist_c,
&arg0, &prec, &context)))
return NULL;
}

if (prec) {
Expand Down Expand Up @@ -900,4 +898,3 @@ GMPy_MPC_Dealloc(MPC_Object *self)
PyObject_Free(self);
}
}

8 changes: 8 additions & 0 deletions test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def test_mpc_creation():

assert mpc('1_2+4_5j') == mpc('12.0+45.0j')

pytest.raises(TypeError, lambda: mpc(1, base=2))
pytest.raises(TypeError, lambda: mpc(1, s=2))
pytest.raises(TypeError, lambda: mpc("1", s=2))
pytest.raises(TypeError, lambda: mpc("1", imag=2))
pytest.raises(TypeError, lambda: mpc(1j, imag=2))
pytest.raises(TypeError, lambda: mpc(1j, base=2))
pytest.raises(TypeError, lambda: mpc(1j, s=2))


def test_mpc_random():
assert (mpc_random(random_state(42))
Expand Down
3 changes: 3 additions & 0 deletions test/test_mpfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def test_mpfr_create():
assert repr(mpfr(1.0/7, precision=1)) == "mpfr('0.14285714285714285')"
assert repr(mpfr(1.0/7, precision=5)) == "mpfr('0.141',5)"

pytest.raises(TypeError, lambda: mpfr(1, base=2))
pytest.raises(TypeError, lambda: mpfr("1", s=2))


@settings(max_examples=1000)
@given(floats())
Expand Down
4 changes: 3 additions & 1 deletion test/test_mpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
to_binary, xmpz)


def test_mpz_constructor():
def test_mpq_constructor():
assert mpq('1/1') == mpq(1,1)
assert mpq('1_/_1') == mpq(1,1)

pytest.raises(TypeError, lambda: mpq('1', s=1))


def test_mpq_as_integer_ratio():
assert mpq(2, 3).as_integer_ratio() == (mpz(2), mpz(3))
Expand Down
3 changes: 3 additions & 0 deletions test/test_mpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def test_mpz_create():
assert mpz('1 2') == mpz(12)
assert mpz(' 1 2') == mpz(12)

raises(TypeError, lambda: mpz(s=1))
raises(TypeError, lambda: mpz(1, s=2))


@settings(max_examples=10000)
@given(integers())
Expand Down
5 changes: 5 additions & 0 deletions test/test_xmpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_xmpz_misc():
assert repr(xmpz(42)) == 'xmpz(42)'


def test_xmpz_create():
pytest.raises(TypeError, lambda: xmpz(s=1))
pytest.raises(TypeError, lambda: xmpz(1, s=2))


def test_xmpz_subscripts():
x = xmpz(10)

Expand Down

0 comments on commit e68afd3

Please sign in to comment.