From e68afd3e9d04f2df112bae1db3a99cb9785373c8 Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Thu, 21 Nov 2024 06:17:30 +0300 Subject: [PATCH] Correct argument handling for constructors Before, e.g. for the mpz: >>> import gmpy2 >>> gmpy2.__version__ '2.2.1' >>> gmpy2.mpz(s=1) mpz(0) >>> gmpy2.mpz(1, s=1) Traceback (most recent call last): File "", line 1, in TypeError: argument for function given by name ('s') and position (1) --- src/gmpy2_cache.c | 39 ++++++++++++++++++--------------------- test/test_mpc.py | 8 ++++++++ test/test_mpfr.py | 3 +++ test/test_mpq.py | 4 +++- test/test_mpz.py | 3 +++ test/test_xmpz.py | 5 +++++ 6 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/gmpy2_cache.c b/src/gmpy2_cache.c index 93f1f459..7fd00038 100644 --- a/src/gmpy2_cache.c +++ b/src/gmpy2_cache.c @@ -69,7 +69,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) PyObject *out = NULL; int base = 0; Py_ssize_t argc; - static char *kwlist[] = {"s", "base", NULL }; + static char *kwlist[] = {"", "base", NULL}; CTXT_Object *context = NULL; if (type != &MPZ_Type) { @@ -81,7 +81,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) argc = PyTuple_GET_SIZE(args); - if (argc == 0) { + if (argc == 0 && !keywds) { return (PyObject*)GMPy_MPZ_New(context); } @@ -158,7 +158,7 @@ GMPy_MPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) return NULL; } - if ((base != 0) && ((base < 2)|| (base > 62))) { + if (base != 0 && (base < 2 || base > 62)) { VALUE_ERROR("base for mpz() must be 0 or in the interval [2, 62]"); return NULL; } @@ -181,7 +181,6 @@ GMPy_MPZ_Dealloc(MPZ_Object *self) { if (global.in_gmpympzcache < CACHE_SIZE && self->z->_mp_alloc <= MAX_CACHE_MPZ_LIMBS) { - global.gmpympzcache[(global.in_gmpympzcache)++] = self; } else { @@ -220,7 +219,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) PyObject *temp = NULL; int base = 0; Py_ssize_t argc; - static char *kwlist[] = {"s", "base", NULL }; + static char *kwlist[] = {"", "base", NULL}; CTXT_Object *context = NULL; if (type != &XMPZ_Type) { @@ -232,7 +231,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) argc = PyTuple_GET_SIZE(args); - if (argc == 0) { + if (argc == 0 && !keywds) { return (PyObject*)GMPy_XMPZ_New(context); } @@ -294,7 +293,7 @@ GMPy_XMPZ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) return NULL; } - if ((base != 0) && ((base < 2)|| (base > 62))) { + if (base != 0 && (base < 2 || base > 62)) { VALUE_ERROR("base for xmpz() must be 0 or in the interval [2, 62]"); return NULL; } @@ -317,7 +316,6 @@ GMPy_XMPZ_Dealloc(XMPZ_Object *self) { if (global.in_gmpyxmpzcache < CACHE_SIZE && self->z->_mp_alloc <= MAX_CACHE_MPZ_LIMBS) { - global.gmpyxmpzcache[(global.in_gmpyxmpzcache)++] = self; } else { @@ -356,7 +354,7 @@ GMPy_MPQ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) PyObject *n = NULL, *m = NULL; int base = 10; Py_ssize_t argc, keywdc = 0; - static char *kwlist[] = {"s", "base", NULL }; + static char *kwlist[] = {"", "base", NULL}; CTXT_Object *context = NULL; if (type != &MPQ_Type) { @@ -389,7 +387,7 @@ GMPy_MPQ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) n = PyTuple_GetItem(args, 0); /* Handle the case where the first argument is a string. */ - if (PyStrOrUnicode_Check(n)) { + if (PyStrOrUnicode_Check(n) || keywdc) { /* keyword base is legal */ if (keywdc || argc > 1) { if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|i", kwlist, &n, &base))) { @@ -397,7 +395,7 @@ GMPy_MPQ_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) } } - if ((base != 0) && ((base < 2) || (base > 62))) { + if (base != 0 && (base < 2 || base > 62)) { VALUE_ERROR("base for mpq() must be 0 or in the interval [2, 62]"); return NULL; } @@ -505,8 +503,8 @@ GMPy_MPFR_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) /* Assumes mpfr_prec_t is the same as a long. */ mpfr_prec_t prec = 0; - static char *kwlist_s[] = {"s", "precision", "base", "context", NULL}; - static char *kwlist_n[] = {"n", "precision", "context", NULL}; + static char *kwlist_s[] = {"", "precision", "base", "context", NULL}; + static char *kwlist_n[] = {"", "precision", "context", NULL}; if (type != &MPFR_Type) { TYPE_ERROR("mpfr.__new__() requires mpfr type"); @@ -675,9 +673,9 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) /* Assumes mpfr_prec_t is the same as a long. */ mpfr_prec_t rprec = 0, iprec = 0; - static char *kwlist_c[] = {"c", "precision", "context", NULL}; - static char *kwlist_r[] = {"real", "imag", "precision", "context", NULL}; - static char *kwlist_s[] = {"s", "precision", "base", "context", NULL}; + static char *kwlist_c[] = {"", "precision", "context", NULL}; + static char *kwlist_r[] = {"", "imag", "precision", "context", NULL}; + static char *kwlist_s[] = {"", "precision", "base", "context", NULL}; if (type != &MPC_Type) { TYPE_ERROR("mpc.__new__() requires mpc type"); @@ -777,7 +775,7 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) if (IS_REAL(arg0)) { if (keywdc || argc > 1) { if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|OOO", kwlist_r, - &arg0, &arg1, &prec, &context))) + &arg0, &arg1, &prec, &context))) return NULL; } @@ -843,9 +841,9 @@ GMPy_MPC_NewInit(PyTypeObject *type, PyObject *args, PyObject *keywds) if (IS_COMPLEX_ONLY(arg0)) { if (keywdc || argc > 1) { - if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|O", kwlist_c, - &arg0, &prec))) - return NULL; + if (!(PyArg_ParseTupleAndKeywords(args, keywds, "O|OO", kwlist_c, + &arg0, &prec, &context))) + return NULL; } if (prec) { @@ -900,4 +898,3 @@ GMPy_MPC_Dealloc(MPC_Object *self) PyObject_Free(self); } } - diff --git a/test/test_mpc.py b/test/test_mpc.py index c1529252..87cae879 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -109,6 +109,14 @@ def test_mpc_creation(): assert mpc('1_2+4_5j') == mpc('12.0+45.0j') + pytest.raises(TypeError, lambda: mpc(1, base=2)) + pytest.raises(TypeError, lambda: mpc(1, s=2)) + pytest.raises(TypeError, lambda: mpc("1", s=2)) + pytest.raises(TypeError, lambda: mpc("1", imag=2)) + pytest.raises(TypeError, lambda: mpc(1j, imag=2)) + pytest.raises(TypeError, lambda: mpc(1j, base=2)) + pytest.raises(TypeError, lambda: mpc(1j, s=2)) + def test_mpc_random(): assert (mpc_random(random_state(42)) diff --git a/test/test_mpfr.py b/test/test_mpfr.py index e931c695..e72f6c6c 100644 --- a/test/test_mpfr.py +++ b/test/test_mpfr.py @@ -212,6 +212,9 @@ def test_mpfr_create(): assert repr(mpfr(1.0/7, precision=1)) == "mpfr('0.14285714285714285')" assert repr(mpfr(1.0/7, precision=5)) == "mpfr('0.141',5)" + pytest.raises(TypeError, lambda: mpfr(1, base=2)) + pytest.raises(TypeError, lambda: mpfr("1", s=2)) + @settings(max_examples=1000) @given(floats()) diff --git a/test/test_mpq.py b/test/test_mpq.py index fb5859fe..f49ec8d5 100644 --- a/test/test_mpq.py +++ b/test/test_mpq.py @@ -16,10 +16,12 @@ to_binary, xmpz) -def test_mpz_constructor(): +def test_mpq_constructor(): assert mpq('1/1') == mpq(1,1) assert mpq('1_/_1') == mpq(1,1) + pytest.raises(TypeError, lambda: mpq('1', s=1)) + def test_mpq_as_integer_ratio(): assert mpq(2, 3).as_integer_ratio() == (mpz(2), mpz(3)) diff --git a/test/test_mpz.py b/test/test_mpz.py index 2d8288bc..261151a5 100644 --- a/test/test_mpz.py +++ b/test/test_mpz.py @@ -441,6 +441,9 @@ def test_mpz_create(): assert mpz('1 2') == mpz(12) assert mpz(' 1 2') == mpz(12) + raises(TypeError, lambda: mpz(s=1)) + raises(TypeError, lambda: mpz(1, s=2)) + @settings(max_examples=10000) @given(integers()) diff --git a/test/test_xmpz.py b/test/test_xmpz.py index 7ba81652..c4915d10 100644 --- a/test/test_xmpz.py +++ b/test/test_xmpz.py @@ -90,6 +90,11 @@ def test_xmpz_misc(): assert repr(xmpz(42)) == 'xmpz(42)' +def test_xmpz_create(): + pytest.raises(TypeError, lambda: xmpz(s=1)) + pytest.raises(TypeError, lambda: xmpz(1, s=2)) + + def test_xmpz_subscripts(): x = xmpz(10)