Skip to content

Commit

Permalink
Merge pull request #10 from nschloe/coverage
Browse files Browse the repository at this point in the history
fix for krylov
  • Loading branch information
nschloe authored Apr 21, 2021
2 parents 17b4b4a + aa90818 commit 55740ed
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
17 changes: 13 additions & 4 deletions npx/_krylov.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Optional
from typing import Callable, Optional

import numpy as np
import scipy
Expand All @@ -15,12 +15,15 @@ def cg(
tol: float = 1e-05,
maxiter: Optional[int] = None,
M=None,
callback=None,
callback: Optional[Callable] = None,
atol: Optional[float] = 0.0,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

if exact_solution is None:
errnorms = None
else:
Expand Down Expand Up @@ -60,12 +63,15 @@ def gmres(
restart: Optional[int] = None,
maxiter: Optional[int] = None,
M=None,
callback=None,
callback: Optional[Callable] = None,
atol: Optional[float] = 0.0,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

if exact_solution is None:
errnorms = None
else:
Expand Down Expand Up @@ -114,11 +120,14 @@ def minres(
tol: float = 1e-05,
maxiter: Optional[int] = None,
M=None,
callback=None,
callback: Optional[Callable] = None,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

if exact_solution is None:
errnorms = None
else:
Expand Down
4 changes: 3 additions & 1 deletion npx/_minimize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable

import numpy as np
import scipy.optimize


def minimize(fun, x0, *args, **kwargs):
def minimize(fun: Callable, x0, *args, **kwargs):
x0 = np.asarray(x0)
x0_shape = x0.shape

Expand Down
6 changes: 4 additions & 2 deletions npx/_nonlinear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

import numpy as np


Expand All @@ -10,7 +12,7 @@ def _dot_last(a, b):
return out


def bisect(f, a, b, tol, max_num_steps=np.infty):
def bisect(f: Callable, a, b, tol: float, max_num_steps=np.infty):
a = np.asarray(a)
b = np.asarray(b)

Expand Down Expand Up @@ -51,7 +53,7 @@ def bisect(f, a, b, tol, max_num_steps=np.infty):
return a, b


def regula_falsi(f, a, b, tol, max_num_steps=np.infty):
def regula_falsi(f: Callable, a, b, tol: float, max_num_steps=np.infty):
a = np.asarray(a)
b = np.asarray(b)
fa = np.asarray(f(a))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.8
version = 0.0.9
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand Down
8 changes: 5 additions & 3 deletions test/test_krylov.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import npx


def _run(fun, resnorms1, resnorms2, tol=1.0e-13):
def _run(method, resnorms1, resnorms2, tol=1.0e-13):
n = 10
data = -np.ones((3, n))
data[1] = 2.0
A = scipy.sparse.spdiags(data, [-1, 0, 1], n, n)
A = A.tocsr()
b = np.ones(n)

sol, info = fun(A, b)
exact_solution = scipy.sparse.linalg.spsolve(A, b)

sol, info = method(A, b, exact_solution=exact_solution)
assert sol is not None
assert info.success
resnorms1 = np.asarray(resnorms1)
Expand All @@ -23,7 +25,7 @@ def _run(fun, resnorms1, resnorms2, tol=1.0e-13):

# with "preconditioning"
M = scipy.sparse.linalg.LinearOperator((n, n), matvec=lambda x: 0.5 * x)
sol, info = fun(A, b, M=M)
sol, info = method(A, b, M=M)

assert sol is not None
assert info.success
Expand Down

0 comments on commit 55740ed

Please sign in to comment.