From 84c42fbbb6262ea42a18bab7e8526971fb5b0f99 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 14 Aug 2024 17:03:07 +0200 Subject: [PATCH] Fix numpy v2 breaking changes (#618) * [fix] Fix breaking changes with numpuy>=v2.0.0rc1 * [fix] Update version check to v2.0.0. * [fix] Disable numerical jvp checks for complex sign function. * Lift numpy<2 dependency. --------- Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- autograd/numpy/numpy_boxes.py | 8 +++++++- autograd/numpy/numpy_jvps.py | 4 +++- autograd/numpy/numpy_vjps.py | 3 ++- autograd/numpy/numpy_vspaces.py | 16 ++++++++++++++-- pyproject.toml | 2 +- tests/test_linalg.py | 6 +++++- tests/test_systematic.py | 6 ++++-- 7 files changed, 36 insertions(+), 9 deletions(-) diff --git a/autograd/numpy/numpy_boxes.py b/autograd/numpy/numpy_boxes.py index 579e9688c..e06f8b8d7 100644 --- a/autograd/numpy/numpy_boxes.py +++ b/autograd/numpy/numpy_boxes.py @@ -66,7 +66,13 @@ def __hash__(self): return id(self) # Flatten has no function, only a method. setattr(ArrayBox, 'flatten', anp.__dict__['ravel']) -if np.__version__ >= '1.25': +if np.lib.NumpyVersion(np.__version__) >= '2.0.0': + SequenceBox.register(np.linalg._linalg.EigResult) + SequenceBox.register(np.linalg._linalg.EighResult) + SequenceBox.register(np.linalg._linalg.QRResult) + SequenceBox.register(np.linalg._linalg.SlogdetResult) + SequenceBox.register(np.linalg._linalg.SVDResult) +elif np.__version__ >= '1.25': SequenceBox.register(np.linalg.linalg.EigResult) SequenceBox.register(np.linalg.linalg.EighResult) SequenceBox.register(np.linalg.linalg.QRResult) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 38c6e7353..779dbdbe4 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,3 +1,4 @@ +import numpy as onp from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, @@ -210,7 +211,8 @@ def fwd_grad_sort(g, ans, x, axis=-1, kind='quicksort', order=None): sort_perm = anp.argsort(x, axis, kind, order) return g[sort_perm] defjvp(anp.sort, fwd_grad_sort) -defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0)) +if onp.lib.NumpyVersion(onp.__version__) < '2.0.0': + defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0)) def fwd_grad_partition(g, ans, x, kth, axis=-1, kind='introselect', order=None): partition_perm = anp.argpartition(x, kth, axis, kind, order) diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index ec7c040ae..26f5dd03f 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -559,7 +559,8 @@ def grad_sort(ans, x, axis=-1, kind='quicksort', order=None): sort_perm = anp.argsort(x, axis, kind, order) return lambda g: unpermuter(g, sort_perm) defvjp(anp.sort, grad_sort) -defvjp(anp.msort, grad_sort) # Until multi-D is allowed, these are the same. +if onp.lib.NumpyVersion(onp.__version__) < '2.0.0': + defvjp(anp.msort, grad_sort) # Until multi-D is allowed, these are the same. def grad_partition(ans, x, kth, axis=-1, kind='introselect', order=None): #TODO: Cast input with np.asanyarray() diff --git a/autograd/numpy/numpy_vspaces.py b/autograd/numpy/numpy_vspaces.py index cbe53434b..788e4460f 100644 --- a/autograd/numpy/numpy_vspaces.py +++ b/autograd/numpy/numpy_vspaces.py @@ -4,7 +4,7 @@ class ArrayVSpace(VSpace): def __init__(self, value): - value = np.array(value, copy=False) + value = np.asarray(value) self.shape = value.shape self.dtype = value.dtype @@ -66,7 +66,19 @@ def _covector(self, x): ComplexArrayVSpace.register(type_) -if np.__version__ >= '1.25': +if np.lib.NumpyVersion(np.__version__) >= '2.0.0': + class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EigResult + class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EighResult + class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.QRResult + class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SlogdetResult + class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SVDResult + + EigResultVSpace.register(np.linalg._linalg.EigResult) + EighResultVSpace.register(np.linalg._linalg.EighResult) + QRResultVSpace.register(np.linalg._linalg.QRResult) + SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult) + SVDResultVSpace.register(np.linalg._linalg.SVDResult) +elif np.__version__ >= '1.25': class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult diff --git a/pyproject.toml b/pyproject.toml index 8745ed350..8c9ee473a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ keywords = [ "SciPy", ] dependencies = [ - "numpy<2", + "numpy", ] # dynamic = ["version"] diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 6a57b5a4f..db4f02bc0 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from builtins import range import itertools +import numpy as onp import autograd.numpy as np import autograd.numpy.random as npr from autograd.test_util import check_grads @@ -94,7 +95,10 @@ def test_solve_arg1_3d(): D = 4 A = npr.randn(D+1, D, D) + 5*np.eye(D) B = npr.randn(D+1, D) - fun = lambda A: np.linalg.solve(A, B) + if onp.lib.NumpyVersion(onp.__version__) < '2.0.0': + fun = lambda A: np.linalg.solve(A, B) + else: + fun = lambda A: np.linalg.solve(A, B[..., None])[..., 0] check_grads(fun)(A) def test_solve_arg1_3d_3d(): diff --git a/tests/test_systematic.py b/tests/test_systematic.py index 0ba80c2a2..f6169d0f4 100644 --- a/tests/test_systematic.py +++ b/tests/test_systematic.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +import numpy as onp import autograd.numpy.random as npr import autograd.numpy as np import operator as op @@ -43,7 +44,7 @@ def test_log1p(): unary_ufunc_check(np.log1p, lims=[0.2, 2.0]) def test_log2(): unary_ufunc_check(np.log2, lims=[0.2, 2.0]) def test_rad2deg(): unary_ufunc_check(lambda x : np.rad2deg(x)/50.0, test_complex=False) def test_radians(): unary_ufunc_check(np.radians, test_complex=False) -def test_sign(): unary_ufunc_check(np.sign) +def test_sign(): unary_ufunc_check(np.sign, test_complex=False) def test_sin(): unary_ufunc_check(np.sin) def test_sinh(): unary_ufunc_check(np.sinh) def test_sqrt(): unary_ufunc_check(np.sqrt, lims=[1.0, 3.0]) @@ -155,7 +156,8 @@ def test_fmin(): combo_check(np.fmin, [0, 1])( [R(1), R(1,4), R(3, 4)]) def test_sort(): combo_check(np.sort, [0])([R(1), R(7)]) -def test_msort(): combo_check(np.msort, [0])([R(1), R(7)]) +if onp.lib.NumpyVersion(onp.__version__) < '2.0.0': + def test_msort(): combo_check(np.msort, [0])([R(1), R(7)]) def test_partition(): combo_check(np.partition, [0])( [R(7), R(14)], kth=[0, 3, 6])