diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
index 1e49f8bc..21b6cc3a 100644
--- a/.github/workflows/cd-build.yml
+++ b/.github/workflows/cd-build.yml
@@ -14,32 +14,28 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
- python-version: 3.8
- auto-activate-base: false
+ python-version: "3.10"
+ cache: pip
- name: Install dependencies
shell: bash -l {0}
run: |
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
python -m pip install twine
- python -m pip install .
+ python -m pip install .[doc,test]
- name: Run Tests
shell: bash -l {0}
run: |
- python setup.py test
+ pytest
- name: Check distribution
shell: bash -l {0}
run: |
- python setup.py sdist
twine check dist/*
- name: Upload coverage to Codecov
@@ -57,22 +53,15 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: 3.8
- auto-activate-base: false
- name: Install dependencies
shell: bash -l {0}
run: |
conda install -c conda-forge pandoc
python -m pip install --upgrade pip
- python -m pip install -r docs/requirements.txt
- python -m pip install .
+ python -m pip install .[doc]
- name: Build API documentation
shell: bash -l {0}
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index 3ffcb6f4..3a209d12 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -16,70 +16,41 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: [3.8]
+ python-version: ["3.8", "3.9", "3.10"]
steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Report WPS Errors
- uses: wemake-services/wemake-python-styleguide@0.14.1
- continue-on-error: true
- with:
- reporter: 'github-pr-review'
- path: './modopt'
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Check Conda
- shell: bash -l {0}
- run: |
- conda info
- conda list
- python --version
+ cache: pip
- name: Install Dependencies
shell: bash -l {0}
run: |
python --version
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install -r docs/requirements.txt
- python -m pip install astropy scikit-image scikit-learn
- python -m pip install tensorflow>=2.4.1
- python -m pip install twine
- python -m pip install .
+ python -m pip install .[test]
+ python -m pip install astropy scikit-image scikit-learn matplotlib
+ python -m pip install tensorflow>=2.4.1 torch
- name: Run Tests
shell: bash -l {0}
run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- python setup.py test
+ pytest -n 2
- name: Save Test Results
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v4
with:
name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
- path: pytest.xml
-
- - name: Check Distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
+ path: coverage.xml
- name: Check API Documentation build
shell: bash -l {0}
run: |
- conda install -c conda-forge pandoc
+ apt install pandoc
+ pip install .[doc] ipykernel
sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
sphinx-build -b doctest -E docs/source docs/_build
@@ -90,38 +61,3 @@ jobs:
file: coverage.xml
flags: unittests
- test-basic:
- name: Basic Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: [3.6, 3.7, 3.9]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install astropy scikit-image scikit-learn
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- python setup.py test
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
new file mode 100644
index 00000000..45fc2b23
--- /dev/null
+++ b/.github/workflows/style.yml
@@ -0,0 +1,38 @@
+name: Style checking
+
+on:
+ push:
+ branches: [ "master", "main", "develop" ]
+ pull_request:
+ branches: [ "master", "main", "develop" ]
+
+ workflow_dispatch:
+
+env:
+ PYTHON_VERSION: "3.10"
+
+jobs:
+ linter-check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Set up Python ${{ env.PYTHON_VERSION }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ env.PYTHON_VERSION }}
+ cache: pip
+
+ - name: Install Python deps
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[test,dev]
+
+ - name: Black Check
+ shell: bash
+ run: black . --diff --color --check
+
+ - name: ruff Check
+ shell: bash
+ run: ruff check
diff --git a/.gitignore b/.gitignore
index 06dff8db..f9eaaa68 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,6 +73,7 @@ instance/
docs/_build/
docs/source/fortuna.*
docs/source/scripts.*
+docs/source/auto_examples/
docs/source/*.nblink
# PyBuilder
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 3ac9aef9..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,2 +0,0 @@
-[MASTER]
-ignore-patterns=**/docs/**/*.py
diff --git a/.pyup.yml b/.pyup.yml
deleted file mode 100644
index 8fdac7ff..00000000
--- a/.pyup.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-# autogenerated pyup.io config file
-# see https://pyup.io/docs/configuration/ for all available options
-
-schedule: ''
-update: all
-label_prs: update
-assignees: sfarrens
-requirements:
- - requirements.txt:
- pin: False
- - develop.txt:
- pin: False
- - docs/requirements.txt:
- pin: True
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 9a2f374e..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,5 +0,0 @@
-include requirements.txt
-include develop.txt
-include docs/requirements.txt
-include README.rst
-include LICENSE.txt
diff --git a/README.md b/README.md
index acb316ad..223d0b73 100644
--- a/README.md
+++ b/README.md
@@ -37,11 +37,11 @@ All packages required by ModOpt should be installed automatically. Optional pack
In order to run the code in this repository the following packages must be
installed:
-* [Python](https://www.python.org/) [> 3.6]
+* [Python](https://www.python.org/) [> 3.7]
* [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0]
* [Numpy](http://www.numpy.org/) [==1.19.5]
* [Scipy](http://www.scipy.org/) [==1.5.4]
-* [Progressbar 2](https://progressbar-2.readthedocs.io/) [==3.53.1]
+* [tqdm](https://tqdm.github.io/) [>=4.64.0]
### Optional Packages
diff --git a/develop.txt b/develop.txt
deleted file mode 100644
index 8beef0ff..00000000
--- a/develop.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-coverage>=5.5
-flake8>=4
-nose>=1.3.7
-pytest>=6.2.2
-pytest-cov>=2.11.1
-pytest-pep8>=1.0.6
-pytest-emoji>=0.2.0
-pytest-flake8>=1.0.7
-wemake-python-styleguide>=0.15.2
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 4d2a14fb..c9e29c88 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -6,3 +6,4 @@ numpydoc==1.1.0
sphinx==4.3.1
sphinxcontrib-bibtex==2.4.1
sphinxawesome-theme==3.2.1
+sphinx-gallery==0.11.1
diff --git a/docs/source/conf.py b/docs/source/conf.py
index fb954f6d..69921008 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Python Template sphinx config
# Import relevant modules
@@ -9,55 +8,53 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../..'))
+sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# General information about the project.
-project = 'modopt'
+project = "modopt"
mdata = metadata(project)
-author = mdata['Author']
-version = mdata['Version']
-copyright = '2020, {}'.format(author)
-gh_user = 'sfarrens'
-
-# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '3.3'
+author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu"
+version = mdata["Version"]
+copyright = f"2020, {author}"
+gh_user = "sfarrens"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.coverage',
- 'sphinx.ext.doctest',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode',
- 'sphinxawesome_theme',
- 'sphinxcontrib.bibtex',
- 'myst_parser',
- 'nbsphinx',
- 'nbsphinx_link',
- 'numpydoc',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.coverage",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinxawesome_theme.highlighting",
+ "sphinxcontrib.bibtex",
+ "myst_parser",
+ "nbsphinx",
+ "nbsphinx_link",
+ "numpydoc",
+ "sphinx_gallery.gen_gallery",
]
# Include module names for objects
add_module_names = False
# Set class documentation standard.
-autoclass_content = 'class'
+autoclass_content = "class"
# Audodoc options
autodoc_default_options = {
- 'member-order': 'bysource',
- 'private-members': True,
- 'show-inheritance': True
+ "member-order": "bysource",
+ "private-members": True,
+ "show-inheritance": True,
}
# Generate summaries
@@ -68,17 +65,17 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = True
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'default'
+pygments_style = "default"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -87,7 +84,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'sphinxawesome_theme'
+html_theme = "sphinxawesome_theme"
# html_theme = 'sphinx_book_theme'
# Theme options are theme-specific and customize the look and feel of a theme
@@ -100,11 +97,10 @@
"breadcrumbs_separator": "/",
"show_prev_next": True,
"show_scrolltop": True,
-
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = 'modopt_logo.jpg'
+html_logo = "modopt_logo.png"
html_permalinks_icon = (
'
'''
+ r""" """
+ r""""""
)
nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
@@ -227,28 +233,28 @@ def add_notebooks(nb_path='../../notebooks'):
# Refer to the package libraries for type definitions
intersphinx_mapping = {
- 'python': ('http://docs.python.org/3', None),
- 'numpy': ('https://numpy.org/doc/stable/', None),
- 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
- 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None),
- 'matplotlib': ('https://matplotlib.org', None),
- 'astropy': ('http://docs.astropy.org/en/latest/', None),
- 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None),
- 'torch': ('https://pytorch.org/docs/stable/', None),
- 'sklearn': (
- 'http://scikit-learn.org/stable',
- (None, './_intersphinx/sklearn-objects.inv')
+ "python": ("http://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
+ "matplotlib": ("https://matplotlib.org", None),
+ "astropy": ("http://docs.astropy.org/en/latest/", None),
+ "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+ "sklearn": (
+ "http://scikit-learn.org/stable",
+ (None, "./_intersphinx/sklearn-objects.inv"),
),
- 'tensorflow': (
- 'https://www.tensorflow.org/api_docs/python',
+ "tensorflow": (
+ "https://www.tensorflow.org/api_docs/python",
(
- 'https://github.com/GPflow/tensorflow-intersphinx/'
- + 'raw/master/tf2_py_objects.inv')
- )
-
+ "https://github.com/GPflow/tensorflow-intersphinx/"
+ + "raw/master/tf2_py_objects.inv"
+ ),
+ ),
}
# -- BibTeX Setting ----------------------------------------------
-bibtex_bibfiles = ['refs.bib', 'my_ref.bib']
-bibtex_default_style = 'alpha'
+bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
+bibtex_default_style = "alpha"
diff --git a/docs/source/refs.bib b/docs/source/refs.bib
index d8365e71..7782ca52 100644
--- a/docs/source/refs.bib
+++ b/docs/source/refs.bib
@@ -207,3 +207,15 @@ @article{zou2005
journal = {Journal of the Royal Statistical Society Series B},
doi = {10.1111/j.1467-9868.2005.00527.x}
}
+
+@article{Goldstein2014,
+ author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard},
+ year={2014},
+ month={Jan},
+ pages={1588–1623},
+ title={Fast Alternating Direction Optimization Methods},
+ journal={SIAM Journal on Imaging Sciences},
+ volume={7},
+ ISSN={1936-4954},
+ doi={10/gdwr49},
+}
diff --git a/docs/source/toc.rst b/docs/source/toc.rst
index 84a6af87..ef5753f5 100644
--- a/docs/source/toc.rst
+++ b/docs/source/toc.rst
@@ -25,6 +25,7 @@
plugin_example
notebooks
+ auto_examples/index
.. toctree::
:hidden:
diff --git a/examples/README.rst b/examples/README.rst
new file mode 100644
index 00000000..e6ffbe27
--- /dev/null
+++ b/examples/README.rst
@@ -0,0 +1,5 @@
+========
+Examples
+========
+
+This is a collection of Python scripts demonstrating the use of ModOpt.
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 00000000..d7e77357
--- /dev/null
+++ b/examples/__init__.py
@@ -0,0 +1,10 @@
+"""EXAMPLES.
+
+This module contains documented examples that demonstrate the usage of various
+ModOpt tools.
+
+These examples also serve as integration tests for various methods.
+
+:Author: Pierre-Antoine Comby
+
+"""
diff --git a/examples/conftest.py b/examples/conftest.py
new file mode 100644
index 00000000..f3ed371b
--- /dev/null
+++ b/examples/conftest.py
@@ -0,0 +1,49 @@
+"""TEST CONFIGURATION.
+
+This module contains methods for configuring the testing of the example
+scripts.
+
+:Author: Pierre-Antoine Comby
+
+Notes
+-----
+Based on:
+https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
+
+"""
+
+from pathlib import Path
+import runpy
+import pytest
+
+
+def pytest_collect_file(path, parent):
+ """Pytest hook.
+
+ Create a collector for the given path, or None if not relevant.
+ The new node needs to have the specified parent as parent.
+ """
+ p = Path(path)
+ if p.suffix == ".py" and "example" in p.name:
+ return Script.from_parent(parent, path=p, name=p.name)
+
+
+class Script(pytest.File):
+ """Script files collected by pytest."""
+
+ def collect(self):
+ """Collect the script as its own item."""
+ yield ScriptItem.from_parent(self, name=self.name)
+
+
+class ScriptItem(pytest.Item):
+ """Item script collected by pytest."""
+
+ def runtest(self):
+ """Run the script as a test."""
+ runpy.run_path(str(self.path))
+
+ def repr_failure(self, excinfo):
+ """Return only the error traceback of the script."""
+ excinfo.traceback = excinfo.traceback.cut(path=self.path)
+ return super().repr_failure(excinfo)
diff --git a/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py
new file mode 100644
index 00000000..f3e5091d
--- /dev/null
+++ b/examples/example_lasso_forward_backward.py
@@ -0,0 +1,153 @@
+"""
+Solving the LASSO Problem with the Forward Backward Algorithm.
+==============================================================
+
+This an example to show how to solve an example LASSO Problem
+using the Forward-Backward Algorithm.
+
+In this example we are going to use:
+ - Modopt Operators (Linear, Gradient, Proximal)
+ - Modopt implementation of solvers
+ - Modopt Metric API.
+TODO: add reference to LASSO paper.
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from modopt.opt.algorithms import ForwardBackward, POGM
+from modopt.opt.cost import costObj
+from modopt.opt.linear import LinearParent, Identity
+from modopt.opt.gradient import GradBasic
+from modopt.opt.proximity import SparseThreshold
+from modopt.math.matrix import PowerMethod
+from modopt.math.stats import mse
+
+# %%
+# Here we create a instance of the LASSO Problem
+
+BETA_TRUE = np.array(
+ [3.0, 1.5, 0, 0, 2, 0, 0, 0]
+) # 8 original values from lLASSO Paper
+DIM = len(BETA_TRUE)
+
+
+rng = np.random.default_rng()
+sigma_noise = 1
+obs = 20
+# create a measurement matrix with decaying covariance matrix.
+cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM))
+x = rng.multivariate_normal(np.zeros(DIM), cov, obs)
+
+y = x @ BETA_TRUE
+y_noise = y + (sigma_noise * np.random.standard_normal(obs))
+
+
+# %%
+# Next we create Operators for solving the problem.
+
+# MatrixOperator could also work here.
+lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb)
+grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op)
+
+prox_op = SparseThreshold(Identity(), 1, thresh_type="soft")
+
+# %%
+# In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator
+#
+
+calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True)
+lip = calc_lips.spec_rad
+print("lipschitz constant:", lip)
+
+# %%
+# Solving using FISTA algorithm
+# -----------------------------
+#
+# TODO: Add description/Reference of FISTA.
+
+cost_op_fista = costObj([grad_op, prox_op], verbose=False)
+
+fb_fista = ForwardBackward(
+ np.zeros(8),
+ beta_param=1 / lip,
+ grad=grad_op,
+ prox=prox_op,
+ cost=cost_op_fista,
+ metric_call_period=1,
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+)
+
+fb_fista.iterate()
+
+# %%
+# After the run we can have a look at the results
+
+print(fb_fista.x_final)
+mse_fista = mse(fb_fista.x_final, BETA_TRUE)
+plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-")
+plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
+plt.legend()
+plt.title(f"FISTA Estimation MSE={mse_fista:.4f}")
+
+# sphinx_gallery_start_ignore
+assert mse(fb_fista.x_final, BETA_TRUE) < 1
+# sphinx_gallery_end_ignore
+
+
+# %%
+# Solving Using the POGM Algorithm
+# --------------------------------
+#
+# TODO: Add description/Reference to POGM.
+
+
+cost_op_pogm = costObj([grad_op, prox_op], verbose=False)
+
+fb_pogm = POGM(
+ np.zeros(8),
+ np.zeros(8),
+ np.zeros(8),
+ np.zeros(8),
+ beta_param=1 / lip,
+ grad=grad_op,
+ prox=prox_op,
+ cost=cost_op_pogm,
+ metric_call_period=1,
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+)
+
+fb_pogm.iterate()
+
+# %%
+# After the run we can have a look at the results
+
+print(fb_pogm.x_final)
+mse_pogm = mse(fb_pogm.x_final, BETA_TRUE)
+
+plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-")
+plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
+plt.legend()
+plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}")
+#
+# sphinx_gallery_start_ignore
+assert mse(fb_pogm.x_final, BETA_TRUE) < 1
+# sphinx_gallery_end_ignore
+
+# %%
+# Comparing the Two algorithms
+# ----------------------------
+
+plt.figure()
+plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence")
+plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence")
+plt.xlabel("iterations")
+plt.ylabel("Cost Function")
+plt.legend()
+plt.show()
+
+
+# %%
+# We can see that the two algorithm converges quickly, and POGM requires less iterations.
+# However the POGM iterations are more costly, so a proper benchmark with time measurement is needed.
+# Check the benchopt benchmark for more details.
diff --git a/modopt/__init__.py b/modopt/__init__.py
deleted file mode 100644
index 2c06c1db..00000000
--- a/modopt/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""MODOPT PACKAGE.
-
-ModOpt is a series of Modular Optimisation tools for solving inverse problems.
-
-"""
-
-from warnings import warn
-
-from importlib_metadata import version
-
-from modopt.base import *
-
-try:
- _version = version('modopt')
-except Exception: # pragma: no cover
- _version = 'Unkown'
- warn(
- 'Could not extract package metadata. Make sure the package is '
- + 'correctly installed.',
- )
-
-__version__ = _version
diff --git a/modopt/base/wrappers.py b/modopt/base/wrappers.py
deleted file mode 100644
index baedb891..00000000
--- a/modopt/base/wrappers.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""WRAPPERS.
-
-This module contains wrappers for adding additional features to functions.
-
-:Author: Samuel Farrens
-
-"""
-
-from functools import wraps
-from inspect import getfullargspec as argspec
-
-
-def add_args_kwargs(func):
- """Add args and kwargs.
-
- This wrapper adds support for additional arguments and keyword arguments to
- any callable function.
-
- Parameters
- ----------
- func : callable
- Callable function
-
- Returns
- -------
- callable
- wrapper
-
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
-
- props = argspec(func)
-
- # if 'args' not in props:
- if isinstance(props[1], type(None)):
- args = args[:len(props[0])]
-
- if (
- (not isinstance(props[2], type(None)))
- or (not isinstance(props[3], type(None)))
- ):
- return func(*args, **kwargs)
-
- return func(*args)
-
- return wrapper
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
deleted file mode 100644
index 7ff96a8b..00000000
--- a/modopt/tests/test_algorithms.py
+++ /dev/null
@@ -1,470 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR OPT.ALGORITHMS.
-
-This module contains unit tests for the modopt.opt.algorithms module.
-
-:Author: Samuel Farrens
-
-"""
-
-from unittest import TestCase
-
-import numpy as np
-import numpy.testing as npt
-
-from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
-
-# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val ** 2
-func_cube = lambda x_val: x_val ** 3
-
-
-class Dummy(object):
- """Dummy class for tests."""
-
- pass
-
-
-class AlgorithmTestCase(TestCase):
- """Test case for algorithms module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6
- self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1
-
- grad_inst = gradient.GradBasic(
- self.data1,
- func_identity,
- func_identity,
- )
-
- prox_inst = proximity.Positivity()
- prox_dual_inst = proximity.IdentityProx()
- linear_inst = linear.Identity()
- reweight_inst = reweight.cwbReweight(self.data3)
- cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
- self.setup = algorithms.SetUp()
- self.max_iter = 20
-
- self.fb_all_iter = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- auto_iterate=False,
- beta_update=func_identity,
- )
- self.fb_all_iter.iterate(self.max_iter)
-
- self.fb1 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- )
-
- self.fb2 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- lambda_update=None,
- )
-
- self.fb3 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- a_cd=3,
- )
-
- self.fb4 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- r_lazy=3,
- p_lazy=0.7,
- q_lazy=0.7,
- )
-
- self.fb5 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='adaptive',
- xi_restart=0.9,
- )
-
- self.fb6 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='greedy',
- xi_restart=0.9,
- min_beta=1.0,
- s_greedy=1.1,
- )
-
- self.gfb_all_iter = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=None,
- auto_iterate=False,
- gamma_update=func_identity,
- beta_update=func_identity,
- )
- self.gfb_all_iter.iterate(self.max_iter)
-
- self.gfb1 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- gamma_update=func_identity,
- lambda_update=func_identity,
- )
-
- self.gfb2 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- )
-
- self.gfb3 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- step_size=2,
- )
-
- self.condat_all_iter = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- auto_iterate=False,
- )
- self.condat_all_iter.iterate(self.max_iter)
-
- self.condat1 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- )
-
- self.condat2 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=linear_inst,
- cost=cost_inst,
- reweight=reweight_inst,
- )
-
- self.condat3 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=Dummy(),
- cost=cost_inst,
- auto_iterate=False,
- )
-
- self.pogm_all_iter = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- auto_iterate=False,
- cost=None,
- )
- self.pogm_all_iter.iterate(self.max_iter)
-
- self.pogm1 = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- )
-
- self.vanilla_grad = algorithms.VanillaGenericGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.ada_grad = algorithms.AdaGenericGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.adam_grad = algorithms.ADAMGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.momentum_grad = algorithms.MomentumGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.rms_grad = algorithms.RMSpropGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.saga_grad = algorithms.SAGAOptGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
-
- self.dummy = Dummy()
- self.dummy.cost = func_identity
- self.setup._check_operator(self.dummy.cost)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.setup = None
- self.fb_all_iter = None
- self.fb1 = None
- self.fb2 = None
- self.gfb_all_iter = None
- self.gfb1 = None
- self.gfb2 = None
- self.condat_all_iter = None
- self.condat1 = None
- self.condat2 = None
- self.condat3 = None
- self.pogm1 = None
- self.pogm_all_iter = None
- self.dummy = None
-
- def test_set_up(self):
- """Test set_up."""
- npt.assert_raises(TypeError, self.setup._check_input_data, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param_update, 1)
-
- def test_all_iter(self):
- """Test if all opt run for all iterations."""
- opts = [
- self.fb_all_iter,
- self.gfb_all_iter,
- self.condat_all_iter,
- self.pogm_all_iter,
- ]
- for opt in opts:
- npt.assert_equal(opt.idx, self.max_iter - 1)
-
- def test_forward_backward(self):
- """Test forward_backward."""
- npt.assert_array_equal(
- self.fb1.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb2.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb3.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb4.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb5.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb6.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- def test_gen_forward_backward(self):
- """Test gen_forward_backward."""
- npt.assert_array_equal(
- self.gfb1.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb2.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb3.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_equal(
- self.gfb3.step_size,
- 2,
- err_msg='Incorrect step size.',
- )
-
- npt.assert_raises(
- TypeError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=1,
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[1],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5, 0.5],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5],
- )
-
- def test_condat(self):
- """Test gen_condat."""
- npt.assert_almost_equal(
- self.condat1.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- npt.assert_almost_equal(
- self.condat2.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- def test_pogm(self):
- """Test pogm."""
- npt.assert_almost_equal(
- self.pogm1.x_final,
- self.data1,
- err_msg='Incorrect POGM result.',
- )
-
- def test_ada_grad(self):
- """Test ADA Gradient Descent."""
- self.ada_grad.iterate()
- npt.assert_almost_equal(
- self.ada_grad.x_final,
- self.data1,
- err_msg='Incorrect ADAGrad results.',
- )
-
- def test_adam_grad(self):
- """Test ADAM Gradient Descent."""
- self.adam_grad.iterate()
- npt.assert_almost_equal(
- self.adam_grad.x_final,
- self.data1,
- err_msg='Incorrect ADAMGrad results.',
- )
-
- def test_momemtum_grad(self):
- """Test Momemtum Gradient Descent."""
- self.momentum_grad.iterate()
- npt.assert_almost_equal(
- self.momentum_grad.x_final,
- self.data1,
- err_msg='Incorrect MomentumGrad results.',
- )
-
- def test_rmsprop_grad(self):
- """Test RMSProp Gradient Descent."""
- self.rms_grad.iterate()
- npt.assert_almost_equal(
- self.rms_grad.x_final,
- self.data1,
- err_msg='Incorrect RMSPropGrad results.',
- )
-
- def test_saga_grad(self):
- """Test SAGA Descent."""
- self.saga_grad.iterate()
- npt.assert_almost_equal(
- self.saga_grad.x_final,
- self.data1,
- err_msg='Incorrect SAGA Grad results.',
- )
-
- def test_vanilla_grad(self):
- """Test Vanilla Gradient Descent."""
- self.vanilla_grad.iterate()
- npt.assert_almost_equal(
- self.vanilla_grad.x_final,
- self.data1,
- err_msg='Incorrect VanillaGrad results.',
- )
diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py
deleted file mode 100644
index 873a4506..00000000
--- a/modopt/tests/test_base.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR BASE.
-
-This module contains unit tests for the modopt.base module.
-
-:Author: Samuel Farrens
-
-"""
-
-from builtins import range
-from unittest import TestCase, skipIf
-
-import numpy as np
-import numpy.testing as npt
-
-from modopt.base import np_adjust, transform, types
-from modopt.base.backend import (LIBRARIES, change_backend, get_array_module,
- get_backend)
-
-
-class NPAdjustTestCase(TestCase):
- """Test case for np_adjust module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape((3, 3))
- self.data2 = np.arange(18).reshape((2, 3, 3))
- self.data3 = np.array([
- [0, 0, 0, 0, 0],
- [0, 0, 1, 2, 0],
- [0, 3, 4, 5, 0],
- [0, 6, 7, 8, 0],
- [0, 0, 0, 0, 0],
- ])
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
-
- def test_rotate(self):
- """Test rotate."""
- npt.assert_array_equal(
- np_adjust.rotate(self.data1),
- np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]),
- err_msg='Incorrect rotation',
- )
-
- def test_rotate_stack(self):
- """Test rotate_stack."""
- npt.assert_array_equal(
- np_adjust.rotate_stack(self.data2),
- np.array([
- [[8, 7, 6], [5, 4, 3], [2, 1, 0]],
- [[17, 16, 15], [14, 13, 12], [11, 10, 9]],
- ]),
- err_msg='Incorrect stack rotation',
- )
-
- def test_pad2d(self):
- """Test pad2d."""
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, (1, 1)),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, 1),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, np.array([1, 1])),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_raises(ValueError, np_adjust.pad2d, self.data1, '1')
-
- def test_fancy_transpose(self):
- """Test fancy_transpose."""
- npt.assert_array_equal(
- np_adjust.fancy_transpose(self.data2),
- np.array([
- [[0, 3, 6], [9, 12, 15]],
- [[1, 4, 7], [10, 13, 16]],
- [[2, 5, 8], [11, 14, 17]],
- ]),
- err_msg='Incorrect fancy transpose',
- )
-
- def test_ftr(self):
- """Test ftr."""
- npt.assert_array_equal(
- np_adjust.ftr(self.data2),
- np.array([
- [[0, 3, 6], [9, 12, 15]],
- [[1, 4, 7], [10, 13, 16]],
- [[2, 5, 8], [11, 14, 17]],
- ]),
- err_msg='Incorrect fancy transpose: ftr',
- )
-
- def test_ftl(self):
- """Test ftl."""
- npt.assert_array_equal(
- np_adjust.ftl(self.data2),
- np.array([
- [[0, 9], [1, 10], [2, 11]],
- [[3, 12], [4, 13], [5, 14]],
- [[6, 15], [7, 16], [8, 17]],
- ]),
- err_msg='Incorrect fancy transpose: ftl',
- )
-
-
-class TransformTestCase(TestCase):
- """Test case for transform module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.cube = np.arange(16).reshape((4, 2, 2))
- self.map = np.array(
- [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]],
- )
- self.matrix = np.array(
- [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]],
- )
- self.layout = (2, 2)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.cube = None
- self.map = None
- self.layout = None
-
- def test_cube2map(self):
- """Test cube2map."""
- npt.assert_array_equal(
- transform.cube2map(self.cube, self.layout),
- self.map,
- err_msg='Incorrect transformation: cube2map',
- )
-
- npt.assert_raises(
- ValueError,
- transform.cube2map,
- self.map,
- self.layout,
- )
-
- npt.assert_raises(ValueError, transform.cube2map, self.cube, (3, 3))
-
- def test_map2cube(self):
- """Test map2cube."""
- npt.assert_array_equal(
- transform.map2cube(self.map, self.layout),
- self.cube,
- err_msg='Incorrect transformation: map2cube',
- )
-
- npt.assert_raises(ValueError, transform.map2cube, self.map, (3, 3))
-
- def test_map2matrix(self):
- """Test map2matrix."""
- npt.assert_array_equal(
- transform.map2matrix(self.map, self.layout),
- self.matrix,
- err_msg='Incorrect transformation: map2matrix',
- )
-
- def test_matrix2map(self):
- """Test matrix2map."""
- npt.assert_array_equal(
- transform.matrix2map(self.matrix, self.map.shape),
- self.map,
- err_msg='Incorrect transformation: matrix2map',
- )
-
- def test_cube2matrix(self):
- """Test cube2matrix."""
- npt.assert_array_equal(
- transform.cube2matrix(self.cube),
- self.matrix,
- err_msg='Incorrect transformation: cube2matrix',
- )
-
- def test_matrix2cube(self):
- """Test matrix2cube."""
- npt.assert_array_equal(
- transform.matrix2cube(self.matrix, self.cube[0].shape),
- self.cube,
- err_msg='Incorrect transformation: matrix2cube',
- )
-
-
-class TypesTestCase(TestCase):
- """Test case for types module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = list(range(5))
- self.data2 = np.arange(5)
- self.data3 = np.arange(5).astype(float)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
-
- def test_check_float(self):
- """Test check_float."""
- npt.assert_array_equal(
- types.check_float(1.0),
- 1.0,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(1),
- 1.0,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(self.data1),
- self.data3,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(self.data2),
- self.data3,
- err_msg='Float check failed',
- )
-
- npt.assert_raises(TypeError, types.check_float, '1')
-
- def test_check_int(self):
- """Test check_int."""
- npt.assert_array_equal(
- types.check_int(1),
- 1,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(1.0),
- 1,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(self.data1),
- self.data2,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(self.data3),
- self.data2,
- err_msg='Int check failed',
- )
-
- npt.assert_raises(TypeError, types.check_int, '1')
-
- def test_check_npndarray(self):
- """Test check_npndarray."""
- npt.assert_raises(
- TypeError,
- types.check_npndarray,
- self.data3,
- dtype=np.integer,
- )
-
-
-class TestBackend(TestCase):
- """Test the backend codes."""
-
- def setUp(self):
- """Set test parameter values."""
- self.input = np.array([10, 10])
-
- @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed')
- def test_tf_backend(self):
- """Test tensorflow backend."""
- xp, backend = get_backend('tensorflow')
- if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']:
- raise AssertionError('tensorflow get_backend fails!')
- tf_input = change_backend(self.input, 'tensorflow')
- if (
- get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow']
- or get_array_module(tf_input) != LIBRARIES['tensorflow']
- ):
- raise AssertionError('tensorflow backend fails!')
-
- @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed')
- def test_cp_backend(self):
- """Test cupy backend."""
- xp, backend = get_backend('cupy')
- if backend != 'cupy' or xp != LIBRARIES['cupy']:
- raise AssertionError('cupy get_backend fails!')
- cp_input = change_backend(self.input, 'cupy')
- if (
- get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy']
- or get_array_module(cp_input) != LIBRARIES['cupy']
- ):
- raise AssertionError('cupy backend fails!')
-
- def test_np_backend(self):
- """Test numpy backend."""
- xp, backend = get_backend('numpy')
- if backend != 'numpy' or xp != LIBRARIES['numpy']:
- raise AssertionError('numpy get_backend fails!')
- np_input = change_backend(self.input, 'numpy')
- if (
- get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy']
- or get_array_module(np_input) != LIBRARIES['numpy']
- ):
- raise AssertionError('numpy backend fails!')
-
- def tearDown(self):
- """Tear Down of objects."""
- self.input = None
diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py
deleted file mode 100644
index 99908e02..00000000
--- a/modopt/tests/test_math.py
+++ /dev/null
@@ -1,496 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR MATH.
-
-This module contains unit tests for the modopt.math module.
-
-:Author: Samuel Farrens
-
-"""
-
-from unittest import TestCase, skipIf, skipUnless
-
-import numpy as np
-import numpy.testing as npt
-
-from modopt.math import convolve, matrix, metrics, stats
-
-try:
- import astropy
-except ImportError: # pragma: no cover
- import_astropy = False
-else: # pragma: no cover
- import_astropy = True
-try:
- from skimage.metrics import structural_similarity as compare_ssim
-except ImportError: # pragma: no cover
- import_skimage = False
-else:
- import_skimage = True
-
-
-class ConvolveTestCase(TestCase):
- """Test case for convolve module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(18).reshape(2, 3, 3)
- self.data2 = self.data1 + 1
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_convolve_astropy(self):
- """Test convolve using astropy."""
- npt.assert_allclose(
- convolve.convolve(self.data1[0], self.data2[0], method='astropy'),
- np.array([
- [210.0, 201.0, 210.0],
- [129.0, 120.0, 129.0],
- [210.0, 201.0, 210.0],
- ]),
- err_msg='Incorrect convolution: astropy',
- )
-
- npt.assert_raises(
- ValueError,
- convolve.convolve,
- self.data1[0],
- self.data2,
- )
-
- npt.assert_raises(
- ValueError,
- convolve.convolve,
- self.data1[0],
- self.data2[0],
- method='bla',
- )
-
- def test_convolve_scipy(self):
- """Test convolve using scipy."""
- npt.assert_allclose(
- convolve.convolve(self.data1[0], self.data2[0], method='scipy'),
- np.array([
- [14.0, 35.0, 38.0],
- [57.0, 120.0, 111.0],
- [110.0, 197.0, 158.0],
- ]),
- err_msg='Incorrect convolution: scipy',
- )
-
- def test_convolve_stack(self):
- """Test convolve_stack."""
- npt.assert_allclose(
- convolve.convolve_stack(self.data1, self.data2),
- np.array([
- [
- [14.0, 35.0, 38.0],
- [57.0, 120.0, 111.0],
- [110.0, 197.0, 158.0],
- ],
- [
- [518.0, 845.0, 614.0],
- [975.0, 1578.0, 1137.0],
- [830.0, 1331.0, 950.0],
- ],
- ]),
- err_msg='Incorrect convolution: stack',
- )
-
- def test_convolve_stack_rot(self):
- """Test convolve_stack rotated."""
- npt.assert_allclose(
- convolve.convolve_stack(self.data1, self.data2, rot_kernel=True),
- np.array([
- [
- [66.0, 115.0, 82.0],
- [153.0, 240.0, 159.0],
- [90.0, 133.0, 82.0],
- ],
- [
- [714.0, 1087.0, 730.0],
- [1125.0, 1698.0, 1131.0],
- [738.0, 1105.0, 730.0],
- ],
- ]),
- err_msg='Incorrect convolution: stack rot',
- )
-
-
-class MatrixTestCase(TestCase):
- """Test case for matrix module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3)
- self.data2 = np.arange(3)
- self.data3 = np.arange(6).reshape(2, 3)
- np.random.seed(1)
- self.pmInstance1 = matrix.PowerMethod(
- lambda x_val: x_val.dot(x_val.T),
- self.data1.shape,
- verbose=True,
- )
- np.random.seed(1)
- self.pmInstance2 = matrix.PowerMethod(
- lambda x_val: x_val.dot(x_val.T),
- self.data1.shape,
- auto_run=False,
- verbose=True,
- )
- self.pmInstance2.get_spec_rad(max_iter=1)
- self.gram_schmidt_out = (
- np.array([
- [0, 1.0, 2.0],
- [3.0, 1.2, -6e-1],
- [-1.77635684e-15, 0, 0],
- ]),
- np.array([
- [0, 0.4472136, 0.89442719],
- [0.91287093, 0.36514837, -0.18257419],
- [-1.0, 0, 0],
- ]),
- )
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.pmInstance1 = None
- self.pmInstance2 = None
- self.gram_schmidt_out = None
-
- def test_gram_schmidt_orthonormal(self):
- """Test gram_schmidt with orthonormal output."""
- npt.assert_allclose(
- matrix.gram_schmidt(self.data1),
- self.gram_schmidt_out[1],
- err_msg='Incorrect Gram-Schmidt: orthonormal',
- )
-
- npt.assert_raises(
- ValueError,
- matrix.gram_schmidt,
- self.data1,
- return_opt='bla',
- )
-
- def test_gram_schmidt_orthogonal(self):
- """Test gram_schmidt with orthogonal output."""
- npt.assert_allclose(
- matrix.gram_schmidt(self.data1, return_opt='orthogonal'),
- self.gram_schmidt_out[0],
- err_msg='Incorrect Gram-Schmidt: orthogonal',
- )
-
- def test_gram_schmidt_both(self):
- """Test gram_schmidt with both outputs."""
- npt.assert_allclose(
- matrix.gram_schmidt(self.data1, return_opt='both'),
- self.gram_schmidt_out,
- err_msg='Incorrect Gram-Schmidt: both',
- )
-
- def test_nuclear_norm(self):
- """Test nuclear_norm."""
- npt.assert_almost_equal(
- matrix.nuclear_norm(self.data1),
- 15.49193338482967,
- err_msg='Incorrect nuclear norm',
- )
-
- def test_project(self):
- """Test project."""
- npt.assert_array_equal(
- matrix.project(self.data2, self.data2 + 3),
- np.array([0, 2.8, 5.6]),
- err_msg='Incorrect projection',
- )
-
- def test_rot_matrix(self):
- """Test rot_matrix."""
- npt.assert_allclose(
- matrix.rot_matrix(np.pi / 6),
- np.array([[0.8660254, -0.5], [0.5, 0.8660254]]),
- err_msg='Incorrect rotation matrix',
- )
-
- def test_rotate(self):
- """Test rotate."""
- npt.assert_array_equal(
- matrix.rotate(self.data1, np.pi / 2),
- np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]),
- err_msg='Incorrect rotation',
- )
-
- npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2)
-
- def test_powermethod_converged(self):
- """Test PowerMethod converged."""
- npt.assert_almost_equal(
- self.pmInstance1.spec_rad,
- 0.90429242629600837,
- err_msg='Incorrect spectral radius: converged',
- )
-
- npt.assert_almost_equal(
- self.pmInstance1.inv_spec_rad,
- 1.1058369736612865,
- err_msg='Incorrect inverse spectral radius: converged',
- )
-
- def test_powermethod_unconverged(self):
- """Test PowerMethod unconverged."""
- npt.assert_almost_equal(
- self.pmInstance2.spec_rad,
- 0.92048833577059219,
- err_msg='Incorrect spectral radius: unconverged',
- )
-
- npt.assert_almost_equal(
- self.pmInstance2.inv_spec_rad,
- 1.0863798715741946,
- err_msg='Incorrect inverse spectral radius: unconverged',
- )
-
-
-class MetricsTestCase(TestCase):
- """Test case for metrics module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(49).reshape(7, 7)
- self.mask = np.ones(self.data1.shape)
- self.ssim_res = 0.8963363560519094
- self.ssim_mask_res = 0.805154442543846
- self.snr_res = 10.134554256920536
- self.psnr_res = 14.860761791850397
- self.mse_res = 0.03265305507330247
- self.nrmse_res = 0.31136678840022625
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.mask = None
- self.ssim_res = None
- self.ssim_mask_res = None
- self.psnr_res = None
- self.mse_res = None
- self.nrmse_res = None
-
- @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover
- def test_ssim_skimage_error(self):
- """Test ssim skimage error."""
- npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1)
-
- @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover
- def test_ssim(self):
- """Test ssim."""
- npt.assert_almost_equal(
- metrics.ssim(self.data1, self.data1 ** 2),
- self.ssim_res,
- err_msg='Incorrect SSIM result',
- )
-
- npt.assert_almost_equal(
- metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask),
- self.ssim_mask_res,
- err_msg='Incorrect SSIM result',
- )
-
- npt.assert_raises(
- ValueError,
- metrics.ssim,
- self.data1,
- self.data1,
- mask=1,
- )
-
- def test_snr(self):
- """Test snr."""
- npt.assert_almost_equal(
- metrics.snr(self.data1, self.data1 ** 2),
- self.snr_res,
- err_msg='Incorrect SNR result',
- )
-
- npt.assert_almost_equal(
- metrics.snr(self.data1, self.data1 ** 2, mask=self.mask),
- self.snr_res,
- err_msg='Incorrect SNR result',
- )
-
- def test_psnr(self):
- """Test psnr."""
- npt.assert_almost_equal(
- metrics.psnr(self.data1, self.data1 ** 2),
- self.psnr_res,
- err_msg='Incorrect PSNR result',
- )
-
- npt.assert_almost_equal(
- metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask),
- self.psnr_res,
- err_msg='Incorrect PSNR result',
- )
-
- def test_mse(self):
- """Test mse."""
- npt.assert_almost_equal(
- metrics.mse(self.data1, self.data1 ** 2),
- self.mse_res,
- err_msg='Incorrect MSE result',
- )
-
- npt.assert_almost_equal(
- metrics.mse(self.data1, self.data1 ** 2, mask=self.mask),
- self.mse_res,
- err_msg='Incorrect MSE result',
- )
-
- def test_nrmse(self):
- """Test nrmse."""
- npt.assert_almost_equal(
- metrics.nrmse(self.data1, self.data1 ** 2),
- self.nrmse_res,
- err_msg='Incorrect NRMSE result',
- )
-
- npt.assert_almost_equal(
- metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask),
- self.nrmse_res,
- err_msg='Incorrect NRMSE result',
- )
-
-
-class StatsTestCase(TestCase):
- """Test case for stats module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3)
- self.data2 = np.arange(18).reshape(2, 3, 3)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
-
- @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover
- def test_gaussian_kernel_astropy_error(self):
- """Test gaussian_kernel astropy error."""
- npt.assert_raises(
- ImportError,
- stats.gaussian_kernel,
- self.data1.shape,
- 1,
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_max(self):
- """Test gaussian_kernel with max norm."""
- npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1),
- np.array([
- [0.36787944, 0.60653066, 0.36787944],
- [0.60653066, 1.0, 0.60653066],
- [0.36787944, 0.60653066, 0.36787944],
- ]),
- err_msg='Incorrect gaussian kernel: max norm',
- )
-
- npt.assert_raises(
- ValueError,
- stats.gaussian_kernel,
- self.data1.shape,
- 1,
- norm='bla',
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_sum(self):
- """Test gaussian_kernel with sum norm."""
- npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1, norm='sum'),
- np.array([
- [0.07511361, 0.1238414, 0.07511361],
- [0.1238414, 0.20417996, 0.1238414],
- [0.07511361, 0.1238414, 0.07511361],
- ]),
- err_msg='Incorrect gaussian kernel: sum norm',
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_none(self):
- """Test gaussian_kernel with no norm."""
- npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1, norm='none'),
- np.array([
- [0.05854983, 0.09653235, 0.05854983],
- [0.09653235, 0.15915494, 0.09653235],
- [0.05854983, 0.09653235, 0.05854983],
- ]),
- err_msg='Incorrect gaussian kernel: sum norm',
- )
-
- def test_mad(self):
- """Test mad."""
- npt.assert_equal(
- stats.mad(self.data1),
- 2.0,
- err_msg='Incorrect median absolute deviation',
- )
-
- def test_mse(self):
- """Test mse."""
- npt.assert_equal(
- stats.mse(self.data1, self.data1 + 2),
- 4.0,
- err_msg='Incorrect mean squared error',
- )
-
- def test_psnr_starck(self):
- """Test psnr."""
- npt.assert_almost_equal(
- stats.psnr(self.data1, self.data1 + 2),
- 12.041199826559248,
- err_msg='Incorrect PSNR: starck',
- )
-
- npt.assert_raises(
- ValueError,
- stats.psnr,
- self.data1,
- self.data1,
- method='bla',
- )
-
- def test_psnr_wiki(self):
- """Test psnr wiki method."""
- npt.assert_almost_equal(
- stats.psnr(self.data1, self.data1 + 2, method='wiki'),
- 42.110203695399477,
- err_msg='Incorrect PSNR: wiki',
- )
-
- def test_psnr_stack(self):
- """Test psnr stack."""
- npt.assert_almost_equal(
- stats.psnr_stack(self.data2, self.data2 + 2),
- 12.041199826559248,
- err_msg='Incorrect PSNR stack',
- )
-
- npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1)
-
- def test_sigma_mad(self):
- """Test sigma_mad."""
- npt.assert_almost_equal(
- stats.sigma_mad(self.data1),
- 2.9651999999999998,
- err_msg='Incorrect sigma from MAD',
- )
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
deleted file mode 100644
index d5547783..00000000
--- a/modopt/tests/test_opt.py
+++ /dev/null
@@ -1,1071 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR OPT.
-
-This module contains unit tests for the modopt.opt module.
-
-:Author: Samuel Farrens
-
-"""
-
-from builtins import zip
-from unittest import TestCase, skipIf, skipUnless
-
-import numpy as np
-import numpy.testing as npt
-
-from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
-
-try:
- import sklearn
-except ImportError: # pragma: no cover
- import_sklearn = False
-else:
- import_sklearn = True
-
-
-# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val ** 2
-func_cube = lambda x_val: x_val ** 3
-
-
-class Dummy(object):
- """Dummy class for tests."""
-
- pass
-
-
-class AlgorithmTestCase(TestCase):
- """Test case for algorithms module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6
- self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1
-
- grad_inst = gradient.GradBasic(
- self.data1,
- func_identity,
- func_identity,
- )
-
- prox_inst = proximity.Positivity()
- prox_dual_inst = proximity.IdentityProx()
- linear_inst = linear.Identity()
- reweight_inst = reweight.cwbReweight(self.data3)
- cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
- self.setup = algorithms.SetUp()
- self.max_iter = 20
-
- self.fb_all_iter = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- auto_iterate=False,
- beta_update=func_identity,
- )
- self.fb_all_iter.iterate(self.max_iter)
-
- self.fb1 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- )
-
- self.fb2 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- lambda_update=None,
- )
-
- self.fb3 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- a_cd=3,
- )
-
- self.fb4 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- r_lazy=3,
- p_lazy=0.7,
- q_lazy=0.7,
- )
-
- self.fb5 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='adaptive',
- xi_restart=0.9,
- )
-
- self.fb6 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='greedy',
- xi_restart=0.9,
- min_beta=1.0,
- s_greedy=1.1,
- )
-
- self.gfb_all_iter = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=None,
- auto_iterate=False,
- gamma_update=func_identity,
- beta_update=func_identity,
- )
- self.gfb_all_iter.iterate(self.max_iter)
-
- self.gfb1 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- gamma_update=func_identity,
- lambda_update=func_identity,
- )
-
- self.gfb2 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- )
-
- self.gfb3 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- step_size=2,
- )
-
- self.condat_all_iter = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- auto_iterate=False,
- )
- self.condat_all_iter.iterate(self.max_iter)
-
- self.condat1 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- )
-
- self.condat2 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=linear_inst,
- cost=cost_inst,
- reweight=reweight_inst,
- )
-
- self.condat3 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=Dummy(),
- cost=cost_inst,
- auto_iterate=False,
- )
-
- self.pogm_all_iter = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- auto_iterate=False,
- cost=None,
- )
- self.pogm_all_iter.iterate(self.max_iter)
-
- self.pogm1 = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- )
-
- self.dummy = Dummy()
- self.dummy.cost = func_identity
- self.setup._check_operator(self.dummy.cost)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.setup = None
- self.fb_all_iter = None
- self.fb1 = None
- self.fb2 = None
- self.gfb_all_iter = None
- self.gfb1 = None
- self.gfb2 = None
- self.condat_all_iter = None
- self.condat1 = None
- self.condat2 = None
- self.condat3 = None
- self.pogm1 = None
- self.pogm_all_iter = None
- self.dummy = None
-
- def test_set_up(self):
- """Test set_up."""
- npt.assert_raises(TypeError, self.setup._check_input_data, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param_update, 1)
-
- def test_all_iter(self):
- """Test if all opt run for all iterations."""
- opts = [
- self.fb_all_iter,
- self.gfb_all_iter,
- self.condat_all_iter,
- self.pogm_all_iter,
- ]
- for opt in opts:
- npt.assert_equal(opt.idx, self.max_iter - 1)
-
- def test_forward_backward(self):
- """Test forward_backward."""
- npt.assert_array_equal(
- self.fb1.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb2.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb3.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb4.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb5.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb6.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- def test_gen_forward_backward(self):
- """Test gen_forward_backward."""
- npt.assert_array_equal(
- self.gfb1.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb2.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb3.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_equal(
- self.gfb3.step_size,
- 2,
- err_msg='Incorrect step size.',
- )
-
- npt.assert_raises(
- TypeError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=1,
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[1],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5, 0.5],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5],
- )
-
- def test_condat(self):
- """Test gen_condat."""
- npt.assert_almost_equal(
- self.condat1.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- npt.assert_almost_equal(
- self.condat2.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- def test_pogm(self):
- """Test pogm."""
- npt.assert_almost_equal(
- self.pogm1.x_final,
- self.data1,
- err_msg='Incorrect POGM result.',
- )
-
-
-class CostTestCase(TestCase):
- """Test case for cost module."""
-
- def setUp(self):
- """Set test parameter values."""
- dummy_inst1 = Dummy()
- dummy_inst1.cost = func_sq
- dummy_inst2 = Dummy()
- dummy_inst2.cost = func_cube
-
- self.inst1 = cost.costObj([dummy_inst1, dummy_inst2])
- self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2)
- # Test that by default cost of False if interval is None
- self.inst_none = cost.costObj(
- [dummy_inst1, dummy_inst2],
- cost_interval=None,
- )
- for _ in range(2):
- self.inst1.get_cost(2)
- for _ in range(6):
- self.inst2.get_cost(2)
- self.inst_none.get_cost(2)
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.inst = None
-
- def test_cost_object(self):
- """Test cost_object."""
- npt.assert_equal(
- self.inst1.get_cost(2),
- False,
- err_msg='Incorrect cost test result.',
- )
- npt.assert_equal(
- self.inst1.get_cost(2),
- True,
- err_msg='Incorrect cost test result.',
- )
- npt.assert_equal(
- self.inst_none.get_cost(2),
- False,
- err_msg='Incorrect cost test result.',
- )
-
- npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.')
-
- npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.')
-
- npt.assert_raises(TypeError, cost.costObj, 1)
-
- npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy])
-
-
-class GradientTestCase(TestCase):
- """Test case for gradient module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.gp = gradient.GradParent(
- self.data1,
- func_sq,
- func_cube,
- func_identity,
- lambda input_val: 1.0,
- data_type=np.floating,
- )
- self.gp.grad = self.gp.get_grad(self.data1)
- self.gb = gradient.GradBasic(
- self.data1,
- func_sq,
- func_cube,
- )
- self.gb.get_grad(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.gp = None
- self.gb = None
-
- def test_grad_parent_operators(self):
- """Test GradParent."""
- npt.assert_array_equal(
- self.gp.op(self.data1),
- np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]),
- err_msg='Incorrect gradient operation.',
- )
-
- npt.assert_array_equal(
- self.gp.trans_op(self.data1),
- np.array(
- [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]],
- ),
- err_msg='Incorrect gradient transpose operation.',
- )
-
- npt.assert_array_equal(
- self.gp.trans_op_op(self.data1),
- np.array([
- [0, 1.0, 6.40000000e1],
- [7.29000000e2, 4.09600000e3, 1.56250000e4],
- [4.66560000e4, 1.17649000e5, 2.62144000e5],
- ]),
- err_msg='Incorrect gradient transpose operation operation.',
- )
-
- npt.assert_equal(
- self.gp.cost(self.data1),
- 1.0,
- err_msg='Incorrect cost.',
- )
-
- npt.assert_raises(
- TypeError,
- gradient.GradParent,
- 1,
- func_sq,
- func_cube,
- )
-
- def test_grad_basic_gradient(self):
- """Test GradBasic."""
- npt.assert_array_equal(
- self.gb.grad,
- np.array([
- [0, 0, 8.0],
- [2.16000000e2, 1.72800000e3, 8.0e3],
- [2.70000000e4, 7.40880000e4, 1.75616000e5],
- ]),
- err_msg='Incorrect gradient.',
- )
-
-
-class LinearTestCase(TestCase):
- """Test case for linear module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.parent = linear.LinearParent(
- func_sq,
- func_cube,
- )
- self.ident = linear.Identity()
- filters = np.arange(8).reshape(2, 2, 2).astype(float)
- self.wave = linear.WaveletConvolve(filters)
- self.combo = linear.LinearCombo([self.parent, self.parent])
- self.combo_weight = linear.LinearCombo(
- [self.parent, self.parent],
- [1.0, 1.0],
- )
- self.data1 = np.arange(18).reshape(2, 3, 3).astype(float)
- self.data2 = np.arange(4).reshape(1, 2, 2).astype(float)
- self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float)
- self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]])
- self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]])
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.parent = None
- self.ident = None
- self.combo = None
- self.combo_weight = None
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
- self.dummy = None
-
- def test_linear_parent(self):
- """Test LinearParent."""
- npt.assert_equal(
- self.parent.op(2),
- 4,
- err_msg='Incorrect linear parent operation.',
- )
-
- npt.assert_equal(
- self.parent.adj_op(2),
- 8,
- err_msg='Incorrect linear parent adjoint operation.',
- )
-
- npt.assert_raises(TypeError, linear.LinearParent, 0, 0)
-
- def test_identity(self):
- """Test Identity."""
- npt.assert_equal(
- self.ident.op(1.0),
- 1.0,
- err_msg='Incorrect identity operation.',
- )
-
- npt.assert_equal(
- self.ident.adj_op(1.0),
- 1.0,
- err_msg='Incorrect identity adjoint operation.',
- )
-
- def test_wavelet_convolve(self):
- """Test WaveletConvolve."""
- npt.assert_almost_equal(
- self.wave.op(self.data2),
- self.data4,
- err_msg='Incorrect wavelet convolution operation.',
- )
-
- npt.assert_almost_equal(
- self.wave.adj_op(self.data3),
- self.data5,
- err_msg='Incorrect wavelet convolution adjoint operation.',
- )
-
- def test_linear_combo(self):
- """Test LinearCombo."""
- npt.assert_equal(
- self.combo.op(2),
- np.array([4, 4]).astype(object),
- err_msg='Incorrect combined linear operation',
- )
-
- npt.assert_equal(
- self.combo.adj_op([2, 2]),
- 8.0,
- err_msg='Incorrect combined linear adjoint operation',
- )
-
- npt.assert_raises(TypeError, linear.LinearCombo, self.parent)
-
- npt.assert_raises(ValueError, linear.LinearCombo, [])
-
- npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy])
-
- self.dummy.op = func_identity
-
- npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy])
-
- def test_linear_combo_weight(self):
- """Test LinearCombo with weight ."""
- npt.assert_equal(
- self.combo_weight.op(2),
- np.array([4, 4]).astype(object),
- err_msg='Incorrect combined linear operation',
- )
-
- npt.assert_equal(
- self.combo_weight.adj_op([2, 2]),
- 16.0,
- err_msg='Incorrect combined linear adjoint operation',
- )
-
- npt.assert_raises(
- ValueError,
- linear.LinearCombo,
- [self.parent, self.parent],
- [1.0],
- )
-
- npt.assert_raises(
- TypeError,
- linear.LinearCombo,
- [self.parent, self.parent],
- ['1', '1'],
- )
-
-
-class ProximityTestCase(TestCase):
- """Test case for proximity module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.parent = proximity.ProximityParent(
- func_sq,
- func_double,
- )
- self.identity = proximity.IdentityProx()
- self.positivity = proximity.Positivity()
- weights = np.ones(9).reshape(3, 3).astype(float) * 3
- self.sparsethresh = proximity.SparseThreshold(
- linear.Identity(),
- weights,
- )
- self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard')
- self.lowrank_rank = proximity.LowRankMatrix(
- 10.0,
- initial_rank=1,
- thresh_type='hard',
- )
- self.lowrank_ngole = proximity.LowRankMatrix(
- 10.0,
- lowr_type='ngole',
- operator=func_double,
- )
- self.linear_comp = proximity.LinearCompositionProx(
- linear_op=linear.Identity(),
- prox_op=self.sparsethresh,
- )
- self.combo = proximity.ProximityCombo([self.identity, self.positivity])
- if import_sklearn:
- self.owl = proximity.OrderedWeightedL1Norm(weights.flatten())
- self.ridge = proximity.Ridge(linear.Identity(), weights)
- self.elasticnet_alpha0 = proximity.ElasticNet(
- linear.Identity(),
- alpha=0,
- beta=weights,
- )
- self.elasticnet_beta0 = proximity.ElasticNet(
- linear.Identity(),
- alpha=weights,
- beta=0,
- )
- self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1)
- self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5)
- self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19)
- self.group_lasso = proximity.GroupLASSO(
- weights=np.tile(weights, (4, 1, 1)),
- )
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]])
- self.data3 = np.arange(18).reshape(2, 3, 3).astype(float)
- self.data4 = np.array([
- [
- [2.73843189, 3.14594066, 3.55344943],
- [3.9609582, 4.36846698, 4.77597575],
- [5.18348452, 5.59099329, 5.99850206],
- ],
- [
- [8.07085295, 9.2718846, 10.47291625],
- [11.67394789, 12.87497954, 14.07601119],
- [15.27704284, 16.47807449, 17.67910614],
- ],
- ])
- self.data5 = np.array([
- [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
- [
- [4.00795282, 4.60438026, 5.2008077],
- [5.79723515, 6.39366259, 6.99009003],
- [7.58651747, 8.18294492, 8.77937236],
- ],
- ])
- self.data6 = self.data3 * -1
- self.data7 = self.combo.op(self.data6)
- self.data8 = np.empty(2, dtype=np.ndarray)
- self.data8[0] = np.array(
- [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]],
- )
- self.data8[1] = np.array(
- [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]],
- )
- self.data9 = self.data1 * (1 + 1j)
- self.data10 = self.data9 / (2 * 3 + 1)
- self.data11 = np.asarray(
- [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]],
- )
- self.random_data = 3 * np.random.random(
- self.group_lasso.weights[0].shape,
- )
- self.random_data_tile = np.tile(
- self.random_data,
- (self.group_lasso.weights.shape[0], 1, 1),
- )
- self.gl_result_data = 2 * self.random_data_tile - 3
- self.gl_result_data = np.array(
- (self.gl_result_data * (self.gl_result_data > 0).astype('int'))
- / 2,
- )
-
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.parent = None
- self.identity = None
- self.positivity = None
- self.sparsethresh = None
- self.lowrank = None
- self.lowrank_rank = None
- self.lowrank_ngole = None
- self.combo = None
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
- self.data6 = None
- self.data7 = None
- self.data8 = None
- self.dummy = None
- self.random_data = None
- self.random_data_tile = None
- self.gl_result_data = None
-
- def test_proximity_parent(self):
- """Test ProximityParent."""
- npt.assert_equal(
- self.parent.op(3),
- 9,
- err_msg='Inccoret proximity parent operation.',
- )
-
- npt.assert_equal(
- self.parent.cost(3),
- 6,
- err_msg='Incorrect proximity parent cost.',
- )
-
- def test_identity(self):
- """Test IdentityProx."""
- npt.assert_equal(
- self.identity.op(3),
- 3,
- err_msg='Incorrect proximity identity operation.',
- )
-
- npt.assert_equal(
- self.identity.cost(3),
- 0,
- err_msg='Incorrect proximity identity cost.',
- )
-
- def test_positivity(self):
- """Test Positivity."""
- npt.assert_equal(
- self.positivity.op(-3),
- 0,
- err_msg='Incorrect proximity positivity operation.',
- )
-
- npt.assert_equal(
- self.positivity.cost(-3, verbose=True),
- 0,
- err_msg='Incorrect proximity positivity cost.',
- )
-
- def test_sparse_threshold(self):
- """Test SparseThreshold."""
- npt.assert_array_equal(
- self.sparsethresh.op(self.data1),
- self.data2,
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.sparsethresh.cost(self.data1, verbose=True),
- 108.0,
- err_msg='Incorrect sparse threshold cost.',
- )
-
- def test_low_rank_matrix(self):
- """Test LowRankMatrix."""
- npt.assert_almost_equal(
- self.lowrank.op(self.data3),
- self.data4,
- err_msg='Incorrect low rank operation: standard',
- )
-
- npt.assert_almost_equal(
- self.lowrank_rank.op(self.data3),
- self.data4,
- err_msg='Incorrect low rank operation: standard with rank',
- )
- npt.assert_almost_equal(
- self.lowrank_ngole.op(self.data3),
- self.data5,
- err_msg='Incorrect low rank operation: ngole',
- )
-
- npt.assert_almost_equal(
- self.lowrank.cost(self.data3, verbose=True),
- 469.39132942464983,
- err_msg='Incorrect low rank cost.',
- )
-
- def test_linear_comp_prox(self):
- """Test LinearCompositionProx."""
- npt.assert_array_equal(
- self.linear_comp.op(self.data1),
- self.data2,
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.linear_comp.cost(self.data1, verbose=True),
- 108.0,
- err_msg='Incorrect sparse threshold cost.',
- )
-
- def test_proximity_combo(self):
- """Test ProximityCombo."""
- for data7, data8 in zip(self.data7, self.data8):
- npt.assert_array_equal(
- data7,
- data8,
- err_msg='Incorrect combined operation',
- )
-
- npt.assert_equal(
- self.combo.cost(self.data6),
- 0,
- err_msg='Incorrect combined cost.',
- )
-
- npt.assert_raises(TypeError, proximity.ProximityCombo, 1)
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [])
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy])
-
- self.dummy.op = func_identity
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy])
-
- @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover
- def test_owl_sklearn_error(self):
- """Test OrderedWeightedL1Norm with Scikit-Learn."""
- npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1)
-
- @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover
- def test_sparse_owl(self):
- """Test OrderedWeightedL1Norm."""
- npt.assert_array_equal(
- self.owl.op(self.data1.flatten()),
- self.data2.flatten(),
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.owl.cost(self.data1.flatten(), verbose=True),
- 108.0,
- err_msg='Incorrect sparse threshold cost.',
- )
-
- npt.assert_raises(
- ValueError,
- proximity.OrderedWeightedL1Norm,
- np.arange(10),
- )
-
- def test_ridge(self):
- """Test Ridge."""
- npt.assert_array_equal(
- self.ridge.op(self.data9),
- self.data10,
- err_msg='Incorect shrinkage operation.',
- )
-
- npt.assert_equal(
- self.ridge.cost(self.data9, verbose=True),
- 408.0 * 3.0,
- err_msg='Incorect shrinkage cost.',
- )
-
- def test_elastic_net_alpha0(self):
- """Test ElasticNet."""
- npt.assert_array_equal(
- self.elasticnet_alpha0.op(self.data1),
- self.data2,
- err_msg='Incorect sparse threshold operation ElasticNet class.',
- )
-
- npt.assert_equal(
- self.elasticnet_alpha0.cost(self.data1),
- 108.0,
- err_msg='Incorect shrinkage cost in ElasticNet class.',
- )
-
- def test_elastic_net_beta0(self):
- """Test ElasticNet with beta=0."""
- npt.assert_array_equal(
- self.elasticnet_beta0.op(self.data9),
- self.data10,
- err_msg='Incorect ridge operation ElasticNet class.',
- )
-
- npt.assert_equal(
- self.elasticnet_beta0.cost(self.data9, verbose=True),
- 408.0 * 3.0,
- err_msg='Incorect shrinkage cost in ElasticNet class.',
- )
-
- def test_one_support_norm(self):
- """Test KSupportNorm with k=1."""
- npt.assert_allclose(
- self.one_support.op(self.data1.flatten()),
- self.data2.flatten(),
- err_msg='Incorect sparse threshold operation for 1-support norm',
- rtol=1e-6,
- )
-
- npt.assert_equal(
- self.one_support.cost(self.data1.flatten(), verbose=True),
- 259.2,
- err_msg='Incorect sparse threshold cost.',
- )
-
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
-
- def test_five_support_norm(self):
- """Test KSupportNorm with k=5."""
- npt.assert_allclose(
- self.five_support_norm.op(self.data1.flatten()),
- self.data11.flatten(),
- err_msg='Incorect sparse Ksupport norm operation',
- rtol=1e-6,
- )
-
- npt.assert_equal(
- self.five_support_norm.cost(self.data1.flatten(), verbose=True),
- 684.0,
- err_msg='Incorrect 5-support norm cost.',
- )
-
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
-
- def test_d_support_norm(self):
- """Test KSupportNorm with k=19."""
- npt.assert_allclose(
- self.d_support.op(self.data9.flatten()),
- self.data10.flatten(),
- err_msg='Incorect shrinkage operation for d-support norm',
- rtol=1e-6,
- )
-
- npt.assert_almost_equal(
- self.d_support.cost(self.data9.flatten(), verbose=True),
- 408.0 * 3.0,
- err_msg='Incorrect shrinkage cost for d-support norm.',
- )
-
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
-
- def test_group_lasso(self):
- """Test GroupLASSO."""
- npt.assert_allclose(
- self.group_lasso.op(self.random_data_tile),
- self.gl_result_data,
- )
- npt.assert_equal(
- self.group_lasso.cost(self.random_data_tile),
- np.sum(6 * self.random_data_tile),
- )
- # Check that for 0 weights operator doesnt change result
- self.group_lasso.weights = np.zeros_like(self.group_lasso.weights)
- npt.assert_equal(
- self.group_lasso.op(self.random_data_tile),
- self.random_data_tile,
- )
- npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0)
-
-
-class ReweightTestCase(TestCase):
- """Test case for reweight module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1
- self.data2 = np.array(
- [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]],
- )
- self.rw = reweight.cwbReweight(self.data1)
- self.rw.reweight(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.rw = None
-
- def test_cwbreweight(self):
- """Test cwbReweight."""
- npt.assert_array_equal(
- self.rw.weights,
- self.data2,
- err_msg='Incorrect CWB re-weighting.',
- )
-
- npt.assert_raises(ValueError, self.rw.reweight, self.data1[0])
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
deleted file mode 100644
index 7490b98c..00000000
--- a/modopt/tests/test_signal.py
+++ /dev/null
@@ -1,414 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR SIGNAL.
-
-This module contains unit tests for the modopt.signal module.
-
-:Author: Samuel Farrens
-
-"""
-
-from unittest import TestCase
-
-import numpy as np
-import numpy.testing as npt
-
-from modopt.signal import filter, noise, positivity, svd, validation, wavelet
-
-
-class FilterTestCase(TestCase):
- """Test case for filter module."""
-
- def test_guassian_filter(self):
- """Test guassian_filter."""
- npt.assert_almost_equal(
- filter.gaussian_filter(1, 1),
- 0.24197072451914337,
- err_msg='Incorrect Gaussian filter',
- )
-
- npt.assert_almost_equal(
- filter.gaussian_filter(1, 1, norm=False),
- 0.60653065971263342,
- err_msg='Incorrect Gaussian filter',
- )
-
- def test_mex_hat(self):
- """Test mex_hat."""
- npt.assert_almost_equal(
- filter.mex_hat(2, 1),
- -0.35213905225713371,
- err_msg='Incorrect Mexican hat filter',
- )
-
- def test_mex_hat_dir(self):
- """Test mex_hat_dir."""
- npt.assert_almost_equal(
- filter.mex_hat_dir(1, 2, 1),
- 0.17606952612856686,
- err_msg='Incorrect directional Mexican hat filter',
- )
-
-
-class NoiseTestCase(TestCase):
- """Test case for noise module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.array(
- [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
- )
- self.data3 = np.array([
- [1.62434536, 0.38824359, 1.47182825],
- [1.92703138, 4.86540763, 2.6984613],
- [7.74481176, 6.2387931, 8.3190391],
- ])
- self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
- self.data5 = np.array(
- [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
- )
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
-
- def test_add_noise_poisson(self):
- """Test add_noise with Poisson noise."""
- np.random.seed(1)
- npt.assert_array_equal(
- noise.add_noise(self.data1, noise_type='poisson'),
- self.data2,
- err_msg='Incorrect noise: Poisson',
- )
-
- npt.assert_raises(
- ValueError,
- noise.add_noise,
- self.data1,
- noise_type='bla',
- )
-
- npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1))
-
- def test_add_noise_gaussian(self):
- """Test add_noise with Gaussian noise."""
- np.random.seed(1)
- npt.assert_almost_equal(
- noise.add_noise(self.data1),
- self.data3,
- err_msg='Incorrect noise: Gaussian',
- )
-
- np.random.seed(1)
- npt.assert_almost_equal(
- noise.add_noise(self.data1, sigma=(1, 1, 1)),
- self.data3,
- err_msg='Incorrect noise: Gaussian',
- )
-
- def test_thresh_hard(self):
- """Test thresh with hard threshold."""
- npt.assert_array_equal(
- noise.thresh(self.data1, 5),
- self.data4,
- err_msg='Incorrect threshold: hard',
- )
-
- npt.assert_raises(
- ValueError,
- noise.thresh,
- self.data1,
- 5,
- threshold_type='bla',
- )
-
- def test_thresh_soft(self):
- """Test thresh with soft threshold."""
- npt.assert_array_equal(
- noise.thresh(self.data1, 5, threshold_type='soft'),
- self.data5,
- err_msg='Incorrect threshold: soft',
- )
-
-
-class PositivityTestCase(TestCase):
- """Test case for positivity module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3) - 5
- self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]])
- self.data3 = np.array(
- [np.arange(5) - 3, np.arange(4) - 2],
- dtype=object,
- )
- self.data4 = np.array(
- [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])],
- dtype=object,
- )
- self.pos_dtype_obj = positivity.positive(self.data3)
- self.err = 'Incorrect positivity'
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
-
- def test_positivity(self):
- """Test positivity."""
- npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err)
-
- npt.assert_equal(
- positivity.positive(-1.0),
- -float(0),
- err_msg=self.err,
- )
-
- npt.assert_equal(
- positivity.positive(self.data1),
- self.data2,
- err_msg=self.err,
- )
-
- for expected, output in zip(self.data4, self.pos_dtype_obj):
- print(expected, output)
- npt.assert_array_equal(expected, output, err_msg=self.err)
-
- npt.assert_raises(TypeError, positivity.positive, '-1')
-
-
-class SVDTestCase(TestCase):
- """Test case for svd module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(18).reshape(9, 2).astype(float)
- self.data2 = np.arange(32).reshape(16, 2).astype(float)
- self.data3 = np.array(
- [
- np.array([
- [-0.01744594, -0.61438865],
- [-0.08435304, -0.50397984],
- [-0.15126014, -0.39357102],
- [-0.21816724, -0.28316221],
- [-0.28507434, -0.17275339],
- [-0.35198144, -0.06234457],
- [-0.41888854, 0.04806424],
- [-0.48579564, 0.15847306],
- [-0.55270274, 0.26888188],
- ]),
- np.array([42.23492742, 1.10041151]),
- np.array([
- [-0.67608034, -0.73682791],
- [0.73682791, -0.67608034],
- ]),
- ],
- dtype=object,
- )
- self.data4 = np.array([
- [-1.05426832e-16, 1.0],
- [2.0, 3.0],
- [4.0, 5.0],
- [6.0, 7.0],
- [8.0, 9.0],
- [1.0e1, 1.1e1],
- [1.2e1, 1.3e1],
- [1.4e1, 1.5e1],
- [1.6e1, 1.7e1],
- ])
- self.data5 = np.array([
- [0.49815487, 0.54291537],
- [2.40863386, 2.62505584],
- [4.31911286, 4.70719631],
- [6.22959185, 6.78933678],
- [8.14007085, 8.87147725],
- [10.05054985, 10.95361772],
- [11.96102884, 13.03575819],
- [13.87150784, 15.11789866],
- [15.78198684, 17.20003913],
- ])
- self.svd = svd.calculate_svd(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.svd = None
-
- def test_find_n_pc(self):
- """Test find_n_pc."""
- npt.assert_equal(
- svd.find_n_pc(svd.svd(self.data2)[0]),
- 2,
- err_msg='Incorrect number of principal components.',
- )
-
- npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3))
-
- def test_calculate_svd(self):
- """Test calculate_svd."""
- npt.assert_almost_equal(
- self.svd[0],
- np.array(self.data3)[0],
- err_msg='Incorrect SVD calculation: U',
- )
-
- npt.assert_almost_equal(
- self.svd[1],
- np.array(self.data3)[1],
- err_msg='Incorrect SVD calculation: S',
- )
-
- npt.assert_almost_equal(
- self.svd[2],
- np.array(self.data3)[2],
- err_msg='Incorrect SVD calculation: V',
- )
-
- def test_svd_thresh(self):
- """Test svd_thresh."""
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1),
- self.data4,
- err_msg='Incorrect SVD tresholding',
- )
-
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1, n_pc=1),
- self.data5,
- err_msg='Incorrect SVD tresholding',
- )
-
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1, n_pc='all'),
- self.data1,
- err_msg='Incorrect SVD tresholding',
- )
-
- npt.assert_raises(TypeError, svd.svd_thresh, 1)
-
- npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla')
-
- def test_svd_thresh_coef(self):
- """Test svd_thresh_coef."""
- npt.assert_almost_equal(
- svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0),
- self.data1,
- err_msg='Incorrect SVD coefficient tresholding',
- )
-
- npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0)
-
-
-class ValidationTestCase(TestCase):
- """Test case for validation module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
-
- def test_transpose_test(self):
- """Test transpose_test."""
- np.random.seed(2)
- npt.assert_equal(
- validation.transpose_test(
- lambda x_val, y_val: x_val.dot(y_val),
- lambda x_val, y_val: x_val.dot(y_val.T),
- self.data1.shape,
- x_args=self.data1,
- ),
- None,
- )
-
- npt.assert_raises(
- TypeError,
- validation.transpose_test,
- 0,
- 0,
- self.data1.shape,
- x_args=self.data1,
- )
-
-
-class WaveletTestCase(TestCase):
- """Test case for wavelet module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.arange(36).reshape(4, 3, 3).astype(float)
- self.data3 = np.array([
- [
- [6.0, 20, 26.0],
- [36.0, 84.0, 84.0],
- [90, 164.0, 134.0],
- ],
- [
- [78.0, 155.0, 134.0],
- [225.0, 408.0, 327.0],
- [270, 461.0, 350],
- ],
- [
- [150, 290, 242.0],
- [414.0, 732.0, 570],
- [450, 758.0, 566.0],
- ],
- [
- [222.0, 425.0, 350],
- [603.0, 1056.0, 813.0],
- [630, 1055.0, 782.0],
- ],
- ])
-
- self.data4 = np.array([
- [6496.0, 9796.0, 6544.0],
- [9924.0, 14910, 9924.0],
- [6544.0, 9796.0, 6496.0],
- ])
-
- self.data5 = np.array([
- [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]],
- [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]],
- [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]],
- ])
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
-
- def test_filter_convolve(self):
- """Test filter_convolve."""
- npt.assert_almost_equal(
- wavelet.filter_convolve(self.data1, self.data2),
- self.data3,
- err_msg='Inccorect filter comvolution.',
- )
-
- npt.assert_almost_equal(
- wavelet.filter_convolve(self.data2, self.data2, filter_rot=True),
- self.data4,
- err_msg='Inccorect filter comvolution.',
- )
-
- def test_filter_convolve_stack(self):
- """Test filter_convolve_stack."""
- npt.assert_almost_equal(
- wavelet.filter_convolve_stack(self.data1, self.data1),
- self.data5,
- err_msg='Inccorect filter stack comvolution.',
- )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..84eb967a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,58 @@
+[project]
+name="modopt"
+description = 'Modular Optimisation tools for soliving inverse problems.'
+version = "1.7.2"
+requires-python= ">=3.8"
+
+authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
+{name="Chaithya GR", email="chaithyagr@gmail.com"},
+{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"},
+{name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"}
+]
+readme="README.md"
+license={file="LICENCE.txt"}
+
+dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"]
+
+[project.optional-dependencies]
+gpu=["torch", "ptwt"]
+doc=["myst-parser",
+"nbsphinx",
+"nbsphinx-link",
+"sphinx-gallery",
+"numpydoc",
+"sphinxawesome-theme",
+"sphinxcontrib-bibtex"]
+dev=["black", "ruff"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"]
+
+[build-system]
+requires=["setuptools", "setuptools-scm[toml]", "wheel"]
+
+[tool.coverage.run]
+omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
+
+[tool.coverage.report]
+precision = 2
+exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
+
+[tool.black]
+
+
+[tool.ruff]
+exclude = ["examples", "docs"]
+[tool.ruff.lint]
+select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
+
+ignore = ["F401"] # we like the try: import ... expect: ...
+
+[tool.ruff.lint.pydocstyle]
+convention="numpy"
+
+[tool.isort]
+profile="black"
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+norecursedirs = ["tests/test_helpers"]
+addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 63a404ba..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-importlib_metadata>=3.7.0
-numpy>=1.19.5
-scipy>=1.5.4
-progressbar2>=3.53.1
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index cabd35a0..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,91 +0,0 @@
-[aliases]
-test=pytest
-
-[metadata]
-description_file = README.rst
-
-[darglint]
-docstring_style = numpy
-strictness = short
-
-[flake8]
-ignore =
- D107, #Justification: Don't need docstring for __init__ in numpydoc style
- RST304, #Justification: Need to use :cite: role for citations
- RST210, #Justification: RST210, RST213 Inconsistent with numpydoc
- RST213, # documentation for handling *args and **kwargs
- W503, #Justification: Have to choose one multiline operator format
- WPS202, #Todo: Rethink module size, possibly split large modules
- WPS337, #Todo: Consider simplifying multiline conditions.
- WPS338, #Todo: Consider changing method order
- WPS403, #Todo: Rethink no cover lines
- WPS421, #Todo: Review need for print statements
- WPS432, #Justification: Mathematical codes require "magic numbers"
- WPS433, #Todo: Rethink conditional imports
- WPS463, #Todo: Rename get_ methods
- WPS615, #Todo: Rename get_ methods
-per-file-ignores =
- #Justification: Needed for keeping package version and current API
- *__init__.py*: F401,F403,WPS347,WPS410,WPS412
- #Todo: Rethink conditional imports
- #Todo: How can we bypass mutable constants?
- modopt/base/backend.py: WPS229, WPS420, WPS407
- #Todo: Rethink conditional imports
- modopt/base/observable.py: WPS420,WPS604
- #Todo: Check string for log formatting
- modopt/interface/log.py: WPS323
- #Todo: Rethink conditional imports
- modopt/math/convolve.py: WPS301,WPS420
- #Todo: Rethink conditional imports
- modopt/math/matrix.py: WPS420
- #Todo: import has bad parenthesis
- modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
- #Todo: x is a too short name.
- modopt/opt/algorithms/forward_backward.py: WPS111
- #Todo: Check need for del statement
- modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
- #multiline parameters bug with tuples
- modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317
- #Todo: Consider changing costObj name
- modopt/opt/cost.py: N801,
- #Todo:
- # - Rethink subscript slice assignment
- # - Reduce complexity of KSupportNorm
- # - Check bitwise operations
- modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
- #Todo: Consider changing cwbReweight name
- modopt/opt/reweight.py: N801
- #Justification: Needed to import matplotlib.pyplot
- modopt/plot/cost_plot.py: N802,WPS301
- #Todo: Investigate possible bug in find_n_pc function
- #Todo: Investigate darglint error
- modopt/signal/svd.py: WPS345, DAR000
- #Todo: Check security of using system executable call
- modopt/signal/wavelet.py: S404,S603
- #Todo: Clean up tests
- modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604
- #Todo: Import has bad parenthesis
- modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301
-#WPS Settings
-max-arguments = 25
-max-attributes = 40
-max-cognitive-score = 20
-max-function-expressions = 20
-max-line-complexity = 30
-max-local-variables = 10
-max-methods = 20
-max-module-expressions = 20
-max-string-usages = 20
-max-raises = 5
-
-[tool:pytest]
-testpaths =
- modopt
-addopts =
- --verbose
- --emoji
- --flake8
- --cov=modopt
- --cov-report=term
- --cov-report=xml
- --junitxml=pytest.xml
diff --git a/setup.py b/setup.py
deleted file mode 100644
index c93dd020..00000000
--- a/setup.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#! /usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from setuptools import setup, find_packages
-import os
-
-# Set the package release version
-major = 1
-minor = 6
-patch = 1
-
-# Set the package details
-name = 'modopt'
-version = '.'.join(str(value) for value in (major, minor, patch))
-author = 'Samuel Farrens'
-email = 'samuel.farrens@cea.fr'
-gh_user = 'cea-cosmic'
-url = 'https://github.com/{0}/{1}'.format(gh_user, name)
-description = 'Modular Optimisation tools for soliving inverse problems.'
-license = 'MIT'
-
-# Set the package classifiers
-python_versions_supported = ['3.6', '3.7', '3.8', '3.9']
-os_platforms_supported = ['Unix', 'MacOS']
-
-lc_str = 'License :: OSI Approved :: {0} License'
-ln_str = 'Programming Language :: Python'
-py_str = 'Programming Language :: Python :: {0}'
-os_str = 'Operating System :: {0}'
-
-classifiers = (
- [lc_str.format(license)]
- + [ln_str]
- + [py_str.format(ver) for ver in python_versions_supported]
- + [os_str.format(ops) for ops in os_platforms_supported]
-)
-
-# Source package description from README.md
-this_directory = os.path.abspath(os.path.dirname(__file__))
-with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
- long_description = f.read()
-
-# Source package requirements from requirements.txt
-with open('requirements.txt') as open_file:
- install_requires = open_file.read()
-
-# Source test requirements from develop.txt
-with open('develop.txt') as open_file:
- tests_require = open_file.read()
-
-# Source doc requirements from docs/requirements.txt
-with open('docs/requirements.txt') as open_file:
- docs_require = open_file.read()
-
-
-setup(
- name=name,
- author=author,
- author_email=email,
- version=version,
- license=license,
- url=url,
- description=description,
- long_description=long_description,
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
- python_requires='>=3.6',
- setup_requires=['pytest-runner'],
- tests_require=tests_require,
- extras_require={'develop': tests_require + docs_require},
- classifiers=classifiers,
-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
new file mode 100644
index 00000000..5e8de1b6
--- /dev/null
+++ b/src/modopt/__init__.py
@@ -0,0 +1,25 @@
+"""MODOPT PACKAGE.
+
+ModOpt is a series of Modular Optimisation tools for solving inverse problems.
+
+"""
+
+from warnings import warn
+
+from importlib_metadata import version
+
+from modopt.base import np_adjust, transform, types, observable
+
+__all__ = ["np_adjust", "transform", "types", "observable"]
+
+try:
+ _version = version("modopt")
+except Exception: # pragma: no cover
+ _version = "Unkown"
+ warn(
+ "Could not extract package metadata. Make sure the package is "
+ + "correctly installed.",
+ stacklevel=1,
+ )
+
+__version__ = _version
diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py
similarity index 68%
rename from modopt/base/__init__.py
rename to src/modopt/base/__init__.py
index 1c0c8b2c..c4c681d7 100644
--- a/modopt/base/__init__.py
+++ b/src/modopt/base/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""BASE ROUTINES.
This module contains submodules for basic operations such as type
@@ -9,4 +7,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'wrappers', 'observable']
+__all__ = ["np_adjust", "transform", "types", "observable"]
diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py
similarity index 77%
rename from modopt/base/backend.py
rename to src/modopt/base/backend.py
index 1f4e9a72..485f649a 100644
--- a/modopt/base/backend.py
+++ b/src/modopt/base/backend.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""BACKEND MODULE.
This module contains methods for GPU Compatiblity.
@@ -26,22 +24,24 @@
# Handle the compatibility with variable
LIBRARIES = {
- 'cupy': None,
- 'tensorflow': None,
- 'numpy': np,
+ "cupy": None,
+ "tensorflow": None,
+ "numpy": np,
}
-if util.find_spec('cupy') is not None:
+if util.find_spec("cupy") is not None:
try:
import cupy as cp
- LIBRARIES['cupy'] = cp
+
+ LIBRARIES["cupy"] = cp
except ImportError:
pass
-if util.find_spec('tensorflow') is not None:
+if util.find_spec("tensorflow") is not None:
try:
from tensorflow.experimental import numpy as tnp
- LIBRARIES['tensorflow'] = tnp
+
+ LIBRARIES["tensorflow"] = tnp
except ImportError:
pass
@@ -66,12 +66,12 @@ def get_backend(backend):
"""
if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
msg = (
- '{0} backend not possible, please ensure that '
- + 'the optional libraries are installed.\n'
- + 'Reverting to numpy.'
+ "{0} backend not possible, please ensure that "
+ + "the optional libraries are installed.\n"
+ + "Reverting to numpy."
)
warn(msg.format(backend))
- backend = 'numpy'
+ backend = "numpy"
return LIBRARIES[backend], backend
@@ -92,16 +92,16 @@ def get_array_module(input_data):
The numpy or cupy module
"""
- if LIBRARIES['tensorflow'] is not None:
- if isinstance(input_data, LIBRARIES['tensorflow'].ndarray):
- return LIBRARIES['tensorflow']
- if LIBRARIES['cupy'] is not None:
- if isinstance(input_data, LIBRARIES['cupy'].ndarray):
- return LIBRARIES['cupy']
+ if LIBRARIES["tensorflow"] is not None:
+ if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
+ return LIBRARIES["tensorflow"]
+ if LIBRARIES["cupy"] is not None:
+ if isinstance(input_data, LIBRARIES["cupy"].ndarray):
+ return LIBRARIES["cupy"]
return np
-def change_backend(input_data, backend='cupy'):
+def change_backend(input_data, backend="cupy"):
"""Move data to device.
This method changes the backend of an array. This can be used to copy data
@@ -151,13 +151,13 @@ def move_to_cpu(input_data):
"""
xp = get_array_module(input_data)
- if xp == LIBRARIES['numpy']:
+ if xp == LIBRARIES["numpy"]:
return input_data
- elif xp == LIBRARIES['cupy']:
+ elif xp == LIBRARIES["cupy"]:
return input_data.get()
- elif xp == LIBRARIES['tensorflow']:
+ elif xp == LIBRARIES["tensorflow"]:
return input_data.data.numpy()
- raise ValueError('Cannot identify the array type.')
+ raise ValueError("Cannot identify the array type.")
def convert_to_tensor(input_data):
@@ -184,9 +184,9 @@ def convert_to_tensor(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
xp = get_array_module(input_data)
@@ -220,9 +220,9 @@ def convert_to_cupy_array(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
if input_data.is_cuda:
diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
similarity index 96%
rename from modopt/base/np_adjust.py
rename to src/modopt/base/np_adjust.py
index 6d290e43..10cb5c29 100644
--- a/modopt/base/np_adjust.py
+++ b/src/modopt/base/np_adjust.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""NUMPY ADJUSTMENT ROUTINES.
This module contains methods for adjusting the default output for certain
@@ -154,8 +152,7 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- 'Padding must be an integer or a tuple (or list, np.ndarray) '
- + 'of itegers',
+ "Padding must be an integer or a tuple (or list, np.ndarray) of integers",
)
if padding.size == 1:
@@ -164,7 +161,7 @@ def pad2d(input_data, padding):
pad_x = (padding[0], padding[0])
pad_y = (padding[1], padding[1])
- return np.pad(input_data, (pad_x, pad_y), 'constant')
+ return np.pad(input_data, (pad_x, pad_y), "constant")
def ftr(input_data):
diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py
similarity index 95%
rename from modopt/base/observable.py
rename to src/modopt/base/observable.py
index 6471ba58..bf8371c3 100644
--- a/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""Observable.
This module contains observable classes
@@ -13,13 +11,13 @@
import numpy as np
-class SignalObject(object):
+class SignalObject:
"""Dummy class for signals."""
pass
-class Observable(object):
+class Observable:
"""Base class for observable classes.
This class defines a simple interface to add or remove observers
@@ -33,7 +31,6 @@ class Observable(object):
"""
def __init__(self, signals):
-
# Define class parameters
self._allowed_signals = []
self._observers = {}
@@ -177,7 +174,7 @@ def _remove_observer(self, signal, observer):
self._observers[signal].remove(observer)
-class MetricObserver(object):
+class MetricObserver:
"""Metric observer.
Wrapper of the metric to the observer object notify by the Observable
@@ -215,7 +212,6 @@ def __init__(
wind=6,
eps=1.0e-3,
):
-
self.name = name
self.metric = metric
self.mapping = mapping
@@ -264,9 +260,7 @@ def is_converge(self):
mid_idx = -(self.wind // 2)
old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
- normalize_residual_metrics = (
- np.abs(old_mean - current_mean) / np.abs(old_mean)
- )
+ normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
self.converge_flag = normalize_residual_metrics < self.eps
def retrieve_metrics(self):
@@ -287,7 +281,7 @@ def retrieve_metrics(self):
time_val -= time_val[0]
return {
- 'time': time_val,
- 'index': self.list_iters,
- 'values': self.list_cv_values,
+ "time": time_val,
+ "index": self.list_iters,
+ "values": self.list_cv_values,
}
diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py
similarity index 87%
rename from modopt/base/transform.py
rename to src/modopt/base/transform.py
index 07ce846f..25ed102a 100644
--- a/modopt/base/transform.py
+++ b/src/modopt/base/transform.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""DATA TRANSFORM ROUTINES.
This module contains methods for transforming data.
@@ -53,18 +51,17 @@ def cube2map(data_cube, layout):
"""
if data_cube.ndim != 3:
- raise ValueError('The input data must have 3 dimensions.')
+ raise ValueError("The input data must have 3 dimensions.")
if data_cube.shape[0] != np.prod(layout):
raise ValueError(
- 'The desired layout must match the number of input '
- + 'data layers.',
+ "The desired layout must match the number of input " + "data layers.",
)
- res = ([
+ res = [
np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
for elem in range(layout[0])
- ])
+ ]
return np.vstack(res)
@@ -118,20 +115,24 @@ def map2cube(data_map, layout):
"""
if np.all(np.array(data_map.shape) % np.array(layout)):
raise ValueError(
- 'The desired layout must be a multiple of the number '
- + 'pixels in the data map.',
+ "The desired layout must be a multiple of the number "
+ + "pixels in the data map.",
)
d_shape = np.array(data_map.shape) // np.array(layout)
- return np.array([
- data_map[(
- slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
- slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
- )]
- for i_elem in range(layout[0])
- for j_elem in range(layout[1])
- ])
+ return np.array(
+ [
+ data_map[
+ (
+ slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
+ slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
+ )
+ ]
+ for i_elem in range(layout[0])
+ for j_elem in range(layout[1])
+ ]
+ )
def map2matrix(data_map, layout):
@@ -186,9 +187,9 @@ def map2matrix(data_map, layout):
image_shape * (i_elem % layout[1] + 1),
)
data_matrix.append(
- (
- data_map[lower[0]:upper[0], lower[1]:upper[1]]
- ).reshape(image_shape ** 2),
+ (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
+ image_shape**2
+ ),
)
return np.array(data_matrix).T
@@ -232,7 +233,7 @@ def matrix2map(data_matrix, map_shape):
# Get the shape and layout of the images
image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
- layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int')
+ layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
# Map objects from matrix
data_map = np.zeros(map_shape)
@@ -248,7 +249,7 @@ def matrix2map(data_matrix, map_shape):
image_shape * (i_elem // layout[1] + 1),
image_shape * (i_elem % layout[1] + 1),
)
- data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem]
+ data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
return data_map.astype(int)
@@ -285,7 +286,7 @@ def cube2matrix(data_cube):
"""
return data_cube.reshape(
- [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])],
+ [data_cube.shape[0], np.prod(data_cube.shape[1:])],
).T
@@ -330,4 +331,4 @@ def matrix2cube(data_matrix, im_shape):
cube2matrix : complimentary function
"""
- return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape))
+ return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)])
diff --git a/modopt/base/types.py b/src/modopt/base/types.py
similarity index 74%
rename from modopt/base/types.py
rename to src/modopt/base/types.py
index 88051675..9e9a15b9 100644
--- a/modopt/base/types.py
+++ b/src/modopt/base/types.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""TYPE HANDLING ROUTINES.
This module contains methods for handing object types.
@@ -9,12 +7,10 @@
"""
import numpy as np
-
-from modopt.base.wrappers import add_args_kwargs
from modopt.interface.errors import warn
-def check_callable(input_obj, add_agrs=True):
+def check_callable(input_obj):
"""Check input object is callable.
This method checks if the input operator is a callable funciton and
@@ -25,30 +21,14 @@ def check_callable(input_obj, add_agrs=True):
----------
input_obj : callable
Callable function
- add_agrs : bool, optional
- Option to add support for agrs and kwargs (default is ``True``)
-
- Returns
- -------
- function
- Function wrapped by ``add_args_kwargs``
Raises
------
TypeError
For invalid input type
-
- See Also
- --------
- modopt.base.wrappers.add_args_kwargs : wrapper used
-
"""
if not callable(input_obj):
- raise TypeError('The input object must be a callable function.')
-
- if add_agrs:
- input_obj = add_args_kwargs(input_obj)
-
+ raise TypeError("The input object must be a callable function.")
return input_obj
@@ -89,14 +69,13 @@ def check_float(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, int):
input_obj = float(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=float)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.floating))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.floating)
):
input_obj = input_obj.astype(float)
@@ -139,14 +118,13 @@ def check_int(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, float):
input_obj = int(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=int)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.integer))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.integer)
):
input_obj = input_obj.astype(int)
@@ -178,19 +156,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
"""
if not isinstance(input_obj, np.ndarray):
- raise TypeError('Input is not a numpy array.')
+ raise TypeError("Input is not a numpy array.")
- if (
- (not isinstance(dtype, type(None)))
- and (not np.issubdtype(input_obj.dtype, dtype))
+ if (not isinstance(dtype, type(None))) and (
+ not np.issubdtype(input_obj.dtype, dtype)
):
raise (
TypeError(
- 'The numpy array elements are not of type: {0}'.format(dtype),
+ f"The numpy array elements are not of type: {dtype}",
),
)
if not writeable and verbose and input_obj.flags.writeable:
- warn('Making input data immutable.')
+ warn("Making input data immutable.")
input_obj.flags.writeable = writeable
diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
similarity index 75%
rename from modopt/interface/__init__.py
rename to src/modopt/interface/__init__.py
index f9439747..a54f4bf5 100644
--- a/modopt/interface/__init__.py
+++ b/src/modopt/interface/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""INTERFACE ROUTINES.
This module contains submodules for error handling, logging and IO interaction.
@@ -8,4 +6,4 @@
"""
-__all__ = ['errors', 'log']
+__all__ = ["errors", "log"]
diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py
similarity index 75%
rename from modopt/interface/errors.py
rename to src/modopt/interface/errors.py
index 0fbe7e71..84031e3c 100644
--- a/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""ERROR HANDLING ROUTINES.
This module contains methods for handing warnings and errors.
@@ -34,16 +32,16 @@ def warn(warn_string, log=None):
"""
if import_fail:
- warn_txt = 'WARNING'
+ warn_txt = "WARNING"
else:
- warn_txt = colored('WARNING', 'yellow')
+ warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string))
+ sys.stderr.write(f"{warn_txt}: {warn_string}\n")
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- warnings.warn(warn_string)
+ warnings.warn(warn_string, stacklevel=2)
def catch_error(exception, log=None):
@@ -61,17 +59,17 @@ def catch_error(exception, log=None):
"""
if import_fail:
- err_txt = 'ERROR'
+ err_txt = "ERROR"
else:
- err_txt = colored('ERROR', 'red')
+ err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = '{0}: {1}\n'.format(err_txt, exception)
+ stream_txt = f"{err_txt}: {exception}\n"
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = 'ERROR: {0}\n'.format(exception)
+ log_txt = f"ERROR: {exception}\n"
log.exception(log_txt)
@@ -91,11 +89,11 @@ def file_name_error(file_name):
If file name not specified or file not found
"""
- if file_name == '' or file_name[0][0] == '-':
- raise IOError('Input file name not specified.')
+ if file_name == "" or file_name[0][0] == "-":
+ raise OSError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError('Input file name {0} not found!'.format(file_name))
+ raise OSError(f"Input file name {file_name} not found!")
def is_exe(fpath):
@@ -136,7 +134,7 @@ def is_executable(exe_name):
"""
if not isinstance(exe_name, str):
- raise TypeError('Executable name must be a string.')
+ raise TypeError("Executable name must be a string.")
fpath, fname = os.path.split(exe_name)
@@ -146,11 +144,9 @@ def is_executable(exe_name):
else:
res = any(
is_exe(os.path.join(path, exe_name))
- for path in os.environ['PATH'].split(os.pathsep)
+ for path in os.environ["PATH"].split(os.pathsep)
)
if not res:
- message = (
- '{0} does not appear to be a valid executable on this system.'
- )
- raise IOError(message.format(exe_name))
+ message = "{0} does not appear to be a valid executable on this system."
+ raise OSError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py
similarity index 77%
rename from modopt/interface/log.py
rename to src/modopt/interface/log.py
index 3b2fa77a..50c316b7 100644
--- a/modopt/interface/log.py
+++ b/src/modopt/interface/log.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""LOGGING ROUTINES.
This module contains methods for handing logging.
@@ -30,22 +28,22 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = '{0}.log'.format(filename)
+ filename = f"{filename}.log"
if verbose:
- print('Preparing log file:', filename)
+ print("Preparing log file:", filename)
# Capture warnings.
logging.captureWarnings(True)
# Set output format.
formatter = logging.Formatter(
- fmt='%(asctime)s %(message)s',
- datefmt='%d/%m/%Y %H:%M:%S',
+ fmt="%(asctime)s %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
)
# Create file handler.
- fh = logging.FileHandler(filename=filename, mode='w')
+ fh = logging.FileHandler(filename=filename, mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
@@ -55,7 +53,7 @@ def set_up_log(filename, verbose=True):
log.addHandler(fh)
# Send opening message.
- log.info('The log file has been set-up.')
+ log.info("The log file has been set-up.")
return log
@@ -74,10 +72,10 @@ def close_log(log, verbose=True):
"""
if verbose:
- print('Closing log file:', log.name)
+ print("Closing log file:", log.name)
# Send closing message.
- log.info('The log file has been closed.')
+ log.info("The log file has been closed.")
# Remove all handlers from log.
for log_handler in log.handlers:
diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py
similarity index 64%
rename from modopt/math/__init__.py
rename to src/modopt/math/__init__.py
index a22c0c98..d5ffc67a 100644
--- a/modopt/math/__init__.py
+++ b/src/modopt/math/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""MATHEMATICS ROUTINES.
This module contains submodules for mathematical applications.
@@ -8,4 +6,4 @@
"""
-__all__ = ['convolve', 'matrix', 'stats', 'metrics']
+__all__ = ["convolve", "matrix", "stats", "metrics"]
diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py
similarity index 87%
rename from modopt/math/convolve.py
rename to src/modopt/math/convolve.py
index a4322ff2..21dc8b4e 100644
--- a/modopt/math/convolve.py
+++ b/src/modopt/math/convolve.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""CONVOLUTION ROUTINES.
This module contains methods for convolution.
@@ -18,7 +16,7 @@
from astropy.convolution import convolve_fft
except ImportError: # pragma: no cover
import_astropy = False
- warn('astropy not found, will default to scipy for convolution')
+ warn("astropy not found, will default to scipy for convolution")
else:
import_astropy = True
try:
@@ -30,7 +28,7 @@
warn('Using pyFFTW "monkey patch" for scipy.fftpack')
-def convolve(input_data, kernel, method='scipy'):
+def convolve(input_data, kernel, method="scipy"):
"""Convolve data with kernel.
This method convolves the input data with a given kernel using FFT and
@@ -80,29 +78,29 @@ def convolve(input_data, kernel, method='scipy'):
"""
if input_data.ndim != kernel.ndim:
- raise ValueError('Data and kernel must have the same dimensions.')
+ raise ValueError("Data and kernel must have the same dimensions.")
- if method not in {'astropy', 'scipy'}:
+ if method not in {"astropy", "scipy"}:
raise ValueError('Invalid method. Options are "astropy" or "scipy".')
if not import_astropy: # pragma: no cover
- method = 'scipy'
+ method = "scipy"
- if method == 'astropy':
+ if method == "astropy":
return convolve_fft(
input_data,
kernel,
- boundary='wrap',
+ boundary="wrap",
crop=False,
- nan_treatment='fill',
+ nan_treatment="fill",
normalize_kernel=False,
)
- elif method == 'scipy':
- return scipy.signal.fftconvolve(input_data, kernel, mode='same')
+ elif method == "scipy":
+ return scipy.signal.fftconvolve(input_data, kernel, mode="same")
-def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
+def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
"""Convolve stack of data with stack of kernels.
This method convolves the input data with a given kernel using FFT and
@@ -156,7 +154,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
if rot_kernel:
kernel = rotate_stack(kernel)
- return np.array([
- convolve(data_i, kernel_i, method=method)
- for data_i, kernel_i in zip(input_data, kernel)
- ])
+ return np.array(
+ [
+ convolve(data_i, kernel_i, method=method)
+ for data_i, kernel_i in zip(input_data, kernel)
+ ]
+ )
diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py
similarity index 88%
rename from modopt/math/matrix.py
rename to src/modopt/math/matrix.py
index 939cf41f..b200f15d 100644
--- a/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""MATRIX ROUTINES.
This module contains methods for matrix operations.
@@ -15,7 +13,7 @@
from modopt.base.backend import get_array_module, get_backend
-def gram_schmidt(matrix, return_opt='orthonormal'):
+def gram_schmidt(matrix, return_opt="orthonormal"):
r"""Gram-Schmit.
This method orthonormalizes the row vectors of the input matrix.
@@ -55,7 +53,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
"""
- if return_opt not in {'orthonormal', 'orthogonal', 'both'}:
+ if return_opt not in {"orthonormal", "orthogonal", "both"}:
raise ValueError(
'Invalid return_opt, options are: "orthonormal", "orthogonal" or '
+ '"both"',
@@ -65,7 +63,6 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
e_vec = []
for vector in matrix:
-
if u_vec:
u_now = vector - sum(project(u_i, vector) for u_i in u_vec)
else:
@@ -77,11 +74,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
u_vec = np.array(u_vec)
e_vec = np.array(e_vec)
- if return_opt == 'orthonormal':
+ if return_opt == "orthonormal":
return e_vec
- elif return_opt == 'orthogonal':
+ elif return_opt == "orthogonal":
return u_vec
- elif return_opt == 'both':
+ elif return_opt == "both":
return u_vec, e_vec
@@ -201,7 +198,7 @@ def rot_matrix(angle):
return np.around(
np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
- dtype='float',
+ dtype="float",
),
10,
)
@@ -243,22 +240,21 @@ def rotate(matrix, angle):
shape = np.array(matrix.shape)
if shape[0] != shape[1]:
- raise ValueError('Input matrix must be square.')
+ raise ValueError("Input matrix must be square.")
shift = (shape - 1) // 2
index = (
- np.array(list(product(*np.array([np.arange(sval) for sval in shape]))))
- - shift
+ np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift
)
- new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift
+ new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift
new_index[new_index >= shape[0]] -= shape[0]
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
-class PowerMethod(object):
+class PowerMethod:
"""Power method class.
This method performs implements power method to calculate the spectral
@@ -277,6 +273,8 @@ class PowerMethod(object):
initialisation (default is ``True``)
verbose : bool, optional
Optional verbosity (default is ``False``)
+ rng: int, xp.random.Generator or None (default is ``None``)
+ Random number generator or seed.
Examples
--------
@@ -285,9 +283,9 @@ class PowerMethod(object):
>>> np.random.seed(1)
>>> pm = PowerMethod(lambda x: x.dot(x.T), (3, 3))
>>> np.around(pm.spec_rad, 6)
- 0.904292
+ 1.0
>>> np.around(pm.inv_spec_rad, 6)
- 1.105837
+ 1.0
Notes
-----
@@ -301,16 +299,17 @@ def __init__(
data_shape,
data_type=float,
auto_run=True,
- compute_backend='numpy',
+ compute_backend="numpy",
verbose=False,
+ rng=None,
):
-
self._operator = operator
self._data_shape = data_shape
self._data_type = data_type
self._verbose = verbose
xp, compute_backend = get_backend(compute_backend)
self.xp = xp
+ self.rng = None
self.compute_backend = compute_backend
if auto_run:
self.get_spec_rad()
@@ -327,7 +326,8 @@ def _set_initial_x(self):
Random values of the same shape as the input data
"""
- return self.xp.random.random(self._data_shape).astype(self._data_type)
+ rng = self.xp.random.default_rng(self.rng)
+ return rng.random(self._data_shape).astype(self._data_type)
def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
"""Get spectral radius.
@@ -348,32 +348,33 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
# Set (or reset) values of x.
x_old = self._set_initial_x()
+ xp = get_array_module(x_old)
+ x_old_norm = xp.linalg.norm(x_old)
+
+ x_old /= x_old_norm
+
# Iterate until the L2 norm of x converges.
for i_elem in range(max_iter):
-
xp = get_array_module(x_old)
- x_old_norm = xp.linalg.norm(x_old)
-
- x_new = self._operator(x_old) / x_old_norm
+ x_new = self._operator(x_old)
x_new_norm = xp.linalg.norm(x_new)
- if (xp.abs(x_new_norm - x_old_norm) < tolerance):
- message = (
- ' - Power Method converged after {0} iterations!'
- )
+ x_new /= x_new_norm
+
+ if xp.abs(x_new_norm - x_old_norm) < tolerance:
+ message = " - Power Method converged after {0} iterations!"
if self._verbose:
print(message.format(i_elem + 1))
break
elif i_elem == max_iter - 1 and self._verbose:
- message = (
- ' - Power Method did not converge after {0} iterations!'
- )
+ message = " - Power Method did not converge after {0} iterations!"
print(message.format(max_iter))
xp.copyto(x_old, x_new)
+ x_old_norm = x_new_norm
self.spec_rad = x_new_norm * extra_factor
self.inv_spec_rad = 1.0 / self.spec_rad
diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py
similarity index 90%
rename from modopt/math/metrics.py
rename to src/modopt/math/metrics.py
index cf41a9c2..befd4fa4 100644
--- a/modopt/math/metrics.py
+++ b/src/modopt/math/metrics.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""METRICS.
This module contains classes of different metric functions for optimization.
@@ -23,7 +21,7 @@
def min_max_normalize(img):
"""Min-Max Normalize.
- Centre and normalize a given array.
+ Normalize a given array in the [0,1] range.
Parameters
----------
@@ -33,7 +31,7 @@ def min_max_normalize(img):
Returns
-------
numpy.ndarray
- Centred and normalized array
+ normalized array
"""
min_img = img.min()
@@ -71,15 +69,13 @@ def _preprocess_input(test, ref, mask=None):
The SNR
"""
- test = np.abs(np.copy(test)).astype('float64')
- ref = np.abs(np.copy(ref)).astype('float64')
+ test = np.abs(np.copy(test)).astype("float64")
+ ref = np.abs(np.copy(ref)).astype("float64")
test = min_max_normalize(test)
ref = min_max_normalize(ref)
if (not isinstance(mask, np.ndarray)) and (mask is not None):
- message = (
- 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
- )
+ message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
raise ValueError(message.format(mask))
if mask is None:
@@ -119,9 +115,9 @@ def ssim(test, ref, mask=None):
"""
if not import_skimage: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Image package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Image package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
test, ref, mask = _preprocess_input(test, ref, mask)
@@ -270,6 +266,6 @@ def nrmse(test, ref, mask=None):
ref = mask * ref
num = np.sqrt(mse(test, ref))
- deno = np.sqrt(np.mean((np.square(test))))
+ deno = np.sqrt(np.mean(np.square(test)))
return num / deno
diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py
similarity index 85%
rename from modopt/math/stats.py
rename to src/modopt/math/stats.py
index 3ac818a7..8583a8c3 100644
--- a/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""STATISTICS ROUTINES.
This module contains methods for basic statistics.
@@ -11,6 +9,8 @@
import numpy as np
try:
+ from packaging import version
+ from astropy import __version__ as astropy_version
from astropy.convolution import Gaussian2DKernel
except ImportError: # pragma: no cover
import_astropy = False
@@ -18,7 +18,7 @@
import_astropy = True
-def gaussian_kernel(data_shape, sigma, norm='max'):
+def gaussian_kernel(data_shape, sigma, norm="max"):
"""Gaussian kernel.
This method produces a Gaussian kerenal of a specified size and dispersion.
@@ -29,9 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm='max'):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
- norm : {'max', 'sum', 'none'}, optional
- Normalisation of the kerenl (options are ``'max'``, ``'sum'`` or
- ``'none'``, default is ``'max'``)
+ norm : {'max', 'sum'}, optional, default='max'
+ Normalisation of the kernel
Returns
-------
@@ -60,22 +59,22 @@ def gaussian_kernel(data_shape, sigma, norm='max'):
"""
if not import_astropy: # pragma: no cover
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
- if norm not in {'max', 'sum', 'none'}:
+ if norm not in {"max", "sum"}:
raise ValueError('Invalid norm, options are "max", "sum" or "none".')
kernel = np.array(
Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]),
)
- if norm == 'max':
+ if norm == "max":
return kernel / np.max(kernel)
- elif norm == 'sum':
+ elif version.parse(astropy_version) < version.parse("5.2"):
return kernel / np.sum(kernel)
- elif norm == 'none':
+ else:
return kernel
@@ -147,7 +146,7 @@ def mse(data1, data2):
return np.mean((data1 - data2) ** 2)
-def psnr(data1, data2, method='starck', max_pix=255):
+def psnr(data1, data2, method="starck", max_pix=255):
r"""Peak Signal-to-Noise Ratio.
This method calculates the Peak Signal-to-Noise Ratio between two data
@@ -202,23 +201,21 @@ def psnr(data1, data2, method='starck', max_pix=255):
10\log_{10}(\mathrm{MSE}))
"""
- if method == 'starck':
- return (
- 20 * np.log10(
- (data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
- / np.linalg.norm(data1 - data2),
- )
+ if method == "starck":
+ return 20 * np.log10(
+ (data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
+ / np.linalg.norm(data1 - data2),
)
- elif method == 'wiki':
- return (20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)))
+ elif method == "wiki":
+ return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))
raise ValueError(
'Invalid PSNR method. Options are "starck" and "wiki"',
)
-def psnr_stack(data1, data2, metric=np.mean, method='starck'):
+def psnr_stack(data1, data2, metric=np.mean, method="starck"):
"""Peak Signa-to-Noise for stack of images.
This method calculates the PSNRs for two stacks of 2D arrays.
@@ -261,12 +258,11 @@ def psnr_stack(data1, data2, metric=np.mean, method='starck'):
"""
if data1.ndim != 3 or data2.ndim != 3:
- raise ValueError('Input data must be a 3D np.ndarray')
+ raise ValueError("Input data must be a 3D np.ndarray")
- return metric([
- psnr(i_elem, j_elem, method=method)
- for i_elem, j_elem in zip(data1, data2)
- ])
+ return metric(
+ [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)]
+ )
def sigma_mad(input_data):
diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
similarity index 59%
rename from modopt/opt/__init__.py
rename to src/modopt/opt/__init__.py
index 2fd3d747..62d1f388 100644
--- a/modopt/opt/__init__.py
+++ b/src/modopt/opt/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""OPTIMISATION PROBLEM MODULES.
This module contains submodules for solving optimisation problems.
@@ -8,4 +6,4 @@
"""
-__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight']
+__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
similarity index 53%
rename from modopt/opt/algorithms/__init__.py
rename to src/modopt/opt/algorithms/__init__.py
index e0ac2572..ff79502c 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/src/modopt/opt/algorithms/__init__.py
@@ -1,5 +1,4 @@
-# -*- coding: utf-8 -*-
-r"""OPTIMISATION ALGOTITHMS.
+r"""OPTIMISATION ALGORITHMS.
This module contains class implementations of various optimisation algoritms.
@@ -45,15 +44,32 @@
"""
-from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
- ForwardBackward,
- GenForwardBackward)
-from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt)
-from modopt.opt.algorithms.primal_dual import Condat
+from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM
+from .primal_dual import Condat
+from .gradient_descent import (
+ ADAMGradOpt,
+ AdaGenericGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
+from .admm import ADMM, FastADMM
+
+__all__ = [
+ "FISTA",
+ "ForwardBackward",
+ "GenForwardBackward",
+ "POGM",
+ "Condat",
+ "ADAMGradOpt",
+ "AdaGenericGradOpt",
+ "GenericGradOpt",
+ "MomentumGradOpt",
+ "RMSpropGradOpt",
+ "SAGAOptGradOpt",
+ "VanillaGenericGradOpt",
+ "ADMM",
+ "FastADMM",
+]
diff --git a/src/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
new file mode 100644
index 00000000..b2f45171
--- /dev/null
+++ b/src/modopt/opt/algorithms/admm.py
@@ -0,0 +1,340 @@
+"""ADMM Algorithms."""
+
+import numpy as np
+
+from modopt.base.backend import get_array_module
+from modopt.opt.algorithms.base import SetUp
+from modopt.opt.cost import CostParent
+
+
+class ADMMcostObj(CostParent):
+ r"""Cost Object for the ADMM problem class.
+
+ Parameters
+ ----------
+ cost_funcs: 2-tuples of callable
+ f and g function.
+ A : OperatorBase
+ First Operator
+ B : OperatorBase
+ Second Operator
+ b : numpy.ndarray
+ Observed data
+ **kwargs : dict
+ Extra parameters for cost operator configuration
+
+ Notes
+ -----
+ Compute :math:`f(u)+g(v) + \tau \| Au +Bv - b\|^2`
+
+ See Also
+ --------
+ CostParent: parent class
+ """
+
+ def __init__(self, cost_funcs, A, B, b, tau, **kwargs):
+ super().__init__(*kwargs)
+ self.cost_funcs = cost_funcs
+ self.A = A
+ self.B = B
+ self.b = b
+ self.tau = tau
+
+ def _calc_cost(self, u, v, **kwargs):
+ """Calculate the cost.
+
+ This method calculates the cost from each of the input operators.
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ First primal variable of ADMM
+ v: numpy.ndarray
+ Second primal variable of ADMM
+
+ Returns
+ -------
+ float
+ Cost value
+
+ """
+ xp = get_array_module(u)
+ cost = self.cost_funcs[0](u)
+ cost += self.cost_funcs[1](v)
+ cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b)
+ return cost
+
+
+class ADMM(SetUp):
+ r"""Fast ADMM Optimisation Algorihm.
+
+ This class implement the ADMM algorithm described in :cite:`Goldstein2014`
+ (Algorithm 1).
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ Initial value for first primal variable of ADMM
+ v: numpy.ndarray
+ Initial value for second primal variable of ADMM
+ mu: numpy.ndarray
+ Initial value for lagrangian multiplier.
+ A : modopt.opt.linear.LinearOperator
+ Linear operator for u
+ B: modopt.opt.linear.LinearOperator
+ Linear operator for v
+ b : numpy.ndarray
+ Constraint vector
+ optimizers: tuple
+ 2-tuple of callable, that are the optimizers for the u and v.
+ Each callable should access init and obs argument and returns an estimate for:
+ .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
+ .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
+ cost_funcs: tuple
+ 2-tuple of callable, that compute values of H and G.
+ tau: float, default=1
+ Coupling parameter for ADMM.
+
+ Notes
+ -----
+ The algorithm solve the problem:
+
+ .. math:: u, v = \arg\min H(u) + G(v) + \frac\tau2 \|Au + Bv - b \|_2^2
+
+ with the following augmented lagrangian:
+
+ .. math :: \mathcal{L}_{\tau}(u,v, \lambda) = H(u) + G(v)
+ +\langle\lambda |Au + Bv -b \rangle + \frac\tau2 \| Au + Bv -b \|^2
+
+ To allow easy iterative solving, the change of variable
+ :math:`\mu=\lambda/\tau` is used. Hence, the lagrangian of interest is:
+
+ .. math :: \tilde{\mathcal{L}}_{\tau}(u,v, \mu) = H(u) + G(v)
+ + \frac\tau2 \left(\|\mu + Au +Bv - b\|^2 - \|\mu\|^2\right)
+
+ See Also
+ --------
+ SetUp: parent class
+ """
+
+ def __init__(
+ self,
+ u,
+ v,
+ mu,
+ A,
+ B,
+ b,
+ optimizers,
+ tau=1,
+ cost_funcs=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.A = A
+ self.B = B
+ self.b = b
+ self._opti_H = optimizers[0]
+ self._opti_G = optimizers[1]
+ self._tau = tau
+ if cost_funcs is not None:
+ self._cost_func = ADMMcostObj(cost_funcs, A, B, b, tau)
+ else:
+ self._cost_func = None
+
+ # init iteration variables.
+ self._u_old = self.xp.copy(u)
+ self._u_new = self.xp.copy(u)
+ self._v_old = self.xp.copy(v)
+ self._v_new = self.xp.copy(v)
+ self._mu_new = self.xp.copy(mu)
+ self._mu_old = self.xp.copy(mu)
+
+ def _update(self):
+ self._u_new = self._opti_H(
+ init=self._u_old,
+ obs=self.B.op(self._v_old) + self._u_old - self.b,
+ )
+ tmp = self.A.op(self._u_new)
+ self._v_new = self._opti_G(
+ init=self._v_old,
+ obs=tmp + self._u_old - self.b,
+ )
+
+ self._mu_new = self._mu_old + (tmp + self.B.op(self._v_new) - self.b)
+
+ # update cycle
+ self._u_old = self.xp.copy(self._u_new)
+ self._v_old = self.xp.copy(self._v_new)
+ self._mu_old = self.xp.copy(self._mu_new)
+
+ # Test cost function for convergence.
+ if self._cost_func:
+ self.converge = self.any_convergence_flag()
+ self.converge |= self._cost_func.get_cost(self._u_new, self._v_new)
+
+ def iterate(self, max_iter=150):
+ """Iterate.
+
+ This method calls update until either convergence criteria is met or
+ the maximum number of iterations is reached.
+
+ Parameters
+ ----------
+ max_iter : int, optional
+ Maximum number of iterations (default is ``150``)
+ """
+ self._run_alg(max_iter)
+
+ # retrieve metrics results
+ self.retrieve_outputs()
+ # rename outputs as attributes
+ self.u_final = self._u_new
+ self.x_final = self.u_final # for backward compatibility
+ self.v_final = self._v_new
+
+ def get_notify_observers_kwargs(self):
+ """Notify observers.
+
+ Return the mapping between the metrics call and the iterated
+ variables.
+
+ Returns
+ -------
+ dict
+ The mapping between the iterated variables
+ """
+ return {
+ "x_new": self._u_new,
+ "v_new": self._v_new,
+ "idx": self.idx,
+ }
+
+ def retrieve_outputs(self):
+ """Retrieve outputs.
+
+ Declare the outputs of the algorithms as attributes: x_final,
+ y_final, metrics.
+ """
+ metrics = {}
+ for obs in self._observers["cv_metrics"]:
+ metrics[obs.name] = obs.retrieve_metrics()
+ self.metrics = metrics
+
+
+class FastADMM(ADMM):
+ r"""Fast ADMM Optimisation Algorihm.
+
+ This class implement the fast ADMM algorithm
+ (Algorithm 8 from :cite:`Goldstein2014`)
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ Initial value for first primal variable of ADMM
+ v: numpy.ndarray
+ Initial value for second primal variable of ADMM
+ mu: numpy.ndarray
+ Initial value for lagrangian multiplier.
+ A : modopt.opt.linear.LinearOperator
+ Linear operator for u
+ B: modopt.opt.linear.LinearOperator
+ Linear operator for v
+ b : numpy.ndarray
+ Constraint vector
+ optimizers: tuple
+ 2-tuple of callable, that are the optimizers for the u and v.
+ Each callable should access init and obs argument and returns an estimate for:
+ .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
+ .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
+ cost_funcs: tuple
+ 2-tuple of callable, that compute values of H and G.
+ tau: float, default=1
+ Coupling parameter for ADMM.
+ eta: float, default=0.999
+ Convergence parameter for ADMM.
+ alpha: float, default=1.
+ Initial value for the FISTA-like acceleration parameter.
+
+ Notes
+ -----
+ This is an accelerated version of the ADMM algorithm. The convergence hypothesis are
+ stronger than for the ADMM algorithm.
+
+ See Also
+ --------
+ ADMM: parent class
+ """
+
+ def __init__(
+ self,
+ u,
+ v,
+ mu,
+ A,
+ B,
+ b,
+ optimizers,
+ cost_funcs=None,
+ alpha=1,
+ eta=0.999,
+ tau=1,
+ **kwargs,
+ ):
+ super().__init__(
+ u=u,
+ v=b,
+ mu=mu,
+ A=A,
+ B=B,
+ b=b,
+ optimizers=optimizers,
+ cost_funcs=cost_funcs,
+ **kwargs,
+ )
+ self._c_old = np.inf
+ self._c_new = 0
+ self._eta = eta
+ self._alpha_old = alpha
+ self._alpha_new = alpha
+ self._v_hat = self.xp.copy(self._v_new)
+ self._mu_hat = self.xp.copy(self._mu_new)
+
+ def _update(self):
+ # Classical ADMM steps
+ self._u_new = self._opti_H(
+ init=self._u_old,
+ obs=self.B.op(self._v_hat) + self._u_old - self.b,
+ )
+ tmp = self.A.op(self._u_new)
+ self._v_new = self._opti_G(
+ init=self._v_hat,
+ obs=tmp + self._u_old - self.b,
+ )
+
+ self._mu_new = self._mu_hat + (tmp + self.B.op(self._v_new) - self.b)
+
+ # restarting condition
+ self._c_new = self.xp.linalg.norm(self._mu_new - self._mu_hat)
+ self._c_new += self._tau * self.xp.linalg.norm(
+ self.B.op(self._v_new - self._v_hat),
+ )
+ if self._c_new < self._eta * self._c_old:
+ self._alpha_new = 1 + np.sqrt(1 + 4 * self._alpha_old**2)
+ beta = (self._alpha_new - 1) / self._alpha_old
+ self._v_hat = self._v_new + (self._v_new - self._v_old) * beta
+ self._mu_hat = self._mu_new + (self._mu_new - self._mu_old) * beta
+ else:
+ # reboot to old iteration
+ self._alpha_new = 1
+ self._v_hat = self._v_old
+ self._mu_hat = self._mu_old
+ self._c_new = self._c_old / self._eta
+
+ self.xp.copyto(self._u_old, self._u_new)
+ self.xp.copyto(self._v_old, self._v_new)
+ self.xp.copyto(self._mu_old, self._mu_new)
+ # Test cost function for convergence.
+ if self._cost_func:
+ self.converge = self.any_convergence_flag()
+ self.convergd |= self._cost_func.get_cost(self._u_new, self._v_new)
diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
similarity index 67%
rename from modopt/opt/algorithms/base.py
rename to src/modopt/opt/algorithms/base.py
index 85c36306..f7391063 100644
--- a/modopt/opt/algorithms/base.py
+++ b/src/modopt/opt/algorithms/base.py
@@ -1,10 +1,9 @@
-# -*- coding: utf-8 -*-
"""Base SetUp for optimisation algorithms."""
from inspect import getmro
import numpy as np
-from progressbar import ProgressBar
+from tqdm.auto import tqdm
from modopt.base import backend
from modopt.base.observable import MetricObserver, Observable
@@ -12,17 +11,17 @@
class SetUp(Observable):
- r"""Algorithm Set-Up.
+ """Algorithm Set-Up.
This class contains methods for checking the set-up of an optimisation
- algotithm and produces warnings if they do not comply.
+ algorithm and produces warnings if they do not comply.
Parameters
----------
metric_call_period : int, optional
Metric call period (default is ``5``)
metrics : dict, optional
- Metrics to be used (default is ``\{\}``)
+ Metrics to be used (default is ``None``)
verbose : bool, optional
Option for verbose output (default is ``False``)
progress : bool, optional
@@ -34,11 +33,32 @@ class SetUp(Observable):
use_gpu : bool, optional
Option to use available GPU
+ Notes
+ -----
+ If provided, the ``metrics`` argument should be a nested dictionary of the
+ following form::
+
+ metrics = {
+ 'metric_name': {
+ 'metric': callable,
+ 'mapping': {'x_new': 'test'},
+ 'cst_kwargs': {'ref': ref_image},
+ 'early_stopping': False,
+ }
+ }
+
+ Where ``callable`` is a function with arguments being for instance
+ ``test`` and ``ref``. The mapping of the argument uses the same keys as the
+ output of ``get_notify_observer_kwargs``, ``cst_kwargs`` defines constant
+ arguments that will always be passed to the metric call.
+ If ``early_stopping`` is True, the metric will be used to check for
+ convergence of the algorithm, in that case it is recommended to have
+ ``metric_call_period = 1``
+
See Also
--------
modopt.base.observable.Observable : parent class
modopt.base.observable.MetricObserver : definition of metrics
-
"""
def __init__(
@@ -48,7 +68,7 @@ def __init__(
verbose=False,
progress=True,
step_size=None,
- compute_backend='numpy',
+ compute_backend="numpy",
**dummy_kwargs,
):
self.idx = 0
@@ -58,26 +78,26 @@ def __init__(
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
- 'GradParent',
- 'ProximityParent',
- 'LinearParent',
- 'costObj',
+ "GradParent",
+ "ProximityParent",
+ "LinearParent",
+ "costObj",
)
self.metric_call_period = metric_call_period
# Declaration of observers for metrics
- super().__init__(['cv_metrics'])
+ super().__init__(["cv_metrics"])
for name, dic in self.metrics.items():
observer = MetricObserver(
name,
- dic['metric'],
- dic['mapping'],
- dic['cst_kwargs'],
- dic['early_stopping'],
+ dic["metric"],
+ dic["mapping"],
+ dic["cst_kwargs"],
+ dic["early_stopping"],
)
- self.add_observer('cv_metrics', observer)
+ self.add_observer("cv_metrics", observer)
xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
@@ -90,14 +110,13 @@ def metrics(self):
@metrics.setter
def metrics(self, metrics):
-
if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
self._metrics = metrics
else:
raise TypeError(
- 'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
+ f"Metrics must be a dictionary, not {type(metrics)}.",
)
def any_convergence_flag(self):
@@ -111,9 +130,7 @@ def any_convergence_flag(self):
True if any convergence criteria met
"""
- return any(
- obs.converge_flag for obs in self._observers['cv_metrics']
- )
+ return any(obs.converge_flag for obs in self._observers["cv_metrics"])
def copy_data(self, input_data):
"""Copy Data.
@@ -131,10 +148,12 @@ def copy_data(self, input_data):
Copy of input data
"""
- return self.xp.copy(backend.change_backend(
- input_data,
- self.compute_backend,
- ))
+ return self.xp.copy(
+ backend.change_backend(
+ input_data,
+ self.compute_backend,
+ )
+ )
def _check_input_data(self, input_data):
"""Check input data type.
@@ -154,7 +173,7 @@ def _check_input_data(self, input_data):
"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
- 'Input data must be a numpy array or backend array',
+ "Input data must be a numpy array or backend array",
)
def _check_param(self, param_val):
@@ -174,7 +193,7 @@ def _check_param(self, param_val):
"""
if not isinstance(param_val, float):
- raise TypeError('Algorithm parameter must be a float value.')
+ raise TypeError("Algorithm parameter must be a float value.")
def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.
@@ -192,14 +211,13 @@ def _check_param_update(self, param_update):
For invalid input type
"""
- param_conditions = (
- not isinstance(param_update, type(None))
- and not callable(param_update)
+ param_conditions = not isinstance(param_update, type(None)) and not callable(
+ param_update
)
if param_conditions:
raise TypeError(
- 'Algorithm parameter update must be a callabale function.',
+ "Algorithm parameter update must be a callabale function.",
)
def _check_operator(self, operator):
@@ -218,7 +236,7 @@ def _check_operator(self, operator):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]
if not any(parent in tree for parent in self._op_parents):
- message = '{0} does not inherit an operator parent.'
+ message = "{0} does not inherit an operator parent."
warn(message.format(str(operator.__class__)))
def _compute_metrics(self):
@@ -229,7 +247,7 @@ def _compute_metrics(self):
"""
kwargs = self.get_notify_observers_kwargs()
- self.notify_observers('cv_metrics', **kwargs)
+ self.notify_observers("cv_metrics", **kwargs)
def _iterations(self, max_iter, progbar=None):
"""Iterate method.
@@ -240,9 +258,8 @@ def _iterations(self, max_iter, progbar=None):
----------
max_iter : int
Maximum number of iterations
- progbar : progressbar.bar.ProgressBar
- Progress bar (default is ``None``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
for idx in range(max_iter):
self.idx = idx
@@ -253,7 +270,6 @@ def _iterations(self, max_iter, progbar=None):
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:
-
metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
@@ -265,13 +281,13 @@ def _iterations(self, max_iter, progbar=None):
if self.converge:
if self.verbose:
- print(' - Converged!')
+ print(" - Converged!")
break
- if not isinstance(progbar, type(None)):
- progbar.update(idx)
+ if progbar:
+ progbar.update()
- def _run_alg(self, max_iter):
+ def _run_alg(self, max_iter, progbar=None):
"""Run algorithm.
Run the update step of a given algorithm up to the maximum number of
@@ -281,17 +297,34 @@ def _run_alg(self, max_iter):
----------
max_iter : int
Maximum number of iterations
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
See Also
--------
- progressbar.bar.ProgressBar
+ tqdm.tqdm
"""
- if self.progress:
- with ProgressBar(
- redirect_stdout=True,
- max_value=max_iter,
- ) as progbar:
- self._iterations(max_iter, progbar=progbar)
+ if self.progress and progbar is None:
+ with tqdm(total=max_iter) as pb:
+ self._iterations(max_iter, progbar=pb)
+ elif progbar:
+ self._iterations(max_iter, progbar=progbar)
else:
self._iterations(max_iter)
+
+ def _update(self):
+ raise NotImplementedError
+
+ def get_notify_observers_kwargs(self):
+ """Notify Observers.
+
+ Return the mapping between the metrics call and the iterated
+ variables.
+
+ Raises
+ ------
+ NotImplementedError
+ This method should be overriden by subclasses.
+ """
+ raise NotImplementedError
diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
similarity index 86%
rename from modopt/opt/algorithms/forward_backward.py
rename to src/modopt/opt/algorithms/forward_backward.py
index e18f66c3..31927eb0 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/src/modopt/opt/algorithms/forward_backward.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Forward-Backward Algorithms."""
import numpy as np
@@ -9,7 +8,7 @@
from modopt.opt.linear import Identity
-class FISTA(object):
+class FISTA:
r"""FISTA.
This class is inherited by optimisation classes to speed up convergence
@@ -52,12 +51,12 @@ class FISTA(object):
"""
_restarting_strategies = (
- 'adaptive', # option 1 in alg 4
- 'adaptive-i',
- 'adaptive-1',
- 'adaptive-ii', # option 2 in alg 4
- 'adaptive-2',
- 'greedy', # alg 5
+ "adaptive", # option 1 in alg 4
+ "adaptive-i",
+ "adaptive-1",
+ "adaptive-ii", # option 2 in alg 4
+ "adaptive-2",
+ "greedy", # alg 5
None, # no restarting
)
@@ -73,26 +72,28 @@ def __init__(
r_lazy=4,
**kwargs,
):
-
if isinstance(a_cd, type(None)):
- self.mode = 'regular'
+ self.mode = "regular"
self.p_lazy = p_lazy
self.q_lazy = q_lazy
self.r_lazy = r_lazy
elif a_cd > 2:
- self.mode = 'CD'
+ self.mode = "CD"
self.a_cd = a_cd
self._n = 0
else:
raise ValueError(
- 'a_cd must either be None (for regular mode) or a number > 2',
+ "a_cd must either be None (for regular mode) or a number > 2",
)
if restart_strategy in self._restarting_strategies:
self._check_restart_params(
- restart_strategy, min_beta, s_greedy, xi_restart,
+ restart_strategy,
+ min_beta,
+ s_greedy,
+ xi_restart,
)
self.restart_strategy = restart_strategy
self.min_beta = min_beta
@@ -100,10 +101,10 @@ def __init__(
self.xi_restart = xi_restart
else:
- message = 'Restarting strategy must be one of {0}.'
+ message = "Restarting strategy must be one of {0}."
raise ValueError(
message.format(
- ', '.join(self._restarting_strategies),
+ ", ".join(self._restarting_strategies),
),
)
self._t_now = 1.0
@@ -155,22 +156,20 @@ def _check_restart_params(
if restart_strategy is None:
return True
- if self.mode != 'regular':
+ if self.mode != "regular":
raise ValueError(
- 'Restarting strategies can only be used with regular mode.',
+ "Restarting strategies can only be used with regular mode.",
)
- greedy_params_check = (
- min_beta is None or s_greedy is None or s_greedy <= 1
- )
+ greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1
- if restart_strategy == 'greedy' and greedy_params_check:
+ if restart_strategy == "greedy" and greedy_params_check:
raise ValueError(
- 'You need a min_beta and an s_greedy > 1 for greedy restart.',
+ "You need a min_beta and an s_greedy > 1 for greedy restart.",
)
if xi_restart is None or xi_restart >= 1:
- raise ValueError('You need a xi_restart < 1 for restart.')
+ raise ValueError("You need a xi_restart < 1 for restart.")
return True
@@ -210,12 +209,12 @@ def is_restart(self, z_old, x_new, x_old):
criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0
if criterion:
- if 'adaptive' in self.restart_strategy:
+ if "adaptive" in self.restart_strategy:
self.r_lazy *= self.xi_restart
- if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
+ if self.restart_strategy in {"adaptive-ii", "adaptive-2"}:
self._t_now = 1
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
cur_delta = xp.linalg.norm(x_new - x_old)
if self._delta0 is None:
self._delta0 = self.s_greedy * cur_delta
@@ -269,17 +268,17 @@ def update_lambda(self, *args, **kwargs):
Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`.
"""
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
return 2
# Steps 3 and 4 from alg.10.7.
self._t_prev = self._t_now
- if self.mode == 'regular':
- sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy
+ if self.mode == "regular":
+ sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy
self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5
- elif self.mode == 'CD':
+ elif self.mode == "CD":
self._t_now = (self._n + self.a_cd - 1) / self.a_cd
self._n += 1
@@ -344,18 +343,17 @@ def __init__(
x,
grad,
prox,
- cost='auto',
+ cost="auto",
beta_param=1.0,
lambda_param=1.0,
beta_update=None,
- lambda_update='fista',
+ lambda_update="fista",
auto_iterate=True,
metric_call_period=5,
metrics=None,
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -376,7 +374,7 @@ def __init__(
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -384,7 +382,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -400,7 +398,7 @@ def __init__(
# Set the algorithm parameter update methods
self._check_param_update(beta_update)
self._beta_update = beta_update
- if isinstance(lambda_update, str) and lambda_update == 'fista':
+ if isinstance(lambda_update, str) and lambda_update == "fista":
fista = FISTA(**kwargs)
self._lambda_update = fista.update_lambda
self._is_restart = fista.is_restart
@@ -462,12 +460,11 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either the convergence criteria is met
@@ -477,9 +474,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -499,9 +497,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z_new,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -512,7 +510,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -576,7 +574,7 @@ def __init__(
x,
grad,
prox_list,
- cost='auto',
+ cost="auto",
gamma_param=1.0,
lambda_param=1.0,
gamma_update=None,
@@ -588,7 +586,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -601,22 +598,22 @@ def __init__(
self._x_old = self.xp.copy(x)
# Set the algorithm operators
- for operator in [grad, cost] + prox_list:
+ for operator in [grad, cost, *prox_list]:
self._check_operator(operator)
self._grad = grad
self._prox_list = self.xp.array(prox_list)
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([self._grad] + prox_list)
+ if cost == "auto":
+ self._cost_func = costObj([self._grad, *prox_list])
else:
self._cost_func = cost
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -640,9 +637,7 @@ def __init__(
self._set_weights(weights)
# Set initial z
- self._z = self.xp.array([
- self._x_old for i in range(self._prox_list.size)
- ])
+ self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)])
# Automatically run the algorithm
if auto_iterate:
@@ -672,25 +667,25 @@ def _set_weights(self, weights):
self._prox_list.size,
)
elif not isinstance(weights, (list, tuple, np.ndarray)):
- raise TypeError('Weights must be provided as a list.')
+ raise TypeError("Weights must be provided as a list.")
weights = self.xp.array(weights)
if not np.issubdtype(weights.dtype, np.floating):
- raise ValueError('Weights must be list of float values.')
+ raise ValueError("Weights must be list of float values.")
if weights.size != self._prox_list.size:
raise ValueError(
- 'The number of weights must match the number of proximity '
- + 'operators.',
+ "The number of weights must match the number of proximity "
+ + "operators.",
)
expected_weight_sum = 1.0
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
- 'Proximity operator weights must sum to 1.0. Current sum of '
- + 'weights = {0}'.format(self.xp.sum(weights)),
+ "Proximity operator weights must sum to 1.0. Current sum of "
+ + f"weights = {self.xp.sum(weights)}",
)
self._weights = weights
@@ -725,9 +720,7 @@ def _update(self):
# Update z values.
for i in range(self._prox_list.size):
- z_temp = (
- 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
- )
+ z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
z_prox = self._prox_list[i].op(
z_temp,
extra_factor=self._gamma / self._weights[i],
@@ -750,7 +743,7 @@ def _update(self):
if self._cost_func:
self.converge = self._cost_func.get_cost(self._x_new)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -760,9 +753,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -782,9 +776,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -795,7 +789,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -815,9 +809,9 @@ class POGM(SetUp):
Initial guess for the :math:`y` variable
z : numpy.ndarray
Initial guess for the :math:`z` variable
- grad
+ grad : GradBasic
Gradient operator class
- prox
+ prox : ProximalParent
Proximity operator class
cost : class instance or str, optional
Cost function class instance (default is ``'auto'``); Use ``'auto'`` to
@@ -869,7 +863,7 @@ def __init__(
z,
grad,
prox,
- cost='auto',
+ cost="auto",
linear=None,
beta_param=1.0,
sigma_bar=1.0,
@@ -878,7 +872,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -903,7 +896,7 @@ def __init__(
self._grad = grad
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -916,7 +909,7 @@ def __init__(
for param_val in (beta_param, sigma_bar):
self._check_param(param_val)
if sigma_bar < 0 or sigma_bar > 1:
- raise ValueError('The sigma bar parameter needs to be in [0, 1]')
+ raise ValueError("The sigma bar parameter needs to be in [0, 1]")
self._beta = self.step_size or beta_param
self._sigma_bar = sigma_bar
@@ -942,16 +935,18 @@ def _update(self):
"""
# Step 4 from alg. 3
self._grad.get_grad(self._x_old)
- self._u_new = self._x_old - self._beta * self._grad.grad
+ # self._u_new = self._x_old - self._beta * self._grad.grad
+ self._u_new = -self._beta * self._grad.grad
+ self._u_new += self._x_old
# Step 5 from alg. 3
- self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
+ self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))
# Step 6 from alg. 3
t_shifted_ratio = (self._t_old - 1) / self._t_new
sigma_t_ratio = self._sigma * self._t_old / self._t_new
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
- self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
+ self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
self._z += self._u_new
self._z += t_shifted_ratio * (self._u_new - self._u_old)
self._z += sigma_t_ratio * (self._u_new - self._x_old)
@@ -964,15 +959,18 @@ def _update(self):
# Restarting and gamma-Decreasing
# Step 9 from alg. 3
- self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ self._g_new = self._z - self._x_new
+ self._g_new /= self._xi
+ self._g_new += self._grad.grad
# Step 10 from alg 3.
- self._y_new = self._x_old - self._beta * self._g_new
+ # self._y_new = self._x_old - self._beta * self._g_new
+ self._y_new = -self._beta * self._g_new
+ self._y_new += self._x_old
# Step 11 from alg. 3
- restart_crit = (
- self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
- )
+ restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
if restart_crit:
self._t_new = 1
self._sigma = 1
@@ -990,12 +988,11 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -1005,9 +1002,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -1027,14 +1025,14 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'u_new': self._u_new,
- 'x_new': self._linear.adj_op(self._x_new),
- 'y_new': self._y_new,
- 'z_new': self._z,
- 'xi': self._xi,
- 'sigma': self._sigma,
- 't': self._t_new,
- 'idx': self.idx,
+ "u_new": self._u_new,
+ "x_new": self._linear.adj_op(self._x_new),
+ "y_new": self._y_new,
+ "z_new": self._z,
+ "xi": self._xi,
+ "sigma": self._sigma,
+ "t": self._t_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -1045,6 +1043,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py
similarity index 95%
rename from modopt/opt/algorithms/gradient_descent.py
rename to src/modopt/opt/algorithms/gradient_descent.py
index f3fe4b10..0960be5a 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/src/modopt/opt/algorithms/gradient_descent.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Gradient Descent Algorithms."""
import numpy as np
@@ -103,7 +102,7 @@ def __init__(
self._check_operator(operator)
self._grad = grad
self._prox = prox
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -157,9 +156,8 @@ def _update(self):
self._eta = self._eta_update(self._eta, self.idx)
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def _update_grad_dir(self, grad):
@@ -208,10 +206,10 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._x_new,
- 'dir_grad': self._dir_grad,
- 'speed_grad': self._speed_grad,
- 'idx': self.idx,
+ "x_new": self._x_new,
+ "dir_grad": self._dir_grad,
+ "speed_grad": self._speed_grad,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -222,7 +220,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -308,7 +306,7 @@ class RMSpropGradOpt(GenericGradOpt):
def __init__(self, *args, gamma=0.5, **kwargs):
super().__init__(*args, **kwargs)
if gamma < 0 or gamma > 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
self._check_param(gamma)
self._gamma = gamma
@@ -405,9 +403,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs):
self._check_param(gamma)
self._check_param(beta)
if gamma < 0 or gamma >= 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
if beta < 0 or beta >= 1:
- raise ValueError('beta is outside of range [0,1]')
+ raise ValueError("beta is outside of range [0,1]")
self._gamma = gamma
self._beta = beta
self._beta_pow = 1
diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
similarity index 86%
rename from modopt/opt/algorithms/primal_dual.py
rename to src/modopt/opt/algorithms/primal_dual.py
index c8566969..fee49a25 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/src/modopt/opt/algorithms/primal_dual.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Primal-Dual Algorithms."""
from modopt.opt.algorithms.base import SetUp
@@ -81,7 +80,7 @@ def __init__(
prox,
prox_dual,
linear=None,
- cost='auto',
+ cost="auto",
reweight=None,
rho=0.5,
sigma=1.0,
@@ -96,7 +95,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -123,12 +121,14 @@ def __init__(
self._linear = Identity()
else:
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([
- self._grad,
- self._prox,
- self._prox_dual,
- ])
+ if cost == "auto":
+ self._cost_func = costObj(
+ [
+ self._grad,
+ self._prox,
+ self._prox_dual,
+ ]
+ )
else:
self._cost_func = cost
@@ -187,22 +187,17 @@ def _update(self):
self._grad.get_grad(self._x_old)
x_prox = self._prox.op(
- self._x_old - self._tau * self._grad.grad - self._tau
- * self._linear.adj_op(self._y_old),
+ self._x_old
+ - self._tau * self._grad.grad
+ - self._tau * self._linear.adj_op(self._y_old),
)
# Step 2 from eq.9.
- y_temp = (
- self._y_old + self._sigma
- * self._linear.op(2 * x_prox - self._x_old)
- )
+ y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
- y_prox = (
- y_temp - self._sigma
- * self._prox_dual.op(
- y_temp / self._sigma,
- extra_factor=(1.0 / self._sigma),
- )
+ y_prox = y_temp - self._sigma * self._prox_dual.op(
+ y_temp / self._sigma,
+ extra_factor=(1.0 / self._sigma),
)
# Step 3 from eq.9.
@@ -220,12 +215,11 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new, self._y_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new, self._y_new
)
- def iterate(self, max_iter=150, n_rewightings=1):
+ def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -237,14 +231,17 @@ def iterate(self, max_iter=150, n_rewightings=1):
Maximum number of iterations (default is ``150``)
n_rewightings : int, optional
Number of reweightings to perform (default is ``1``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
if not isinstance(self._reweight, type(None)):
for _ in range(n_rewightings):
self._reweight.reweight(self._linear.op(self._x_new))
- self._run_alg(max_iter)
+ if progbar:
+ progbar.reset(total=max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -264,7 +261,7 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
- return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}
+ return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
def retrieve_outputs(self):
"""Retrieve outputs.
@@ -274,6 +271,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py
similarity index 67%
rename from modopt/opt/cost.py
rename to src/modopt/opt/cost.py
index 3cdfcc50..37771f16 100644
--- a/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -6,6 +6,8 @@
"""
+import abc
+
import numpy as np
from modopt.base.backend import get_array_module
@@ -13,8 +15,8 @@
from modopt.plot.cost_plot import plotCost
-class costObj(object):
- """Generic cost function object.
+class CostParent(abc.ABC):
+ """Abstract cost function object.
This class updates the cost according to the input operator classes and
tests for convergence.
@@ -40,7 +42,8 @@ class costObj(object):
Notes
-----
- The costFunc class must contain a method called ``cost``.
+ All child classes should implement a ``_calc_cost`` method (returning
+ a float) or a ``get_cost`` for more complex behavior on convergence test.
Examples
--------
@@ -71,7 +74,6 @@ class costObj(object):
def __init__(
self,
- operators,
initial_cost=1e6,
tolerance=1e-4,
cost_interval=1,
@@ -79,10 +81,6 @@ def __init__(
verbose=True,
plot_output=None,
):
-
- self._operators = operators
- if not isinstance(operators, type(None)):
- self._check_operators()
self.cost = initial_cost
self._cost_list = []
self._cost_interval = cost_interval
@@ -93,30 +91,6 @@ def __init__(
self._plot_output = plot_output
self._verbose = verbose
- def _check_operators(self):
- """Check operators.
-
- This method checks if the input operators have a ``cost`` method.
-
- Raises
- ------
- TypeError
- For invalid operators type
- ValueError
- For operators without ``cost`` method
-
- """
- if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
- raise TypeError(message.format(type(self._operators)))
-
- for op in self._operators:
- if not hasattr(op, 'cost'):
- raise ValueError('Operators must contain "cost" method.')
- op.cost = check_callable(op.cost)
-
def _check_cost(self):
"""Check cost function.
@@ -137,20 +111,19 @@ def _check_cost(self):
# Check if enough cost values have been collected
if len(self._test_list) == self._test_range:
-
# The mean of the first half of the test list
t1 = xp.mean(
- xp.array(self._test_list[len(self._test_list) // 2:]),
+ xp.array(self._test_list[len(self._test_list) // 2 :]),
axis=0,
)
# The mean of the second half of the test list
t2 = xp.mean(
- xp.array(self._test_list[:len(self._test_list) // 2]),
+ xp.array(self._test_list[: len(self._test_list) // 2]),
axis=0,
)
# Calculate the change across the test list
if xp.around(t1, decimals=16):
- cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
+ cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)
else:
cost_diff = 0
@@ -158,15 +131,16 @@ def _check_cost(self):
self._test_list = []
if self._verbose:
- print(' - CONVERGENCE TEST - ')
- print(' - CHANGE IN COST:', cost_diff)
- print('')
+ print(" - CONVERGENCE TEST - ")
+ print(" - CHANGE IN COST:", cost_diff)
+ print("")
# Check for convergence
return cost_diff <= self._tolerance
return False
+ @abc.abstractmethod
def _calc_cost(self, *args, **kwargs):
"""Calculate the cost.
@@ -178,14 +152,7 @@ def _calc_cost(self, *args, **kwargs):
Positional arguments
**kwargs : dict
Keyword arguments
-
- Returns
- -------
- float
- Cost value
-
"""
- return np.sum([op.cost(*args, **kwargs) for op in self._operators])
def get_cost(self, *args, **kwargs):
"""Get cost function.
@@ -207,8 +174,7 @@ def get_cost(self, *args, **kwargs):
"""
# Check if the cost should be calculated
test_conditions = (
- self._cost_interval is None
- or self._iteration % self._cost_interval
+ self._cost_interval is None or self._iteration % self._cost_interval
)
if test_conditions:
@@ -216,15 +182,15 @@ def get_cost(self, *args, **kwargs):
else:
if self._verbose:
- print(' - ITERATION:', self._iteration)
+ print(" - ITERATION:", self._iteration)
# Calculate the current cost
- self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
+ self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
- print(' - COST:', self.cost)
- print('')
+ print(" - COST:", self.cost)
+ print("")
# Test for convergence
test_result = self._check_cost()
@@ -241,3 +207,108 @@ def plot_cost(self): # pragma: no cover
"""
plotCost(self._cost_list, self._plot_output)
+
+
+class costObj(CostParent):
+ """Abstract cost function object.
+
+ This class updates the cost according to the input operator classes and
+ tests for convergence.
+
+ Parameters
+ ----------
+ opertors : list, tuple or numpy.ndarray
+ List of operators classes containing ``cost`` method
+ initial_cost : float, optional
+ Initial value of the cost (default is ``1e6``)
+ tolerance : float, optional
+ Tolerance threshold for convergence (default is ``1e-4``)
+ cost_interval : int, optional
+ Iteration interval to calculate cost (default is ``1``).
+ If ``cost_interval`` is ``None`` the cost is never calculated,
+ thereby saving on computation time.
+ test_range : int, optional
+ Number of cost values to be used in test (default is ``4``)
+ verbose : bool, optional
+ Option for verbose output (default is ``True``)
+ plot_output : str, optional
+ Output file name for cost function plot
+
+ Examples
+ --------
+ >>> from modopt.opt.cost import *
+ >>> class dummy(object):
+ ... def cost(self, x):
+ ... return x ** 2
+ ...
+ ...
+ >>> inst = costObj([dummy(), dummy()])
+ >>> inst.get_cost(2)
+ - ITERATION: 1
+ - COST: 8
+
+ False
+ >>> inst.get_cost(2)
+ - ITERATION: 2
+ - COST: 8
+
+ False
+ >>> inst.get_cost(2)
+ - ITERATION: 3
+ - COST: 8
+
+ False
+ """
+
+ def __init__(
+ self,
+ operators,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self._operators = operators
+ if not isinstance(operators, type(None)):
+ self._check_operators()
+
+ def _check_operators(self):
+ """Check operators.
+
+ This method checks if the input operators have a ``cost`` method.
+
+ Raises
+ ------
+ TypeError
+ For invalid operators type
+ ValueError
+ For operators without ``cost`` method
+
+ """
+ if not isinstance(self._operators, (list, tuple, np.ndarray)):
+ message = "Input operators must be provided as a list, not {0}"
+ raise TypeError(message.format(type(self._operators)))
+
+ for op in self._operators:
+ if not hasattr(op, "cost"):
+ raise ValueError('Operators must contain "cost" method.')
+ op.cost = check_callable(op.cost)
+
+ def _calc_cost(self, *args, **kwargs):
+ """Calculate the cost.
+
+ This method calculates the cost from each of the input operators.
+
+ Parameters
+ ----------
+ *args : tuple
+ Positional arguments
+ **kwargs : dict
+ Keyword arguments
+
+ Returns
+ -------
+ float
+ Cost value
+
+ """
+ return np.sum([op.cost(*args, **kwargs) for op in self._operators])
diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
similarity index 97%
rename from modopt/opt/gradient.py
rename to src/modopt/opt/gradient.py
index caa8fa9d..fe9b87d8 100644
--- a/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""GRADIENT CLASSES.
This module contains classses for defining algorithm gradients.
@@ -14,7 +12,7 @@
from modopt.base.types import check_callable, check_float, check_npndarray
-class GradParent(object):
+class GradParent:
"""Gradient Parent Class.
This class defines the basic methods that will be inherited by specific
@@ -71,7 +69,6 @@ def __init__(
input_data_writeable=False,
verbose=True,
):
-
self.verbose = verbose
self._input_data_writeable = input_data_writeable
self._grad_data_type = data_type
@@ -100,7 +97,6 @@ def obs_data(self):
@obs_data.setter
def obs_data(self, input_data):
-
if self._grad_data_type in {float, np.floating}:
input_data = check_float(input_data)
check_npndarray(
@@ -128,7 +124,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -147,7 +142,6 @@ def trans_op(self):
@trans_op.setter
def trans_op(self, operator):
-
self._trans_op = check_callable(operator)
@property
@@ -157,7 +151,6 @@ def get_grad(self):
@get_grad.setter
def get_grad(self, method):
-
self._get_grad = check_callable(method)
@property
@@ -167,7 +160,6 @@ def grad(self):
@grad.setter
def grad(self, input_value):
-
if self._grad_data_type in {float, np.floating}:
input_value = check_float(input_value)
self._grad = input_value
@@ -179,7 +171,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
def trans_op_op(self, input_data):
@@ -243,7 +234,6 @@ class GradBasic(GradParent):
"""
def __init__(self, *args, **kwargs):
-
super().__init__(*args, **kwargs)
self.get_grad = self._get_grad_method
self.cost = self._cost_method
@@ -289,7 +279,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - DATA FIDELITY (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/src/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py
new file mode 100644
index 00000000..d5c0d21f
--- /dev/null
+++ b/src/modopt/opt/linear/__init__.py
@@ -0,0 +1,21 @@
+"""LINEAR OPERATORS.
+
+This module contains linear operator classes.
+
+:Author: Samuel Farrens
+:Author: Pierre-Antoine Comby
+"""
+
+from .base import LinearParent, Identity, MatrixOperator, LinearCombo
+
+from .wavelet import WaveletConvolve, WaveletTransform
+
+
+__all__ = [
+ "LinearParent",
+ "Identity",
+ "MatrixOperator",
+ "LinearCombo",
+ "WaveletConvolve",
+ "WaveletTransform",
+]
diff --git a/modopt/opt/linear.py b/src/modopt/opt/linear/base.py
similarity index 79%
rename from modopt/opt/linear.py
rename to src/modopt/opt/linear/base.py
index d8679998..af748a73 100644
--- a/modopt/opt/linear.py
+++ b/src/modopt/opt/linear/base.py
@@ -1,20 +1,12 @@
-# -*- coding: utf-8 -*-
-
-"""LINEAR OPERATORS.
-
-This module contains linear operator classes.
-
-:Author: Samuel Farrens
-
-"""
+"""Base classes for linear operators."""
import numpy as np
-from modopt.base.types import check_callable, check_float
-from modopt.signal.wavelet import filter_convolve_stack
+from modopt.base.types import check_callable
+from modopt.base.backend import get_array_module
-class LinearParent(object):
+class LinearParent:
"""Linear Operator Parent Class.
This class sets the structure for defining linear operator instances.
@@ -38,7 +30,6 @@ class LinearParent(object):
"""
def __init__(self, op, adj_op):
-
self.op = op
self.adj_op = adj_op
@@ -49,7 +40,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -59,7 +49,6 @@ def adj_op(self):
@adj_op.setter
def adj_op(self, operator):
-
self._adj_op = check_callable(operator)
@@ -75,45 +64,26 @@ class Identity(LinearParent):
"""
def __init__(self):
-
self.op = lambda input_data: input_data
self.adj_op = self.op
+ self.cost = lambda *args, **kwargs: 0
-class WaveletConvolve(LinearParent):
- """Wavelet Convolution Class.
-
- This class defines the wavelet transform operators via convolution with
- predefined filters.
-
- Parameters
- ----------
- filters: numpy.ndarray
- Array of wavelet filter coefficients
- method : str, optional
- Convolution method (default is ``'scipy'``)
-
- See Also
- --------
- LinearParent : parent class
- modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution
+class MatrixOperator(LinearParent):
+ """
+ Matrix Operator class.
+ This class transforms an array into a suitable linear operator.
"""
- def __init__(self, filters, method='scipy'):
+ def __init__(self, array):
+ self.op = lambda x: array @ x
+ xp = get_array_module(array)
- self._filters = check_float(filters)
- self.op = lambda input_data: filter_convolve_stack(
- input_data,
- self._filters,
- method=method,
- )
- self.adj_op = lambda input_data: filter_convolve_stack(
- input_data,
- self._filters,
- filter_rot=True,
- method=method,
- )
+ if xp.any(xp.iscomplex(array)):
+ self.adj_op = lambda x: array.T.conjugate() @ x
+ else:
+ self.adj_op = lambda x: array.T @ x
class LinearCombo(LinearParent):
@@ -150,11 +120,9 @@ class LinearCombo(LinearParent):
See Also
--------
LinearParent : parent class
-
"""
def __init__(self, operators, weights=None):
-
operators, weights = self._check_inputs(operators, weights)
self.operators = operators
self.weights = weights
@@ -187,14 +155,13 @@ def _check_type(self, input_val):
"""
if not isinstance(input_val, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, input must be a list, tuple or numpy '
- + 'array.',
+ "Invalid input type, input must be a list, tuple or numpy " + "array.",
)
input_val = np.array(input_val)
if not input_val.size:
- raise ValueError('Input list is empty.')
+ raise ValueError("Input list is empty.")
return input_val
@@ -227,11 +194,10 @@ def _check_inputs(self, operators, weights):
operators = self._check_type(operators)
for operator in operators:
-
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'adj_op'):
+ if not hasattr(operator, "adj_op"):
raise ValueError('Operators must contain "adj_op" method.')
operator.op = check_callable(operator.op)
@@ -242,12 +208,11 @@ def _check_inputs(self, operators, weights):
if weights.size != operators.size:
raise ValueError(
- 'The number of weights must match the number of '
- + 'operators.',
+ "The number of weights must match the number of " + "operators.",
)
if not np.issubdtype(weights.dtype, np.floating):
- raise TypeError('The weights must be a list of float values.')
+ raise TypeError("The weights must be a list of float values.")
return operators, weights
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
new file mode 100644
index 00000000..8012a072
--- /dev/null
+++ b/src/modopt/opt/linear/wavelet.py
@@ -0,0 +1,487 @@
+#!/usr/bin/env python3
+"""Wavelet operator, using either scipy filter or pywavelet."""
+import warnings
+
+import numpy as np
+
+from modopt.base.types import check_float
+from modopt.signal.wavelet import filter_convolve_stack
+
+from .base import LinearParent
+
+pywt_available = True
+try:
+ import pywt
+ from joblib import Parallel, cpu_count, delayed
+except ImportError:
+ pywt_available = False
+
+ptwt_available = True
+try:
+ import ptwt
+ import torch
+ import cupy as cp
+except ImportError:
+ ptwt_available = False
+
+
+class WaveletConvolve(LinearParent):
+ """Wavelet Convolution Class.
+
+ This class defines the wavelet transform operators via convolution with
+ predefined filters.
+
+ Parameters
+ ----------
+ filters: numpy.ndarray
+ Array of wavelet filter coefficients
+ method : str, optional
+ Convolution method (default is ``'scipy'``)
+
+ See Also
+ --------
+ LinearParent : parent class
+ modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution
+
+ """
+
+ def __init__(self, filters, method="scipy"):
+ self._filters = check_float(filters)
+ self.op = lambda input_data: filter_convolve_stack(
+ input_data,
+ self._filters,
+ method=method,
+ )
+ self.adj_op = lambda input_data: filter_convolve_stack(
+ input_data,
+ self._filters,
+ filter_rot=True,
+ method=method,
+ )
+
+
+class WaveletTransform(LinearParent):
+ """
+ 2D and 3D wavelet transform class.
+
+ This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU).
+
+ Parameters
+ ----------
+ wavelet_name: str
+ the wavelet name to be used during the decomposition.
+ shape: tuple[int,...]
+ Shape of the input data. The shape should be a tuple of length 2 or 3.
+ It should not contains coils or batch dimension.
+ nb_scales: int, default 4
+ the number of scales in the decomposition.
+ mode: str, default "zero"
+ Boundary Condition mode
+ compute_backend: str, "numpy" or "cupy", default "numpy"
+ Backend library to use. "cupy" also requires a working installation of PyTorch
+ and PyTorch wavelets (ptwt).
+
+ **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
+ """
+
+ def __init__(
+ self,
+ wavelet_name,
+ shape,
+ level=4,
+ mode="symmetric",
+ compute_backend="numpy",
+ **kwargs,
+ ):
+ if compute_backend == "cupy" and ptwt_available:
+ self.operator = CupyWaveletTransform(
+ wavelet=wavelet_name, shape=shape, level=level, mode=mode
+ )
+ elif compute_backend == "numpy" and pywt_available:
+ self.operator = CPUWaveletTransform(
+ wavelet_name=wavelet_name, shape=shape, level=level, **kwargs
+ )
+ else:
+ raise ValueError(f"Compute Backend {compute_backend} not available")
+
+ self.op = self.operator.op
+ self.adj_op = self.operator.adj_op
+
+ @property
+ def coeffs_shape(self):
+ """Get the coeffs shapes."""
+ return self.operator.coeffs_shape
+
+
+class CPUWaveletTransform(LinearParent):
+ """
+ 2D and 3D wavelet transform class.
+
+ This is a light wrapper around PyWavelet, with multicoil support.
+
+ Parameters
+ ----------
+ wavelet_name: str
+ the wavelet name to be used during the decomposition.
+ shape: tuple[int,...]
+ Shape of the input data. The shape should be a tuple of length 2 or 3.
+ It should not contains coils or batch dimension.
+ nb_scales: int, default 4
+ the number of scales in the decomposition.
+ n_batchs: int, default 1
+ the number of channel/ batch dimension
+ n_jobs: int, default 1
+ the number of cores to use for multichannel.
+ backend: str, default "threading"
+ the backend to use for parallel multichannel linear operation.
+ verbose: int, default 0
+ the verbosity level.
+
+ Attributes
+ ----------
+ nb_scale: int
+ number of scale decomposed in wavelet space.
+ n_jobs: int
+ number of jobs for parallel computation
+ n_batchs: int
+ number of coils use f
+ backend: str
+ Backend use for parallel computation
+ verbose: int
+ Verbosity level
+ """
+
+ def __init__(
+ self,
+ wavelet_name,
+ shape,
+ level=4,
+ n_batch=1,
+ n_jobs=1,
+ decimated=True,
+ backend="threading",
+ mode="symmetric",
+ ):
+ if not pywt_available:
+ raise ImportError(
+ "PyWavelet and/or joblib are not available. "
+ "Please install it to use WaveletTransform."
+ )
+ if wavelet_name not in pywt.wavelist(kind="all"):
+ raise ValueError(
+ "Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``"
+ )
+
+ self.wavelet = wavelet_name
+ if isinstance(shape, int):
+ shape = (shape,)
+ self.shape = shape
+ self.n_jobs = n_jobs
+ self.mode = mode
+ self.level = level
+ if not decimated:
+ raise NotImplementedError(
+ "Undecimated Wavelet Transform is not implemented yet."
+ )
+ ca, *cds = pywt.wavedecn_shapes(
+ self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level
+ )
+ self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()]
+
+ if len(shape) > 1:
+ self.dwt = pywt.wavedecn
+ self.idwt = pywt.waverecn
+ self._pywt_fun = "wavedecn"
+ else:
+ self.dwt = pywt.wavedec
+ self.idwt = pywt.waverec
+ self._pywt_fun = "wavedec"
+
+ self.n_batch = n_batch
+ if self.n_batch == 1 and self.n_jobs != 1:
+ warnings.warn(
+ "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1
+ )
+ self.n_jobs = 1
+ self.backend = backend
+ n_proc = self.n_jobs
+ if n_proc < 0:
+ n_proc = cpu_count() + self.n_jobs + 1
+
+ def op(self, data):
+ """Define the wavelet operator.
+
+ This method returns the input data convolved with the wavelet filter.
+
+ Parameters
+ ----------
+ data: ndarray or Image
+ input 2D data array.
+
+ Returns
+ -------
+ coeffs: ndarray
+ the wavelet coefficients.
+ """
+ if self.n_batch > 1:
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip(
+ *Parallel(
+ n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
+ )(delayed(self._op)(data[i]) for i in np.arange(self.n_batch))
+ )
+ coeffs = np.asarray(coeffs)
+ else:
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data)
+ return coeffs
+
+ def _op(self, data):
+ """Single coil wavelet transform."""
+ return pywt.ravel_coeffs(
+ self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet)
+ )
+
+ def adj_op(self, coeffs):
+ """Define the wavelet adjoint operator.
+
+ This method returns the reconstructed image.
+
+ Parameters
+ ----------
+ coeffs: ndarray
+ the wavelet coefficients.
+
+ Returns
+ -------
+ data: ndarray
+ the reconstructed data.
+ """
+ if self.n_batch > 1:
+ images = Parallel(
+ n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
+ )(
+ delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i])
+ for i in np.arange(self.n_batch)
+ )
+ images = np.asarray(images)
+ else:
+ images = self._adj_op(coeffs)
+ return images
+
+ def _adj_op(self, coeffs):
+ """Single coil inverse wavelet transform."""
+ return self.idwt(
+ pywt.unravel_coeffs(
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun
+ ),
+ wavelet=self.wavelet,
+ mode=self.mode,
+ )
+
+
+class TorchWaveletTransform:
+ """Wavelet transform using pytorch."""
+
+ wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
+
+ def __init__(
+ self,
+ shape,
+ wavelet,
+ level,
+ mode,
+ ):
+ self.wavelet = wavelet
+ self.level = level
+ self.shape = shape
+ self.mode = mode
+ self.coeffs_shape = None # will be set after op.
+
+ def op(self, data):
+ """Apply the wavelet decomposition on.
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ 2D or 3D, real or complex data with last axes matching shape of
+ the operator.
+
+ Returns
+ -------
+ list[torch.Tensor]
+ list of tensor each containing the data of a subband.
+ """
+ if data.shape == self.shape:
+ data = data[None, ...] # add a batch dimension
+
+ if len(self.shape) == 2:
+ if torch.is_complex(data):
+ # 2D Complex
+ data_ = torch.view_as_real(data)
+ coeffs_ = ptwt.wavedec2(
+ data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2)
+ )
+ # flatten list of tuple of tensors to a list of tensors
+ coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
+ torch.view_as_complex(cc.contiguous())
+ for c in coeffs_[1:]
+ for cc in c
+ ]
+
+ return coeffs
+ # 2D Real
+ coeffs_ = ptwt.wavedec2(
+ data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1)
+ )
+ return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c]
+
+ if torch.is_complex(data):
+ # 3D Complex
+ data_ = torch.view_as_real(data)
+ coeffs_ = ptwt.wavedec3(
+ data_,
+ self.wavelet,
+ level=self.level,
+ mode=self.mode,
+ axes=(-4, -3, -2),
+ )
+ # flatten list of tuple of tensors to a list of tensors
+ coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
+ torch.view_as_complex(cc.contiguous())
+ for c in coeffs_[1:]
+ for cc in c.values()
+ ]
+
+ return coeffs
+ # 3D Real
+ coeffs_ = ptwt.wavedec3(
+ data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1)
+ )
+ return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]
+
+ def adj_op(self, coeffs):
+ """Apply the wavelet recomposition.
+
+ Parameters
+ ----------
+ list[torch.Tensor]
+ list of tensor each containing the data of a subband.
+
+ Returns
+ -------
+ data: torch.Tensor
+ 2D or 3D, real or complex data with last axes matching shape of the
+ operator.
+
+ """
+ if len(self.shape) == 2:
+ if torch.is_complex(coeffs[0]):
+ ## 2D Complex ##
+ # list of tensor to list of tuple of tensor
+ coeffs = [torch.view_as_real(coeffs[0])] + [
+ tuple(torch.view_as_real(coeffs[i + k]) for k in range(3))
+ for i in range(1, len(coeffs) - 2, 3)
+ ]
+ data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2))
+ return torch.view_as_complex(data.contiguous())
+ ## 2D Real ##
+ coeffs_ = [coeffs[0]] + [
+ tuple(coeffs[i + k] for k in range(3))
+ for i in range(1, len(coeffs) - 2, 3)
+ ]
+ data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1))
+ return data
+
+ if torch.is_complex(coeffs[0]):
+ ## 3D Complex ##
+ # list of tensor to list of tuple of tensor
+ coeffs = [torch.view_as_real(coeffs[0])] + [
+ {
+ v: torch.view_as_real(coeffs[i + k])
+ for k, v in enumerate(self.wavedec3_keys)
+ }
+ for i in range(1, len(coeffs) - 6, 7)
+ ]
+ data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2))
+ return torch.view_as_complex(data.contiguous())
+ ## 3D Real ##
+ coeffs_ = [coeffs[0]] + [
+ {v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)}
+ for i in range(1, len(coeffs) - 6, 7)
+ ]
+ data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1))
+ return data
+
+
+class CupyWaveletTransform(LinearParent):
+ """Wrapper around torch wavelet transform to be compatible with the Modopt API."""
+
+ def __init__(
+ self,
+ shape,
+ wavelet,
+ level,
+ mode,
+ ):
+ self.wavelet = wavelet
+ self.level = level
+ self.shape = shape
+ self.mode = mode
+
+ self.operator = TorchWaveletTransform(
+ shape=shape, wavelet=wavelet, level=level, mode=mode
+ )
+ self.coeffs_shape = None # will be set after op
+
+ def op(self, data):
+ """Define the wavelet operator.
+
+ This method returns the input data convolved with the wavelet filter.
+
+ Parameters
+ ----------
+ data: cp.ndarray
+ input 2D data array.
+
+ Returns
+ -------
+ coeffs: ndarray
+ the wavelet coefficients.
+ """
+ data_ = torch.as_tensor(data)
+ tensor_list = self.operator.op(data_)
+ # flatten the list of tensor to a cupy array
+ # this requires an on device copy...
+ self.coeffs_shape = [c.shape for c in tensor_list]
+ n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape])
+ ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch
+ start = 0
+ for t in tensor_list:
+ stop = start + np.prod(t.shape)
+ ret[start:stop] = cp.asarray(t.flatten())
+ start = stop
+
+ return ret
+
+ def adj_op(self, data):
+ """Define the wavelet adjoint operator.
+
+ This method returns the reconstructed image.
+
+ Parameters
+ ----------
+ coeffs: cp.ndarray
+ the wavelet coefficients.
+
+ Returns
+ -------
+ data: ndarray
+ the reconstructed data.
+ """
+ start = 0
+ tensor_list = [None] * len(self.coeffs_shape)
+ for i, s in enumerate(self.coeffs_shape):
+ stop = start + np.prod(s)
+ tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda")
+ start = stop
+ ret_tensor = self.operator.adj_op(tensor_list)
+ return cp.from_dlpack(ret_tensor)
diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
similarity index 89%
rename from modopt/opt/proximity.py
rename to src/modopt/opt/proximity.py
index f8f368ef..dea862ca 100644
--- a/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PROXIMITY OPERATORS.
This module contains classes of proximity operators for optimisation.
@@ -22,6 +20,7 @@
else:
import_sklearn = True
+from modopt.base.backend import get_array_module
from modopt.base.transform import cube2matrix, matrix2cube
from modopt.base.types import check_callable
from modopt.interface.errors import warn
@@ -31,7 +30,7 @@
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast
-class ProximityParent(object):
+class ProximityParent:
"""Proximity Operator Parent Class.
This class sets the structure for defining proximity operator instances.
@@ -47,7 +46,6 @@ class ProximityParent(object):
"""
def __init__(self, op, cost):
-
self.op = op
self.cost = cost
@@ -58,7 +56,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -78,7 +75,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
@@ -98,9 +94,8 @@ class IdentityProx(ProximityParent):
"""
def __init__(self):
-
- self.op = lambda x_val: x_val
- self.cost = lambda x_val: 0
+ self.op = lambda x_val, *args, **kwargs: x_val
+ self.cost = lambda x_val, *args, **kwargs: 0
class Positivity(ProximityParent):
@@ -116,10 +111,25 @@ class Positivity(ProximityParent):
"""
def __init__(self):
-
- self.op = lambda input_data: positive(input_data)
self.cost = self._cost_method
+ def op(self, input_data, *args, **kwargs):
+ """
+ Make the data positive.
+
+ Parameters
+ ----------
+ input_data: np.ndarray
+ Input array
+ *args, **kwargs: dummy.
+
+ Returns
+ -------
+ np.ndarray
+ Positive data.
+ """
+ return positive(input_data)
+
def _cost_method(self, *args, **kwargs):
"""Calculate positivity component of the cost.
@@ -139,8 +149,8 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - Min (X):', np.min(args[0]))
+ if kwargs.get("verbose"):
+ print(" - Min (X):", np.min(args[0]))
return 0
@@ -166,8 +176,7 @@ class SparseThreshold(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self._thresh_type = thresh_type
@@ -215,10 +224,13 @@ def _cost_method(self, *args, **kwargs):
Sparsity cost component
"""
- cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0])))
+ xp = get_array_module(args[0])
+ cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0])))
+ if isinstance(cost_val, xp.ndarray):
+ cost_val = cost_val.item()
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L1 NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -269,12 +281,11 @@ class LowRankMatrix(ProximityParent):
def __init__(
self,
threshold,
- thresh_type='soft',
- lowr_type='standard',
+ thresh_type="soft",
+ lowr_type="standard",
initial_rank=None,
operator=None,
):
-
self.thresh = threshold
self.thresh_type = thresh_type
self.lowr_type = lowr_type
@@ -311,13 +322,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor
- if self.lowr_type == 'standard' and self.rank is None and rank is None:
+ if self.lowr_type == "standard" and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
- elif self.lowr_type == 'standard':
+ elif self.lowr_type == "standard":
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
@@ -327,7 +338,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
)
self.rank = update_rank # save for future use
- elif self.lowr_type == 'ngole':
+ elif self.lowr_type == "ngole":
data_matrix = svd_thresh_coef(
cube2matrix(input_data),
self.operator,
@@ -335,7 +346,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
thresh_type=self.thresh_type,
)
else:
- raise ValueError('lowr_type should be standard or ngole')
+ raise ValueError("lowr_type should be standard or ngole")
# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
@@ -361,8 +372,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - NUCLEAR NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -466,7 +477,6 @@ class ProximityCombo(ProximityParent):
"""
def __init__(self, operators):
-
operators = self._check_operators(operators)
self.operators = operators
self.op = self._op_method
@@ -502,19 +512,19 @@ def _check_operators(self, operators):
"""
if not isinstance(operators, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, operators must be a list, tuple or '
- + 'numpy array.',
+ "Invalid input type, operators must be a list, tuple or "
+ + "numpy array.",
)
operators = np.array(operators)
if not operators.size:
- raise ValueError('Operator list is empty.')
+ raise ValueError("Operator list is empty.")
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'cost'):
+ if not hasattr(operator, "cost"):
raise ValueError('Operators must contain "cost" method.')
operator.op = check_callable(operator.op)
operator.cost = check_callable(operator.cost)
@@ -569,10 +579,12 @@ def _cost_method(self, *args, **kwargs):
Combinded cost components
"""
- return np.sum([
- operator.cost(input_data)
- for operator, input_data in zip(self.operators, args[0])
- ])
+ return np.sum(
+ [
+ operator.cost(input_data)
+ for operator, input_data in zip(self.operators, args[0])
+ ]
+ )
class OrderedWeightedL1Norm(ProximityParent):
@@ -613,16 +625,16 @@ class OrderedWeightedL1Norm(ProximityParent):
def __init__(self, weights):
if not import_sklearn: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Learn package not found see '
- + 'documentation for details: '
- + 'https://cea-cosmic.github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Learn package not found see "
+ + "documentation for details: "
+ + "https://cea-cosmic.github.io/ModOpt/#optional-packages",
)
if np.max(np.diff(weights)) > 0:
- raise ValueError('Weights must be non increasing')
+ raise ValueError("Weights must be non increasing")
self.weights = weights.flatten()
if (self.weights < 0).any():
raise ValueError(
- 'The weight values must be provided in descending order',
+ "The weight values must be provided in descending order",
)
self.op = self._op_method
self.cost = self._cost_method
@@ -660,7 +672,9 @@ def _op_method(self, input_data, extra_factor=1.0):
# Projection onto the monotone non-negative cone using
# isotonic_regression
data_abs = isotonic_regression(
- data_abs - threshold, y_min=0, increasing=False,
+ data_abs - threshold,
+ y_min=0,
+ increasing=False,
)
# Unsorting the data
@@ -668,7 +682,7 @@ def _op_method(self, input_data, extra_factor=1.0):
data_abs_unsorted[data_abs_sort_idx] = data_abs
# Putting the sign back
- with np.errstate(invalid='ignore'):
+ with np.errstate(invalid="ignore"):
sign_data = data_squeezed / np.abs(data_squeezed)
# Removing NAN caused by the sign
@@ -698,8 +712,8 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - OWL NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -730,8 +744,7 @@ class Ridge(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self.op = self._op_method
@@ -782,8 +795,8 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L2 NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -818,7 +831,6 @@ class ElasticNet(ProximityParent):
"""
def __init__(self, linear, alpha, beta):
-
self._linear = linear
self.alpha = alpha
self.beta = beta
@@ -844,8 +856,8 @@ def _op_method(self, input_data, extra_factor=1.0):
"""
soft_threshold = self.beta * extra_factor
- normalization = (self.alpha * 2 * extra_factor + 1)
- return thresh(input_data, soft_threshold, 'soft') / normalization
+ normalization = self.alpha * 2 * extra_factor + 1
+ return thresh(input_data, soft_threshold, "soft") / normalization
def _cost_method(self, *args, **kwargs):
"""Calculate Ridge component of the cost.
@@ -871,8 +883,8 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - ELASTIC NET (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -938,7 +950,7 @@ def k_value(self):
def k_value(self, k_val):
if k_val < 1:
raise ValueError(
- 'The k parameter should be greater or equal than 1',
+ "The k parameter should be greater or equal than 1",
)
self._k_value = k_val
@@ -983,7 +995,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0):
alpha_beta = alpha_input - self.beta * extra_factor
theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0))
theta = np.nan_to_num(theta)
- theta += (alpha_input > (self.beta * extra_factor + 1))
+ theta += alpha_input > (self.beta * extra_factor + 1)
return theta
def _interpolate(self, alpha0, alpha1, sum0, sum1):
@@ -993,7 +1005,7 @@ def _interpolate(self, alpha0, alpha1, sum0, sum1):
:math:`\sum\theta(\alpha^*)=k` via a linear interpolation.
Parameters
- -----------
+ ----------
alpha0: float
A value for wich :math:`\sum\theta(\alpha^0) \leq k`
alpha1: float
@@ -1074,12 +1086,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
midpoint = 0
while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]):
-
midpoint = (first_idx + last_idx) // 2
cnt += 1
if prev_midpoint == midpoint:
-
# Particular case
sum0 = self._compute_theta(
data_abs,
@@ -1092,11 +1102,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
extra_factor,
).sum()
- if (np.abs(sum0 - self._k_value) <= tolerance):
+ if np.abs(sum0 - self._k_value) <= tolerance:
found = True
midpoint = first_idx
- if (np.abs(sum1 - self._k_value) <= tolerance):
+ if np.abs(sum1 - self._k_value) <= tolerance:
found = True
midpoint = last_idx - 1
# -1 because output is index such that
@@ -1141,13 +1151,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
if found:
return (
- midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1,
+ midpoint,
+ alpha[midpoint],
+ alpha[midpoint + 1],
+ sum0,
+ sum1,
)
raise ValueError(
- 'Cannot find the coordinate of alpha (i) such '
- + 'that sum(theta(alpha[i])) =< k and '
- + 'sum(theta(alpha[i+1])) >= k ',
+ "Cannot find the coordinate of alpha (i) such "
+ + "that sum(theta(alpha[i])) =< k and "
+ + "sum(theta(alpha[i+1])) >= k ",
)
def _find_alpha(self, input_data, extra_factor=1.0):
@@ -1171,15 +1185,13 @@ def _find_alpha(self, input_data, extra_factor=1.0):
data_size = input_data.shape[0]
# Computes the alpha^i points line 1 in Algorithm 1.
- alpha = np.zeros((data_size * 2))
+ alpha = np.zeros(data_size * 2)
data_abs = np.abs(input_data)
- alpha[:data_size] = (
- (self.beta * extra_factor)
- / (data_abs + sys.float_info.epsilon)
+ alpha[:data_size] = (self.beta * extra_factor) / (
+ data_abs + sys.float_info.epsilon
)
- alpha[data_size:] = (
- (self.beta * extra_factor + 1)
- / (data_abs + sys.float_info.epsilon)
+ alpha[data_size:] = (self.beta * extra_factor + 1) / (
+ data_abs + sys.float_info.epsilon
)
alpha = np.sort(np.unique(alpha))
@@ -1216,8 +1228,8 @@ def _op_method(self, input_data, extra_factor=1.0):
k_max = np.prod(data_shape)
if self._k_value > k_max:
warn(
- 'K value of the K-support norm is greater than the input '
- + 'dimension, its value will be set to {0}'.format(k_max),
+ "K value of the K-support norm is greater than the input "
+ + f"dimension, its value will be set to {k_max}",
)
self._k_value = k_max
@@ -1229,8 +1241,7 @@ def _op_method(self, input_data, extra_factor=1.0):
# Computes line 5. in Algorithm 1.
rslt = np.nan_to_num(
- (input_data.flatten() * theta)
- / (theta + self.beta * extra_factor),
+ (input_data.flatten() * theta) / (theta + self.beta * extra_factor),
)
return rslt.reshape(data_shape)
@@ -1271,25 +1282,20 @@ def _find_q(self, sorted_data):
found = True
q_val = 0
- elif (
- (sorted_data[self._k_value - 1:].sum())
- <= sorted_data[self._k_value - 1]
- ):
+ elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]:
found = True
q_val = self._k_value - 1
while (
- not found and not cnt == self._k_value
+ not found
+ and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
-
q_val = (first_idx + last_idx) // 2
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
- if (
- sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]
- ):
+ if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]:
found = True
else:
@@ -1324,15 +1330,12 @@ def _cost_method(self, *args, **kwargs):
data_abs = data_abs[ix] # Sorted absolute value of the data
q_val = self._find_q(data_abs)
cost_val = (
- (
- np.sum(data_abs[:q_val] ** 2) * 0.5
- + np.sum(data_abs[q_val:]) ** 2
- / (self._k_value - q_val)
- ) * self.beta
- )
+ np.sum(data_abs[:q_val] ** 2) * 0.5
+ + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
+ ) * self.beta
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - K-SUPPORT NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
@@ -1377,7 +1380,7 @@ def __init__(self, weights):
self.op = self._op_method
self.cost = self._cost_method
- def _op_method(self, input_data, extra_factor=1.0):
+ def _op_method(self, input_data, *args, extra_factor=1.0, **kwargs):
"""Operator.
This method returns the input data thresholded by the weights.
@@ -1388,6 +1391,7 @@ def _op_method(self, input_data, extra_factor=1.0):
Input data array
extra_factor : float
Additional multiplication factor (default is ``1.0``)
+ *args, **kwargs: no effects
Returns
-------
@@ -1403,7 +1407,7 @@ def _op_method(self, input_data, extra_factor=1.0):
(1.0 - self.weights * extra_factor / denominator),
)
- def _cost_method(self, input_data):
+ def _cost_method(self, input_data, *args, **kwargs):
"""Calculate the group LASSO component of the cost.
This method calculate the cost function of the proximable part.
@@ -1413,6 +1417,8 @@ def _cost_method(self, input_data):
input_data : numpy.ndarray
Input array of the sparse code
+ *args, **kwargs: no effects.
+
Returns
-------
float
diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
similarity index 91%
rename from modopt/opt/reweight.py
rename to src/modopt/opt/reweight.py
index 8c4f2449..4a9bf44b 100644
--- a/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""REWEIGHTING CLASSES.
This module contains classes for reweighting optimisation implementations.
@@ -13,7 +11,7 @@
from modopt.base.types import check_float
-class cwbReweight(object):
+class cwbReweight:
"""Candes, Wakin and Boyd reweighting class.
This class implements the reweighting scheme described in
@@ -45,7 +43,6 @@ class cwbReweight(object):
"""
def __init__(self, weights, thresh_factor=1.0, verbose=False):
-
self.weights = check_float(weights)
self.original_weights = np.copy(self.weights)
self.thresh_factor = check_float(thresh_factor)
@@ -81,7 +78,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(' - Reweighting: {0}'.format(self._rw_num))
+ print(f" - Reweighting: {self._rw_num}")
self._rw_num += 1
@@ -89,7 +86,7 @@ def reweight(self, input_data):
if input_data.shape != self.weights.shape:
raise ValueError(
- 'Input data must have the same shape as the initial weights.',
+ "Input data must have the same shape as the initial weights.",
)
thresh_weights = self.thresh_factor * self.original_weights
diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
similarity index 73%
rename from modopt/plot/__init__.py
rename to src/modopt/plot/__init__.py
index 28d60be6..f31ed596 100644
--- a/modopt/plot/__init__.py
+++ b/src/modopt/plot/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PLOTTING ROUTINES.
This module contains submodules for plotting applications.
@@ -8,4 +6,4 @@
"""
-__all__ = ['cost_plot']
+__all__ = ["cost_plot"]
diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
similarity index 66%
rename from modopt/plot/cost_plot.py
rename to src/modopt/plot/cost_plot.py
index aa855eaa..7fb7e39b 100644
--- a/modopt/plot/cost_plot.py
+++ b/src/modopt/plot/cost_plot.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PLOTTING ROUTINES.
This module contains methods for making plots.
@@ -37,20 +35,20 @@ def plotCost(cost_list, output=None):
"""
if import_fail:
- raise ImportError('Matplotlib package not found')
+ raise ImportError("Matplotlib package not found")
else:
if isinstance(output, type(None)):
- file_name = 'cost_function.png'
+ file_name = "cost_function.png"
else:
- file_name = '{0}_cost_function.png'.format(output)
+ file_name = f"{output}_cost_function.png"
plt.figure()
- plt.plot(np.log10(cost_list), 'r-')
- plt.title('Cost Function')
- plt.xlabel('Iteration')
- plt.ylabel(r'$\log_{10}$ Cost')
+ plt.plot(np.log10(cost_list), "r-")
+ plt.title("Cost Function")
+ plt.xlabel("Iteration")
+ plt.ylabel(r"$\log_{10}$ Cost")
plt.savefig(file_name)
plt.close()
- print(' - Saving cost function data to:', file_name)
+ print(" - Saving cost function data to:", file_name)
diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
similarity index 58%
rename from modopt/signal/__init__.py
rename to src/modopt/signal/__init__.py
index dbc6d053..6bf0912b 100644
--- a/modopt/signal/__init__.py
+++ b/src/modopt/signal/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""SIGNAL PROCESSING ROUTINES.
This module contains submodules for signal processing.
@@ -8,4 +6,4 @@
"""
-__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet']
+__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py
similarity index 92%
rename from modopt/signal/filter.py
rename to src/modopt/signal/filter.py
index 8e24768c..33c3c105 100644
--- a/modopt/signal/filter.py
+++ b/src/modopt/signal/filter.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""FILTER ROUTINES.
This module contains methods for distance measurements in cosmology.
@@ -73,15 +71,15 @@ def mex_hat(data_point, sigma):
Examples
--------
>>> from modopt.signal.filter import mex_hat
- >>> mex_hat(2, 1)
- -0.3521390522571337
+ >>> round(mex_hat(2, 1), 15)
+ -0.352139052257134
"""
data_point = check_float(data_point)
sigma = check_float(sigma)
xs = (data_point / sigma) ** 2
- factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25
+ factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
return factor * (1 - xs) * np.exp(-0.5 * xs)
@@ -108,8 +106,8 @@ def mex_hat_dir(data_gauss, data_mex, sigma):
Examples
--------
>>> from modopt.signal.filter import mex_hat_dir
- >>> mex_hat_dir(1, 2, 1)
- 0.17606952612856686
+ >>> round(mex_hat_dir(1, 2, 1), 16)
+ 0.1760695261285668
"""
data_gauss = check_float(data_gauss)
diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py
similarity index 86%
rename from modopt/signal/noise.py
rename to src/modopt/signal/noise.py
index a59d5553..28307f52 100644
--- a/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""NOISE ROUTINES.
This module contains methods for adding and removing noise from data.
@@ -8,14 +6,12 @@
"""
-from builtins import zip
-
import numpy as np
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type='gauss'):
+def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -29,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
default is ``1.0``)
noise_type : {'gauss', 'poisson'}
Type of noise to be added (default is ``'gauss'``)
+ rng: np.random.Generator or int
+ A Random number generator or a seed to initialize one.
+
Returns
-------
@@ -68,9 +67,12 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526])
"""
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
+
input_data = np.array(input_data)
- if noise_type not in {'gauss', 'poisson'}:
+ if noise_type not in {"gauss", "poisson"}:
raise ValueError(
'Invalid noise type. Options are "gauss" or "poisson"',
)
@@ -78,15 +80,14 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
if isinstance(sigma, (list, tuple, np.ndarray)):
if len(sigma) != input_data.shape[0]:
raise ValueError(
- 'Number of sigma values must match first dimension of input '
- + 'data',
+ "Number of sigma values must match first dimension of input " + "data",
)
- if noise_type == 'gauss':
- random = np.random.randn(*input_data.shape)
+ if noise_type == "gauss":
+ random = rng.standard_normal(input_data.shape)
- elif noise_type == 'poisson':
- random = np.random.poisson(np.abs(input_data))
+ elif noise_type == "poisson":
+ random = rng.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
return input_data + sigma * random
@@ -96,7 +97,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
return input_data + noise
-def thresh(input_data, threshold, threshold_type='hard'):
+def thresh(input_data, threshold, threshold_type="hard"):
r"""Threshold data.
This method perfoms hard or soft thresholding on the input data.
@@ -169,12 +170,12 @@ def thresh(input_data, threshold, threshold_type='hard'):
input_data = xp.array(input_data)
- if threshold_type not in {'hard', 'soft'}:
+ if threshold_type not in {"hard", "soft"}:
raise ValueError(
'Invalid threshold type. Options are "hard" or "soft"',
)
- if threshold_type == 'soft':
+ if threshold_type == "soft":
denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
max_value = xp.maximum((1.0 - threshold / denominator), 0)
diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
similarity index 90%
rename from modopt/signal/positivity.py
rename to src/modopt/signal/positivity.py
index e4ec098d..8d7aa46c 100644
--- a/modopt/signal/positivity.py
+++ b/src/modopt/signal/positivity.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""POSITIVITY.
This module contains a function that retains only positive coefficients in
@@ -47,8 +45,8 @@ def pos_recursive(input_data):
Positive coefficients
"""
- if input_data.dtype == 'O':
- res = np.array([pos_recursive(elem) for elem in input_data])
+ if input_data.dtype == "O":
+ res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
res = pos_thresh(input_data)
@@ -97,15 +95,15 @@ def positive(input_data, ragged=False):
"""
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid data type, input must be `int`, `float`, `list`, '
- + '`tuple` or `np.ndarray`.',
+ "Invalid data type, input must be `int`, `float`, `list`, "
+ + "`tuple` or `np.ndarray`.",
)
if isinstance(input_data, (int, float)):
return pos_thresh(input_data)
if ragged:
- input_data = np.array(input_data, dtype='object')
+ input_data = np.array(input_data, dtype="object")
else:
input_data = np.array(input_data)
diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py
similarity index 84%
rename from modopt/signal/svd.py
rename to src/modopt/signal/svd.py
index 6dcb9eda..cf147503 100644
--- a/modopt/signal/svd.py
+++ b/src/modopt/signal/svd.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""SVD ROUTINES.
This module contains methods for thresholding singular values.
@@ -52,12 +50,12 @@ def find_n_pc(u_vec, factor=0.5):
"""
if np.sqrt(u_vec.shape[0]) % 1:
raise ValueError(
- 'Invalid left singular vector. The size of the first '
- + 'dimenion of ``u_vec`` must be perfect square.',
+ "Invalid left singular vector. The size of the first "
+ + "dimenion of ``u_vec`` must be perfect square.",
)
# Get the shape of the array
- array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2)
+ array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Find the auto correlation of the left singular vector.
u_auto = [
@@ -69,13 +67,12 @@ def find_n_pc(u_vec, factor=0.5):
]
# Return the required number of principal components.
- return np.sum([
- (
- u_val[tuple(zip(array_shape // 2))] ** 2 <= factor
- * np.sum(u_val ** 2),
- )
- for u_val in u_auto
- ])
+ return np.sum(
+ [
+ (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
+ for u_val in u_auto
+ ]
+ )
def calculate_svd(input_data):
@@ -101,17 +98,17 @@ def calculate_svd(input_data):
"""
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise TypeError('Input data must be a 2D np.ndarray.')
+ raise TypeError("Input data must be a 2D np.ndarray.")
return svd(
input_data,
check_finite=False,
- lapack_driver='gesvd',
+ lapack_driver="gesvd",
full_matrices=False,
)
-def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
+def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
"""Threshold the singular values.
This method thresholds the input data using singular value decomposition.
@@ -156,16 +153,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
"""
less_than_zero = isinstance(n_pc, int) and n_pc <= 0
- str_not_all = isinstance(n_pc, str) and n_pc != 'all'
+ str_not_all = isinstance(n_pc, str) and n_pc != "all"
- if (
- (not isinstance(n_pc, (int, str, type(None))))
- or less_than_zero
- or str_not_all
- ):
+ if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
raise ValueError(
- 'Invalid value for "n_pc", specify a positive integer value or '
- + '"all"',
+ 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
)
# Get SVD of input data.
@@ -176,15 +168,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print('xxxx', n_pc, u_vec)
+ print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
- if (
- (isinstance(n_pc, int) and n_pc >= s_values.size)
- or (isinstance(n_pc, str) and n_pc == 'all')
+ if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
+ isinstance(n_pc, str) and n_pc == "all"
):
n_pc = s_values.size
- warn('Using all singular values.')
+ warn("Using all singular values.")
threshold = s_values[n_pc - 1]
@@ -192,7 +183,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
s_new = thresh(s_values, threshold, thresh_type)
if np.all(s_new == s_values):
- warn('No change to singular values.')
+ warn("No change to singular values.")
# Diagonalize the svd
s_new = np.diag(s_new)
@@ -206,7 +197,7 @@ def svd_thresh_coef_fast(
threshold,
n_vals=-1,
extra_vals=5,
- thresh_type='hard',
+ thresh_type="hard",
):
"""Threshold the singular values coefficients.
@@ -241,7 +232,7 @@ def svd_thresh_coef_fast(
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
- ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
+ ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
n_vals = min(n_vals + extra_vals, *input_data.shape)
s_values = thresh(
@@ -259,7 +250,7 @@ def svd_thresh_coef_fast(
)
-def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
+def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
"""Threshold the singular values coefficients.
This method thresholds the input data using singular value decomposition.
@@ -287,7 +278,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""
if not callable(operator):
- raise TypeError('Operator must be a callable function.')
+ raise TypeError("Operator must be a callable function.")
# Get SVD of data matrix
u_vec, s_values, v_vec = calculate_svd(input_data)
@@ -299,13 +290,12 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
a_matrix = np.dot(s_values, v_vec)
# Get the shape of the array
- array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2)
+ array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
- ti = np.array([
- np.linalg.norm(elem)
- for elem in operator(matrix2cube(u_vec, array_shape))
- ])
+ ti = np.array(
+ [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
+ )
threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
# Threshold coefficients.
diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py
similarity index 80%
rename from modopt/signal/validation.py
rename to src/modopt/signal/validation.py
index 422a987b..cdf69b7d 100644
--- a/modopt/signal/validation.py
+++ b/src/modopt/signal/validation.py
@@ -16,6 +16,7 @@ def transpose_test(
x_args=None,
y_shape=None,
y_args=None,
+ rng=None,
):
"""Transpose test.
@@ -36,6 +37,8 @@ def transpose_test(
Shape of transpose operator input data (default is ``None``)
y_args : tuple, optional
Arguments to be passed to transpose operator (default is ``None``)
+ rng: numpy.random.Generator or int or None (default is ``None``)
+ Initialized random number generator or seed.
Raises
------
@@ -54,7 +57,7 @@ def transpose_test(
"""
if not callable(operator) or not callable(operator_t):
- raise TypeError('The input operators must be callable functions.')
+ raise TypeError("The input operators must be callable functions.")
if isinstance(y_shape, type(None)):
y_shape = x_shape
@@ -62,9 +65,11 @@ def transpose_test(
if isinstance(y_args, type(None)):
y_args = x_args
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
# Generate random arrays.
- x_val = np.random.ranf(x_shape)
- y_val = np.random.ranf(y_shape)
+ x_val = rng.random(x_shape)
+ y_val = rng.random(y_shape)
# Calculate
mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val))
@@ -73,4 +78,4 @@ def transpose_test(
x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
# Test the difference between the two.
- print(' - | - | =', np.abs(mx_y - x_mty))
+ print(" - | - | =", np.abs(mx_y - x_mty))
diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
similarity index 88%
rename from modopt/signal/wavelet.py
rename to src/modopt/signal/wavelet.py
index bc4ffc70..b55b78d9 100644
--- a/modopt/signal/wavelet.py
+++ b/src/modopt/signal/wavelet.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""WAVELET MODULE.
This module contains methods for performing wavelet transformations using
@@ -58,20 +56,20 @@ def execute(command_line):
"""
if not isinstance(command_line, str):
- raise TypeError('Command line must be a string.')
+ raise TypeError("Command line must be a string.")
command = command_line.split()
process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
def call_mr_transform(
input_data,
- opt='',
- path='./',
+ opt="",
+ path="./",
remove_files=True,
): # pragma: no cover
"""Call ``mr_transform``.
@@ -127,26 +125,23 @@ def call_mr_transform(
"""
if not import_astropy:
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise ValueError('Input data must be a 2D numpy array.')
+ raise ValueError("Input data must be a 2D numpy array.")
- executable = 'mr_transform'
+ executable = "mr_transform"
# Make sure mr_transform is installed.
is_executable(executable)
# Create a unique string using the current date and time.
- unique_string = (
- datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
- + str(getrandbits(128))
- )
+ unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = '{0}mr_temp_{1}'.format(path, unique_string)
- file_fits = '{0}.fits'.format(file_name)
- file_mr = '{0}.mr'.format(file_name)
+ file_name = f"{path}mr_temp_{unique_string}"
+ file_fits = f"{file_name}.fits"
+ file_mr = f"{file_name}.mr"
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -155,15 +150,15 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = ' '.join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable, *opt, file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
- if any(word in stdout for word in ('bad', 'Error', 'Sorry')):
+ if any(word in stdout for word in ("bad", "Error", "Sorry")):
remove(file_fits)
message = '{0} raised following exception: "{1}"'
raise RuntimeError(
- message.format(executable, stdout.rstrip('\n')),
+ message.format(executable, stdout.rstrip("\n")),
)
# Retrieve wavelet transformed data.
@@ -198,12 +193,12 @@ def trim_filter(filter_array):
min_idx = np.min(non_zero_indices, axis=-1)
max_idx = np.max(non_zero_indices, axis=-1)
- return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1]
+ return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1]
def get_mr_filters(
data_shape,
- opt='',
+ opt="",
coarse=False,
trim=False,
): # pragma: no cover
@@ -256,7 +251,7 @@ def get_mr_filters(
return mr_filters[:-1]
-def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
+def filter_convolve(input_data, filters, filter_rot=False, method="scipy"):
"""Filter convolve.
This method convolves the input image with the wavelet filters.
@@ -315,16 +310,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
axis=0,
)
- return np.array([
- convolve(input_data, filt, method=method) for filt in filters
- ])
+ return np.array([convolve(input_data, filt, method=method) for filt in filters])
def filter_convolve_stack(
input_data,
filters,
filter_rot=False,
- method='scipy',
+ method="scipy",
):
"""Filter convolve.
@@ -366,7 +359,9 @@ def filter_convolve_stack(
"""
# Return the convolved data cube.
- return np.array([
- filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
- for elem in input_data
- ])
+ return np.array(
+ [
+ filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
+ for elem in input_data
+ ]
+ )
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
new file mode 100644
index 00000000..63847764
--- /dev/null
+++ b/tests/test_algorithms.py
@@ -0,0 +1,277 @@
+"""UNIT TESTS FOR Algorithms.
+
+This module contains unit tests for the modopt.opt module.
+
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
+
+import numpy as np
+import numpy.testing as npt
+from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
+from pytest_cases import (
+ fixture,
+ parametrize,
+ parametrize_with_cases,
+)
+
+
+SKLEARN_AVAILABLE = True
+try:
+ import sklearn
+except ImportError:
+ SKLEARN_AVAILABLE = False
+
+
+rng = np.random.default_rng()
+
+
+@fixture
+def idty():
+ """Identity function."""
+ return lambda x: x
+
+
+@fixture
+def reweight_op():
+ """Reweight operator."""
+ data3 = np.arange(9).reshape(3, 3).astype(float) + 1
+ return reweight.cwbReweight(data3)
+
+
+def build_kwargs(kwargs, use_metrics):
+ """Build the kwargs for each algorithm, replacing placeholders by true values.
+
+ This function has to be call for each test, as direct parameterization somehow
+ is not working with pytest-xdist and pytest-cases.
+ It also adds dummy metric measurement to validate the metric api.
+ """
+ update_value = {
+ "idty": lambda x: x,
+ "lin_idty": linear.Identity(),
+ "reweight_op": reweight.cwbReweight(
+ np.arange(9).reshape(3, 3).astype(float) + 1
+ ),
+ }
+ new_kwargs = dict()
+ print(kwargs)
+ # update the value of the dict is possible.
+ for key in kwargs:
+ new_kwargs[key] = update_value.get(kwargs[key], kwargs[key])
+
+ if use_metrics:
+ new_kwargs["linear"] = linear.Identity()
+ new_kwargs["metrics"] = {
+ "diff": {
+ "metric": lambda test, ref: np.sum(test - ref),
+ "mapping": {"x_new": "test"},
+ "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))},
+ "early_stopping": False,
+ }
+ }
+
+ return new_kwargs
+
+
+@parametrize(use_metrics=[True, False])
+class AlgoCases:
+ r"""Cases for algorithms.
+
+ Most of the test solves the trivial problem
+
+ .. math::
+ \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0
+
+ More complex and concrete usecases are shown in examples.
+ """
+
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = data1 + rng.standard_normal(data1.shape) * 1e-6
+ max_iter = 20
+
+ @parametrize(
+ kwargs=[
+ {"beta_update": "idty", "auto_iterate": False, "cost": None},
+ {"beta_update": "idty"},
+ {"cost": None, "lambda_update": None},
+ {"beta_update": "idty", "a_cd": 3},
+ {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7},
+ {"restart_strategy": "adaptive", "xi_restart": 0.9},
+ {
+ "restart_strategy": "greedy",
+ "xi_restart": 0.9,
+ "min_beta": 1.0,
+ "s_greedy": 1.1,
+ },
+ ]
+ )
+ def case_forward_backward(self, kwargs, idty, use_metrics):
+ """Forward Backward case."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ algo = algorithms.ForwardBackward(
+ self.data1,
+ grad=gradient.GradBasic(self.data1, idty, idty),
+ prox=proximity.Positivity(),
+ **update_kwargs,
+ )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(
+ kwargs=[
+ {
+ "cost": None,
+ "auto_iterate": False,
+ "gamma_update": "idty",
+ "beta_update": "idty",
+ },
+ {"gamma_update": "idty", "lambda_update": "idty"},
+ {"cost": True},
+ {"cost": True, "step_size": 2},
+ ]
+ )
+ def case_gen_forward_backward(self, kwargs, use_metrics, idty):
+ """General FB setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ prox_dual_inst = proximity.IdentityProx()
+ if update_kwargs.get("cost", None) is True:
+ update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
+ algo = algorithms.GenForwardBackward(
+ self.data1,
+ grad=grad_inst,
+ prox_list=[prox_inst, prox_dual_inst],
+ **update_kwargs,
+ )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(
+ kwargs=[
+ {
+ "sigma_dual": "idty",
+ "tau_update": "idty",
+ "rho_update": "idty",
+ "auto_iterate": False,
+ },
+ {
+ "sigma_dual": "idty",
+ "tau_update": "idty",
+ "rho_update": "idty",
+ },
+ {
+ "linear": "lin_idty",
+ "cost": True,
+ "reweight": "reweight_op",
+ },
+ ]
+ )
+ def case_condat(self, kwargs, use_metrics, idty):
+ """Condat Vu Algorithm setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ prox_dual_inst = proximity.IdentityProx()
+ if update_kwargs.get("cost", None) is True:
+ update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
+
+ algo = algorithms.Condat(
+ self.data1,
+ self.data2,
+ grad=grad_inst,
+ prox=prox_inst,
+ prox_dual=prox_dual_inst,
+ **update_kwargs,
+ )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}])
+ def case_pogm(self, kwargs, use_metrics, idty):
+ """POGM setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ algo = algorithms.POGM(
+ u=self.data1,
+ x=self.data1,
+ y=self.data1,
+ z=self.data1,
+ grad=grad_inst,
+ prox=prox_inst,
+ **update_kwargs,
+ )
+
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(
+ GradDescent=[
+ algorithms.VanillaGenericGradOpt,
+ algorithms.AdaGenericGradOpt,
+ algorithms.ADAMGradOpt,
+ algorithms.MomentumGradOpt,
+ algorithms.RMSpropGradOpt,
+ algorithms.SAGAOptGradOpt,
+ ]
+ )
+ def case_grad(self, GradDescent, use_metrics, idty):
+ """Gradient Descent algorithm test."""
+ update_kwargs = build_kwargs({}, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ cost_inst = cost.costObj([grad_inst, prox_inst])
+
+ algo = GradDescent(
+ self.data1,
+ grad=grad_inst,
+ prox=prox_inst,
+ cost=cost_inst,
+ **update_kwargs,
+ )
+ algo.iterate()
+ return algo, update_kwargs
+
+ @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM])
+ def case_admm(self, admm, use_metrics, idty):
+ """ADMM setup."""
+
+ def optim1(init, obs):
+ return obs
+
+ def optim2(init, obs):
+ return obs
+
+ update_kwargs = build_kwargs({}, use_metrics)
+ algo = admm(
+ u=self.data1,
+ v=self.data1,
+ mu=np.zeros_like(self.data1),
+ A=linear.Identity(),
+ B=linear.Identity(),
+ b=self.data1,
+ optimizers=(optim1, optim2),
+ **update_kwargs,
+ )
+ algo.iterate()
+ return algo, update_kwargs
+
+
+@parametrize_with_cases("algo, kwargs", cases=AlgoCases)
+def test_algo(algo, kwargs):
+ """Test algorithms."""
+ if kwargs.get("auto_iterate") is False:
+ # algo already run
+ npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1)
+ else:
+ npt.assert_almost_equal(algo.x_final, AlgoCases.data1)
+
+ if kwargs.get("metrics"):
+ print(algo.metrics)
+ npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3)
diff --git a/tests/test_base.py b/tests/test_base.py
new file mode 100644
index 00000000..62e09095
--- /dev/null
+++ b/tests/test_base.py
@@ -0,0 +1,219 @@
+"""
+Test for base module.
+
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
+
+import numpy as np
+import numpy.testing as npt
+import pytest
+from test_helpers import failparam, skipparam
+
+from modopt.base import backend, np_adjust, transform, types
+from modopt.base.backend import LIBRARIES
+
+
+class TestNpAdjust:
+ """Test for npadjust."""
+
+ array33 = np.arange(9).reshape((3, 3))
+ array233 = np.arange(18).reshape((2, 3, 3))
+ arraypad = np.array(
+ [
+ [0, 0, 0, 0, 0],
+ [0, 0, 1, 2, 0],
+ [0, 3, 4, 5, 0],
+ [0, 6, 7, 8, 0],
+ [0, 0, 0, 0, 0],
+ ]
+ )
+
+ def test_rotate(self):
+ """Test rotate."""
+ npt.assert_array_equal(
+ np_adjust.rotate(self.array33),
+ np.rot90(np.rot90(self.array33)),
+ err_msg="Incorrect rotation.",
+ )
+
+ def test_rotate_stack(self):
+ """Test rotate_stack."""
+ npt.assert_array_equal(
+ np_adjust.rotate_stack(self.array233),
+ np.rot90(self.array233, k=2, axes=(1, 2)),
+ err_msg="Incorrect stack rotation.",
+ )
+
+ @pytest.mark.parametrize(
+ "padding",
+ [
+ 1,
+ [1, 1],
+ np.array([1, 1]),
+ failparam("1", raises=ValueError),
+ ],
+ )
+ def test_pad2d(self, padding):
+ """Test pad2d."""
+ npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad)
+
+ def test_fancy_transpose(self):
+ """Test fancy transpose."""
+ npt.assert_array_equal(
+ np_adjust.fancy_transpose(self.array233),
+ np.array(
+ [
+ [[0, 3, 6], [9, 12, 15]],
+ [[1, 4, 7], [10, 13, 16]],
+ [[2, 5, 8], [11, 14, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose",
+ )
+
+ def test_ftr(self):
+ """Test ftr."""
+ npt.assert_array_equal(
+ np_adjust.ftr(self.array233),
+ np.array(
+ [
+ [[0, 3, 6], [9, 12, 15]],
+ [[1, 4, 7], [10, 13, 16]],
+ [[2, 5, 8], [11, 14, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose: ftr",
+ )
+
+ def test_ftl(self):
+ """Test fancy transpose left."""
+ npt.assert_array_equal(
+ np_adjust.ftl(self.array233),
+ np.array(
+ [
+ [[0, 9], [1, 10], [2, 11]],
+ [[3, 12], [4, 13], [5, 14]],
+ [[6, 15], [7, 16], [8, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose: ftl",
+ )
+
+
+class TestTransforms:
+ """Test for the transform module."""
+
+ cube = np.arange(16).reshape((4, 2, 2))
+ map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]])
+ matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]])
+ layout = (2, 2)
+ fail_layout = (3, 3)
+
+ @pytest.mark.parametrize(
+ ("func", "indata", "layout", "outdata"),
+ [
+ (transform.cube2map, cube, layout, map),
+ failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError),
+ (transform.map2cube, map, layout, cube),
+ (transform.map2matrix, map, layout, matrix),
+ (transform.matrix2map, matrix, matrix.shape, map),
+ ],
+ )
+ def test_map(self, func, indata, layout, outdata):
+ """Test cube2map."""
+ npt.assert_array_equal(
+ func(indata, layout),
+ outdata,
+ )
+ if func.__name__ != "map2matrix":
+ npt.assert_raises(ValueError, func, indata, self.fail_layout)
+
+ def test_cube2matrix(self):
+ """Test cube2matrix."""
+ npt.assert_array_equal(
+ transform.cube2matrix(self.cube),
+ self.matrix,
+ )
+
+ def test_matrix2cube(self):
+ """Test matrix2cube."""
+ npt.assert_array_equal(
+ transform.matrix2cube(self.matrix, self.cube[0].shape),
+ self.cube,
+ err_msg="Incorrect transformation: matrix2cube",
+ )
+
+
+class TestType:
+ """Test for type module."""
+
+ data_list = list(range(5)) # noqa: RUF012
+ data_int = np.arange(5)
+ data_flt = np.arange(5).astype(float)
+
+ @pytest.mark.parametrize(
+ ("data", "checked"),
+ [
+ (1.0, 1.0),
+ (1, 1.0),
+ (data_list, data_flt),
+ (data_int, data_flt),
+ failparam("1.0", 1.0, raises=TypeError),
+ ],
+ )
+ def test_check_float(self, data, checked):
+ """Test check float."""
+ npt.assert_array_equal(types.check_float(data), checked)
+
+ @pytest.mark.parametrize(
+ ("data", "checked"),
+ [
+ (1.0, 1),
+ (1, 1),
+ (data_list, data_int),
+ (data_flt, data_int),
+ failparam("1", None, raises=TypeError),
+ ],
+ )
+ def test_check_int(self, data, checked):
+ """Test check int."""
+ npt.assert_array_equal(types.check_int(data), checked)
+
+ @pytest.mark.parametrize(
+ ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)]
+ )
+ def test_check_npndarray(self, data, dtype):
+ """Test check_npndarray."""
+ npt.assert_raises(
+ TypeError,
+ types.check_npndarray,
+ data,
+ dtype=dtype,
+ )
+
+ def test_check_callable(self):
+ """Test callable."""
+ npt.assert_raises(TypeError, types.check_callable, 1)
+
+
+@pytest.mark.parametrize(
+ "backend_name",
+ [
+ skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed")
+ for name in LIBRARIES
+ ],
+)
+def test_tf_backend(backend_name):
+ """Test Modopt computational backends."""
+ xp, checked_backend_name = backend.get_backend(backend_name)
+ if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]:
+ raise AssertionError(f"{backend_name} get_backend fails!")
+ xp_input = backend.change_backend(np.array([10, 10]), backend_name)
+ if (
+ backend.get_array_module(LIBRARIES[backend_name].ones(1))
+ != backend.LIBRARIES[backend_name]
+ or backend.get_array_module(xp_input) != LIBRARIES[backend_name]
+ ):
+ raise AssertionError(f"{backend_name} backend fails!")
diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
new file mode 100644
index 00000000..0ded847a
--- /dev/null
+++ b/tests/test_helpers/__init__.py
@@ -0,0 +1,5 @@
+"""Utilities for tests."""
+
+from .utils import Dummy, failparam, skipparam
+
+__all__ = ["Dummy", "failparam", "skipparam"]
diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
new file mode 100644
index 00000000..41f948a6
--- /dev/null
+++ b/tests/test_helpers/utils.py
@@ -0,0 +1,27 @@
+"""
+Some helper functions for the test parametrization.
+
+They should be used inside ``@pytest.mark.parametrize`` call.
+
+:Author: Pierre-Antoine Comby
+"""
+
+import pytest
+
+
+def failparam(*args, raises=None):
+ """Return a pytest parameterization that should raise an error."""
+ if not issubclass(raises, Exception):
+ raise ValueError("raises should be an expected Exception.")
+ return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)])
+
+
+def skipparam(*args, cond=True, reason=""):
+ """Return a pytest parameterization that should be skip if cond is valid."""
+ return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)])
+
+
+class Dummy:
+ """Dummy Class."""
+
+ pass
diff --git a/tests/test_math.py b/tests/test_math.py
new file mode 100644
index 00000000..5c466e5e
--- /dev/null
+++ b/tests/test_math.py
@@ -0,0 +1,326 @@
+"""UNIT TESTS FOR MATH.
+
+This module contains unit tests for the modopt.math module.
+
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
+
+import pytest
+from test_helpers import failparam, skipparam
+
+import numpy as np
+import numpy.testing as npt
+
+
+from modopt.math import convolve, matrix, metrics, stats
+
+try:
+ import astropy
+except ImportError: # pragma: no cover
+ ASTROPY_AVAILABLE = False
+else: # pragma: no cover
+ ASTROPY_AVAILABLE = True
+try:
+ from skimage.metrics import structural_similarity as compare_ssim
+except ImportError: # pragma: no cover
+ SKIMAGE_AVAILABLE = False
+else:
+ SKIMAGE_AVAILABLE = True
+
+rng = np.random.default_rng(1)
+
+
+class TestConvolve:
+ """Test convolve functions."""
+
+ array233 = np.arange(18).reshape((2, 3, 3))
+ array233_1 = array233 + 1
+ result_astropy = np.array(
+ [
+ [210.0, 201.0, 210.0],
+ [129.0, 120.0, 129.0],
+ [210.0, 201.0, 210.0],
+ ]
+ )
+ result_scipy = np.array(
+ [
+ [
+ [14.0, 35.0, 38.0],
+ [57.0, 120.0, 111.0],
+ [110.0, 197.0, 158.0],
+ ],
+ [
+ [518.0, 845.0, 614.0],
+ [975.0, 1578.0, 1137.0],
+ [830.0, 1331.0, 950.0],
+ ],
+ ]
+ )
+
+ result_rot_kernel = np.array(
+ [
+ [
+ [66.0, 115.0, 82.0],
+ [153.0, 240.0, 159.0],
+ [90.0, 133.0, 82.0],
+ ],
+ [
+ [714.0, 1087.0, 730.0],
+ [1125.0, 1698.0, 1131.0],
+ [738.0, 1105.0, 730.0],
+ ],
+ ]
+ )
+
+ @pytest.mark.parametrize(
+ ("input_data", "kernel", "method", "result"),
+ [
+ skipparam(
+ array233[0],
+ array233_1[0],
+ "astropy",
+ result_astropy,
+ cond=not ASTROPY_AVAILABLE,
+ reason="astropy not available",
+ ),
+ failparam(
+ array233[0], array233_1, "astropy", result_astropy, raises=ValueError
+ ),
+ failparam(
+ array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError
+ ),
+ (array233[0], array233_1[0], "scipy", result_scipy[0]),
+ ],
+ )
+ def test_convolve(self, input_data, kernel, method, result):
+ """Test convolve function."""
+ npt.assert_allclose(convolve.convolve(input_data, kernel, method), result)
+
+ @pytest.mark.parametrize(
+ ("result", "rot_kernel"),
+ [
+ (result_scipy, False),
+ (result_rot_kernel, True),
+ ],
+ )
+ def test_convolve_stack(self, result, rot_kernel):
+ """Test convolve stack function."""
+ npt.assert_allclose(
+ convolve.convolve_stack(
+ self.array233, self.array233_1, rot_kernel=rot_kernel
+ ),
+ result,
+ )
+
+
+class TestMatrix:
+ """Test matrix module."""
+
+ array3 = np.arange(3)
+ array33 = np.arange(9).reshape((3, 3))
+ array23 = np.arange(6).reshape((2, 3))
+ gram_schmidt_out = (
+ np.array(
+ [
+ [0, 1.0, 2.0],
+ [3.0, 1.2, -6e-1],
+ [-1.77635684e-15, 0, 0],
+ ]
+ ),
+ np.array(
+ [
+ [0, 0.4472136, 0.89442719],
+ [0.91287093, 0.36514837, -0.18257419],
+ [-1.0, 0, 0],
+ ]
+ ),
+ )
+
+ @pytest.fixture(scope="module")
+ def pm_instance(self, request):
+ """Power Method instance."""
+ pm = matrix.PowerMethod(
+ lambda x_val: x_val.dot(x_val.T),
+ self.array33.shape,
+ verbose=True,
+ rng=np.random.default_rng(0),
+ )
+ return pm
+
+ @pytest.mark.parametrize(
+ ("return_opt", "output"),
+ [
+ ("orthonormal", gram_schmidt_out[1]),
+ ("orthogonal", gram_schmidt_out[0]),
+ ("both", gram_schmidt_out),
+ failparam("fail!", gram_schmidt_out, raises=ValueError),
+ ],
+ )
+ def test_gram_schmidt(self, return_opt, output):
+ """Test gram schmidt."""
+ npt.assert_allclose(
+ matrix.gram_schmidt(self.array33, return_opt=return_opt), output
+ )
+
+ def test_nuclear_norm(self):
+ """Test nuclear norm."""
+ npt.assert_almost_equal(
+ matrix.nuclear_norm(self.array33),
+ 15.49193338482967,
+ )
+
+ def test_project(self):
+ """Test project."""
+ npt.assert_array_equal(
+ matrix.project(self.array3, self.array3 + 3),
+ np.array([0, 2.8, 5.6]),
+ )
+
+ def test_rot_matrix(self):
+ """Test rot_matrix."""
+ npt.assert_allclose(
+ matrix.rot_matrix(np.pi / 6),
+ np.array([[0.8660254, -0.5], [0.5, 0.8660254]]),
+ )
+
+ def test_rotate(self):
+ """Test rotate."""
+ npt.assert_array_equal(
+ matrix.rotate(self.array33, np.pi / 2),
+ np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]),
+ )
+
+ npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2)
+
+ def test_power_method(self, pm_instance, value=1):
+ """Test power method."""
+ npt.assert_almost_equal(pm_instance.spec_rad, value)
+ npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value)
+
+
+class TestMetrics:
+ """Test metrics module."""
+
+ data1 = np.arange(49).reshape(7, 7)
+ mask = np.ones(data1.shape)
+ ssim_res = 0.8958315888566867
+ ssim_mask_res = 0.8023827544418249
+ snr_res = 10.134554256920536
+ psnr_res = 14.860761791850397
+ mse_res = 0.03265305507330247
+ nrmse_res = 0.31136678840022625
+
+ @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed")
+ @pytest.mark.parametrize(
+ ("data1", "data2", "result", "mask"),
+ [
+ (data1, data1**2, ssim_res, None),
+ (data1, data1**2, ssim_mask_res, mask),
+ failparam(data1, data1, None, 1, raises=ValueError),
+ ],
+ )
+ def test_ssim(self, data1, data2, result, mask):
+ """Test ssim."""
+ npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result)
+
+ @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed")
+ def test_ssim_fail(self):
+ """Test ssim."""
+ npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1)
+
+ @pytest.mark.parametrize(
+ ("metric", "data", "result", "mask"),
+ [
+ (metrics.snr, data1, snr_res, None),
+ (metrics.snr, data1, snr_res, mask),
+ (metrics.psnr, data1, psnr_res, None),
+ (metrics.psnr, data1, psnr_res, mask),
+ (metrics.mse, data1, mse_res, None),
+ (metrics.mse, data1, mse_res, mask),
+ (metrics.nrmse, data1, nrmse_res, None),
+ (metrics.nrmse, data1, nrmse_res, mask),
+ failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError),
+ ],
+ )
+ def test_metric(self, metric, data, result, mask):
+ """Test snr."""
+ npt.assert_almost_equal(metric(data, data**2, mask=mask), result)
+
+
+class TestStats:
+ """Test stats module."""
+
+ array33 = np.arange(9).reshape(3, 3)
+ array233 = np.arange(18).reshape(2, 3, 3)
+
+ @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed")
+ @pytest.mark.parametrize(
+ ("norm", "result"),
+ [
+ (
+ "max",
+ np.array(
+ [
+ [0.36787944, 0.60653066, 0.36787944],
+ [0.60653066, 1.0, 0.60653066],
+ [0.36787944, 0.60653066, 0.36787944],
+ ]
+ ),
+ ),
+ (
+ "sum",
+ np.array(
+ [
+ [0.07511361, 0.1238414, 0.07511361],
+ [0.1238414, 0.20417996, 0.1238414],
+ [0.07511361, 0.1238414, 0.07511361],
+ ]
+ ),
+ ),
+ failparam("fail", None, raises=ValueError),
+ ],
+ )
+ def test_gaussian_kernel(self, norm, result):
+ """Test Gaussian kernel."""
+ npt.assert_allclose(
+ stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result
+ )
+
+ @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed")
+ def test_import_astropy(self):
+ """Test missing astropy."""
+ npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1)
+
+ def test_mad(self):
+ """Test mad."""
+ npt.assert_equal(stats.mad(self.array33), 2.0)
+
+ def test_sigma_mad(self):
+ """Test sigma_mad."""
+ npt.assert_almost_equal(
+ stats.sigma_mad(self.array33),
+ 2.9651999999999998,
+ )
+
+ @pytest.mark.parametrize(
+ ("data1", "data2", "method", "result"),
+ [
+ (array33, array33 + 2, "starck", 12.041199826559248),
+ failparam(array33, array33, "fail", 0, raises=ValueError),
+ (array33, array33 + 2, "wiki", 42.110203695399477),
+ ],
+ )
+ def test_psnr(self, data1, data2, method, result):
+ """Test PSNR."""
+ npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result)
+
+ def test_psnr_stack(self):
+ """Test psnr stack."""
+ npt.assert_almost_equal(
+ stats.psnr_stack(self.array233, self.array233 + 2),
+ 12.041199826559248,
+ )
+
+ npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33)
diff --git a/tests/test_opt.py b/tests/test_opt.py
new file mode 100644
index 00000000..2ea58c27
--- /dev/null
+++ b/tests/test_opt.py
@@ -0,0 +1,572 @@
+"""UNIT TESTS FOR OPT.
+
+This module contains tests for the modopt.opt module.
+
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
+
+import numpy as np
+import numpy.testing as npt
+import pytest
+from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref
+
+from modopt.opt import cost, gradient, linear, proximity, reweight
+
+from test_helpers import Dummy
+
+SKLEARN_AVAILABLE = True
+try:
+ import sklearn
+except ImportError:
+ SKLEARN_AVAILABLE = False
+
+PTWT_AVAILABLE = True
+try:
+ import ptwt
+ import cupy
+except ImportError:
+ PTWT_AVAILABLE = False
+
+PYWT_AVAILABLE = True
+try:
+ import pywt
+ import joblib
+except ImportError:
+ PYWT_AVAILABLE = False
+
+rng = np.random.default_rng()
+
+
+# Basic functions to be used as operators or as dummy functions
+def func_identity(x_val, *args, **kwargs):
+ """Return x."""
+ return x_val
+
+
+def func_double(x_val, *args, **kwargs):
+ """Double x."""
+ return x_val * 2
+
+
+def func_sq(x_val, *args, **kwargs):
+ """Square x."""
+ return x_val**2
+
+
+def func_cube(x_val, *args, **kwargs):
+ """Cube x."""
+ return x_val**3
+
+
+@case(tags="cost")
+@parametrize(
+ ("cost_interval", "n_calls", "converged"),
+ [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)],
+)
+def case_cost_op(cost_interval, n_calls, converged):
+ """Case function for costs."""
+ dummy_inst1 = Dummy()
+ dummy_inst1.cost = func_sq
+ dummy_inst2 = Dummy()
+ dummy_inst2.cost = func_cube
+
+ cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval)
+
+ for _ in range(n_calls + 1):
+ cost_obj.get_cost(2)
+ return cost_obj, converged
+
+
+@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost")
+def test_costs(cost_obj, converged):
+ """Test cost."""
+ npt.assert_equal(cost_obj.get_cost(2), converged)
+ if cost_obj._cost_interval:
+ npt.assert_equal(cost_obj.cost, 12)
+
+
+def test_raise_cost():
+ """Test error raising for cost."""
+ npt.assert_raises(TypeError, cost.costObj, 1)
+ npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()])
+
+
+@case(tags="grad")
+@parametrize(call=("op", "trans_op", "trans_op_op"))
+def case_grad_parent(call):
+ """Case for gradient parent."""
+ input_data = np.arange(9).reshape(3, 3)
+ callables = {
+ "op": func_sq,
+ "trans_op": func_cube,
+ "get_grad": func_identity,
+ "cost": lambda input_val: 1.0,
+ }
+
+ grad_op = gradient.GradParent(
+ input_data,
+ **callables,
+ data_type=np.floating,
+ )
+ if call != "trans_op_op":
+ result = callables[call](input_data)
+ else:
+ result = callables["trans_op"](callables["op"](input_data))
+
+ grad_call = getattr(grad_op, call)(input_data)
+ return grad_call, result
+
+
+@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad")
+def test_grad_op(grad_values, result):
+ """Test Gradient operator."""
+ npt.assert_equal(grad_values, result)
+
+
+@pytest.fixture
+def grad_basic():
+ """Case for GradBasic."""
+ input_data = np.arange(9).reshape(3, 3)
+ grad_op = gradient.GradBasic(
+ input_data,
+ func_sq,
+ func_cube,
+ verbose=True,
+ )
+ grad_op.get_grad(input_data)
+ return grad_op
+
+
+def test_grad_basic(grad_basic):
+ """Test grad basic."""
+ npt.assert_array_equal(
+ grad_basic.grad,
+ np.array(
+ [
+ [0, 0, 8.0],
+ [2.16000000e2, 1.72800000e3, 8.0e3],
+ [2.70000000e4, 7.40880000e4, 1.75616000e5],
+ ]
+ ),
+ err_msg="Incorrect gradient.",
+ )
+
+
+def test_grad_basic_cost(grad_basic):
+ """Test grad_basic cost."""
+ npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0)
+
+
+def test_grad_op_raises():
+ """Test raise error."""
+ npt.assert_raises(
+ TypeError,
+ gradient.GradParent,
+ 1,
+ func_sq,
+ func_cube,
+ )
+
+
+#############
+# LINEAR OP #
+#############
+
+
+class LinearCases:
+ """Linear operator cases."""
+
+ def case_linear_identity(self):
+ """Case linear operator identity."""
+ linop = linear.Identity()
+
+ data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1
+
+ return linop, data_op, data_adj_op, res_op, res_adj_op
+
+ def case_linear_wavelet_convolve(self):
+ """Case linear operator wavelet."""
+ linop = linear.WaveletConvolve(
+ filters=np.arange(8).reshape(2, 2, 2).astype(float)
+ )
+ data_op = np.arange(4).reshape(1, 2, 2).astype(float)
+ data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float)
+ res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]])
+ res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]])
+
+ return linop, data_op, data_adj_op, res_op, res_adj_op
+
+ @parametrize(
+ compute_backend=[
+ pytest.param(
+ "numpy",
+ marks=pytest.mark.skipif(
+ not PYWT_AVAILABLE, reason="PyWavelet not available."
+ ),
+ ),
+ pytest.param(
+ "cupy",
+ marks=pytest.mark.skipif(
+ not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."
+ ),
+ ),
+ ]
+ )
+ def case_linear_wavelet_transform(self, compute_backend):
+ """Case linear wavelet operator."""
+ linop = linear.WaveletTransform(
+ wavelet_name="haar",
+ shape=(8, 8),
+ level=2,
+ )
+ data_op = np.arange(64).reshape(8, 8).astype(float)
+ res_op, slices, shapes = pywt.ravel_coeffs(
+ pywt.wavedecn(data_op, "haar", level=2)
+ )
+ data_adj_op = linop.op(data_op)
+ res_adj_op = pywt.waverecn(
+ pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar"
+ )
+ return linop, data_op, data_adj_op, res_op, res_adj_op
+
+ @parametrize(weights=[[1.0, 1.0], None])
+ def case_linear_combo(self, weights):
+ """Case linear operator combo with weights."""
+ parent = linear.LinearParent(
+ func_sq,
+ func_cube,
+ )
+ linop = linear.LinearCombo([parent, parent], weights)
+
+ data_op, data_adj_op, res_op, res_adj_op = (
+ 2,
+ np.array([2, 2]),
+ np.array([4, 4]),
+ 8.0 * (2 if weights else 1),
+ )
+
+ return linop, data_op, data_adj_op, res_op, res_adj_op
+
+ @parametrize(factor=[1, 1 + 1j])
+ def case_linear_matrix(self, factor):
+ """Case linear operator from matrix."""
+ linop = linear.MatrixOperator(np.eye(5) * factor)
+ data_op = np.arange(5)
+ data_adj_op = np.arange(5)
+ res_op = np.arange(5) * factor
+ res_adj_op = np.arange(5) * np.conjugate(factor)
+
+ return linop, data_op, data_adj_op, res_op, res_adj_op
+
+
+@fixture
+@parametrize_with_cases(
+ "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases
+)
+def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op):
+ """Get adj_op relative data."""
+ return linop.adj_op, data_adj_op, res_adj_op
+
+
+@fixture
+@parametrize_with_cases(
+ "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases
+)
+def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op):
+ """Get op relative data."""
+ return linop.op, data_op, res_op
+
+
+@parametrize(
+ ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)]
+)
+def test_linear_operator(action, data, result):
+ """Test linear operator."""
+ npt.assert_almost_equal(action(data), result)
+
+
+dummy_with_op = Dummy()
+dummy_with_op.op = lambda x: x
+
+
+@pytest.mark.parametrize(
+ ("args", "error"),
+ [
+ ([linear.LinearParent(func_sq, func_cube)], TypeError),
+ ([[]], ValueError),
+ ([[Dummy()]], ValueError),
+ ([[dummy_with_op]], ValueError),
+ ([[]], ValueError),
+ ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError),
+ ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError),
+ ],
+)
+def test_linear_combo_errors(args, error):
+ """Test linear combo_errors."""
+ npt.assert_raises(error, linear.LinearCombo, *args)
+
+
+#############
+# Proximity #
+#############
+
+
+class ProxCases:
+ """Class containing all proximal operator cases.
+
+ Each case should return 4 parameters:
+ 1. The proximal operator
+ 2. test input data
+ 3. Expected result data
+ 4. Expected cost value.
+ """
+
+ weights = np.ones(9).reshape(3, 3).astype(float) * 3
+ array33 = np.arange(9).reshape(3, 3).astype(float)
+ array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]])
+ array33_st2 = array33_st * -1
+
+ array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]])
+
+ array233 = np.arange(18).reshape(2, 3, 3).astype(float)
+ array233_2 = np.array(
+ [
+ [
+ [2.73843189, 3.14594066, 3.55344943],
+ [3.9609582, 4.36846698, 4.77597575],
+ [5.18348452, 5.59099329, 5.99850206],
+ ],
+ [
+ [8.07085295, 9.2718846, 10.47291625],
+ [11.67394789, 12.87497954, 14.07601119],
+ [15.27704284, 16.47807449, 17.67910614],
+ ],
+ ]
+ )
+ array233_3 = np.array(
+ [
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
+ [
+ [4.00795282, 4.60438026, 5.2008077],
+ [5.79723515, 6.39366259, 6.99009003],
+ [7.58651747, 8.18294492, 8.77937236],
+ ],
+ ]
+ )
+
+ def case_prox_parent(self):
+ """Case prox parent."""
+ return (
+ proximity.ProximityParent(
+ func_sq,
+ func_double,
+ ),
+ 3,
+ 9,
+ 6,
+ )
+
+ def case_prox_identity(self):
+ """Case prox identity."""
+ return proximity.IdentityProx(), 3, 3, 0
+
+ def case_prox_positivity(self):
+ """Case prox positivity."""
+ return proximity.Positivity(), -3, 0, 0
+
+ def case_prox_sparsethresh(self):
+ """Case prox sparsethreshosld."""
+ return (
+ proximity.SparseThreshold(linear.Identity(), weights=self.weights),
+ self.array33,
+ self.array33_st,
+ 108,
+ )
+
+ @parametrize(
+ "lowr_type, initial_rank, operator, result, cost",
+ [
+ ("standard", None, None, array233_2, 469.3913294246498),
+ ("standard", 1, None, array233_2, 469.3913294246498),
+ ("ngole", None, func_double, array233_3, 469.3913294246498),
+ ],
+ )
+ def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost):
+ """Case prox lowrank."""
+ return (
+ proximity.LowRankMatrix(
+ 10,
+ lowr_type=lowr_type,
+ initial_rank=initial_rank,
+ operator=operator,
+ thresh_type="hard" if lowr_type == "standard" else "soft",
+ ),
+ self.array233,
+ result,
+ cost,
+ )
+
+ def case_prox_linear_comp(self):
+ """Case prox linear comp."""
+ return (
+ proximity.LinearCompositionProx(
+ linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0]
+ ),
+ self.array33,
+ self.array33_st,
+ 108,
+ )
+
+ def case_prox_ridge(self):
+ """Case prox ridge."""
+ return (
+ proximity.Ridge(linear.Identity(), self.weights),
+ self.array33 * (1 + 1j),
+ self.array33 * (1 + 1j) / 7,
+ 1224,
+ )
+
+ @parametrize("alpha, beta", [(0, weights), (weights, 0)])
+ def case_prox_elasticnet(self, alpha, beta):
+ """Case prox elastic net."""
+ if np.all(alpha == 0):
+ data = self.case_prox_sparsethresh()[1:]
+ else:
+ data = self.case_prox_ridge()[1:]
+ return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data)
+
+ @parametrize(
+ "beta, k_value, data, result, cost",
+ [
+ (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2),
+ (3, 5, array33.flatten(), array33_support.flatten(), 684.0),
+ (
+ 6.0,
+ 9,
+ array33.flatten() * (1 + 1j),
+ array33.flatten() * (1 + 1j) / 7,
+ 1224,
+ ),
+ ],
+ )
+ def case_prox_Ksupport(self, beta, k_value, data, result, cost):
+ """Case prox K-support norm."""
+ return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost)
+
+ @parametrize(use_weights=[True, False])
+ def case_prox_grouplasso(self, use_weights):
+ """Case GroupLasso proximity."""
+ if use_weights:
+ weights = np.tile(self.weights, (4, 1, 1))
+ else:
+ weights = np.tile(np.zeros((3, 3)), (4, 1, 1))
+
+ random_data = 3 * rng.random(weights[0].shape)
+ random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1))
+ if use_weights:
+ gl_result_data = 2 * random_data_tile - 3
+ gl_result_data = (
+ np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2
+ )
+ cost = np.sum(random_data_tile) * 6
+ else:
+ gl_result_data = random_data_tile
+ cost = 0
+ return (
+ proximity.GroupLASSO(
+ weights=weights,
+ ),
+ random_data_tile,
+ gl_result_data,
+ cost,
+ )
+
+ @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.")
+ def case_prox_owl(self):
+ """Case prox for Ordered Weighted L1 Norm."""
+ return (
+ proximity.OrderedWeightedL1Norm(self.weights.flatten()),
+ self.array33.flatten(),
+ self.array33_st.flatten(),
+ 108.0,
+ )
+
+
+@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases)
+def test_prox_op(operator, input_data, op_result, cost_result):
+ """Test proximity operator op."""
+ npt.assert_almost_equal(operator.op(input_data), op_result)
+
+
+@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases)
+def test_prox_cost(operator, input_data, op_result, cost_result):
+ """Test proximity operator cost."""
+ npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result)
+
+
+@parametrize(
+ "arg, error",
+ [
+ (1, TypeError),
+ ([], ValueError),
+ ([Dummy()], ValueError),
+ ([dummy_with_op], ValueError),
+ ],
+)
+def test_error_prox_combo(arg, error):
+ """Test errors for proximity combo."""
+ npt.assert_raises(error, proximity.ProximityCombo, arg)
+
+
+@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed")
+def test_fail_sklearn():
+ """Test fail OWL with sklearn."""
+ npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1)
+
+
+@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.")
+def test_fail_owl():
+ """Test errors for Ordered Weighted L1 Norm."""
+ npt.assert_raises(
+ ValueError,
+ proximity.OrderedWeightedL1Norm,
+ np.arange(10),
+ )
+
+ npt.assert_raises(
+ ValueError,
+ proximity.OrderedWeightedL1Norm,
+ -np.arange(10),
+ )
+
+
+def test_fail_lowrank():
+ """Test fail for lowrank."""
+ prox_op = proximity.LowRankMatrix(10, lowr_type="fail")
+ npt.assert_raises(ValueError, prox_op.op, 0)
+
+
+def test_fail_Ksupport_norm():
+ """Test fail for K-support norm."""
+ npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
+
+
+def test_reweight():
+ """Test for reweight module."""
+ data1 = np.arange(9).reshape(3, 3).astype(float) + 1
+ data2 = np.array(
+ [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]],
+ )
+
+ rw = reweight.cwbReweight(data1)
+ rw.reweight(data1)
+
+ npt.assert_array_equal(
+ rw.weights,
+ data2,
+ err_msg="Incorrect CWB re-weighting.",
+ )
+
+ npt.assert_raises(ValueError, rw.reweight, data1[0])
diff --git a/tests/test_signal.py b/tests/test_signal.py
new file mode 100644
index 00000000..6dbb0bba
--- /dev/null
+++ b/tests/test_signal.py
@@ -0,0 +1,327 @@
+"""UNIT TESTS FOR SIGNAL.
+
+This module contains unit tests for the modopt.signal module.
+
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
+
+import numpy as np
+import numpy.testing as npt
+import pytest
+from test_helpers import failparam
+
+from modopt.signal import filter, noise, positivity, svd, validation, wavelet
+
+
+class TestFilter:
+ """Test filter module."""
+
+ @pytest.mark.parametrize(
+ ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
+ )
+ def test_gaussian_filter(self, norm, result):
+ """Test gaussian filter."""
+ npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
+
+ def test_mex_hat(self):
+ """Test mexican hat filter."""
+ npt.assert_almost_equal(
+ filter.mex_hat(2, 1),
+ -0.35213905225713371,
+ )
+
+ def test_mex_hat_dir(self):
+ """Test directional mexican hat filter."""
+ npt.assert_almost_equal(
+ filter.mex_hat_dir(1, 2, 1),
+ 0.17606952612856686,
+ )
+
+
+class TestNoise:
+ """Test noise module."""
+
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = np.array(
+ [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]],
+ )
+ data3 = np.array(
+ [
+ [0.3455842, 1.8216181, 2.3304371],
+ [1.6968428, 4.9053559, 5.4463746],
+ [5.4630468, 7.5811181, 8.3645724],
+ ]
+ )
+ data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
+ data5 = np.array(
+ [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
+ )
+
+ @pytest.mark.parametrize(
+ ("data", "noise_type", "sigma", "data_noise"),
+ [
+ (data1, "poisson", 1, data2),
+ (data1, "gauss", 1, data3),
+ (data1, "gauss", (1, 1, 1), data3),
+ failparam(data1, "fail", 1, data1, raises=ValueError),
+ ],
+ )
+ def test_add_noise(self, data, noise_type, sigma, data_noise):
+ """Test add_noise."""
+ rng = np.random.default_rng(1)
+ npt.assert_almost_equal(
+ noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng),
+ data_noise,
+ )
+
+ @pytest.mark.parametrize(
+ ("threshold_type", "result"),
+ [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)],
+ )
+ def test_thresh(self, threshold_type, result):
+ """Test threshold."""
+ npt.assert_array_equal(
+ noise.thresh(self.data1, 5, threshold_type=threshold_type), result
+ )
+
+
+class TestPositivity:
+ """Test positivity module."""
+
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
+ data5 = np.array(
+ [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
+ )
+
+ @pytest.mark.parametrize(
+ ("value", "expected"),
+ [
+ (-1.0, -float(0)),
+ (-1, 0),
+ (data1 - 5, data5),
+ (
+ np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object),
+ np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object),
+ ),
+ failparam("-1", None, raises=TypeError),
+ ],
+ )
+ def test_positive(self, value, expected):
+ """Test positive."""
+ if isinstance(value, np.ndarray) and value.dtype == "O":
+ for v, e in zip(positivity.positive(value), expected):
+ npt.assert_array_equal(v, e)
+ else:
+ npt.assert_array_equal(positivity.positive(value), expected)
+
+
+class TestSVD:
+ """Test for svd module."""
+
+ @pytest.fixture
+ def data(self):
+ """Initialize test data."""
+ data1 = np.arange(18).reshape(9, 2).astype(float)
+ data2 = np.arange(32).reshape(16, 2).astype(float)
+ data3 = np.array(
+ [
+ np.array(
+ [
+ [-0.01744594, -0.61438865],
+ [-0.08435304, -0.50397984],
+ [-0.15126014, -0.39357102],
+ [-0.21816724, -0.28316221],
+ [-0.28507434, -0.17275339],
+ [-0.35198144, -0.06234457],
+ [-0.41888854, 0.04806424],
+ [-0.48579564, 0.15847306],
+ [-0.55270274, 0.26888188],
+ ]
+ ),
+ np.array([42.23492742, 1.10041151]),
+ np.array(
+ [
+ [-0.67608034, -0.73682791],
+ [0.73682791, -0.67608034],
+ ]
+ ),
+ ],
+ dtype=object,
+ )
+ data4 = np.array(
+ [
+ [-1.05426832e-16, 1.0],
+ [2.0, 3.0],
+ [4.0, 5.0],
+ [6.0, 7.0],
+ [8.0, 9.0],
+ [1.0e1, 1.1e1],
+ [1.2e1, 1.3e1],
+ [1.4e1, 1.5e1],
+ [1.6e1, 1.7e1],
+ ]
+ )
+
+ data5 = np.array(
+ [
+ [0.49815487, 0.54291537],
+ [2.40863386, 2.62505584],
+ [4.31911286, 4.70719631],
+ [6.22959185, 6.78933678],
+ [8.14007085, 8.87147725],
+ [10.05054985, 10.95361772],
+ [11.96102884, 13.03575819],
+ [13.87150784, 15.11789866],
+ [15.78198684, 17.20003913],
+ ]
+ )
+ return (data1, data2, data3, data4, data5)
+
+ @pytest.fixture
+ def svd0(self, data):
+ """Compute SVD of first data sample."""
+ return svd.calculate_svd(data[0])
+
+ def test_find_n_pc(self, data):
+ """Test find number of principal component."""
+ npt.assert_equal(
+ svd.find_n_pc(svd.svd(data[1])[0]),
+ 2,
+ err_msg="Incorrect number of principal components.",
+ )
+
+ def test_n_pc_fail_non_square(self):
+ """Test find_n_pc."""
+ npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3))
+
+ def test_calculate_svd(self, data, svd0):
+ """Test calculate_svd."""
+ errors = []
+ for i, name in enumerate("USV"):
+ try:
+ npt.assert_almost_equal(svd0[i], data[2][i])
+ except AssertionError:
+ errors.append(name)
+ if errors:
+ raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors))
+
+ @pytest.mark.parametrize(
+ ("n_pc", "idx_res"),
+ [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)],
+ )
+ def test_svd_thresh(self, data, n_pc, idx_res):
+ """Test svd_tresh."""
+ npt.assert_almost_equal(
+ svd.svd_thresh(data[0], n_pc=n_pc),
+ data[idx_res],
+ )
+
+ def test_svd_tresh_invalid_type(self):
+ """Test svd_tresh failure."""
+ npt.assert_raises(TypeError, svd.svd_thresh, 1)
+
+ @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)])
+ def test_svd_thresh_coef(self, data, operator):
+ """Test svd_tresh_coef."""
+ npt.assert_almost_equal(
+ svd.svd_thresh_coef(data[0], operator, 0),
+ data[0],
+ err_msg="Incorrect SVD coefficient tresholding",
+ )
+
+ # TODO test_svd_thresh_coef_fast
+
+
+class TestValidation:
+ """Test validation Module."""
+
+ array33 = np.arange(9).reshape(3, 3)
+
+ def test_transpose_test(self):
+ """Test transpose_test."""
+ npt.assert_equal(
+ validation.transpose_test(
+ lambda x_val, y_val: x_val.dot(y_val),
+ lambda x_val, y_val: x_val.dot(y_val.T),
+ self.array33.shape,
+ x_args=self.array33,
+ rng=2,
+ ),
+ None,
+ )
+
+
+class TestWavelet:
+ """Test Wavelet Module."""
+
+ @pytest.fixture
+ def data(self):
+ """Set test parameter values."""
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = np.arange(36).reshape(4, 3, 3).astype(float)
+ data3 = np.array(
+ [
+ [
+ [6.0, 20, 26.0],
+ [36.0, 84.0, 84.0],
+ [90, 164.0, 134.0],
+ ],
+ [
+ [78.0, 155.0, 134.0],
+ [225.0, 408.0, 327.0],
+ [270, 461.0, 350],
+ ],
+ [
+ [150, 290, 242.0],
+ [414.0, 732.0, 570],
+ [450, 758.0, 566.0],
+ ],
+ [
+ [222.0, 425.0, 350],
+ [603.0, 1056.0, 813.0],
+ [630, 1055.0, 782.0],
+ ],
+ ]
+ )
+
+ data4 = np.array(
+ [
+ [6496.0, 9796.0, 6544.0],
+ [9924.0, 14910, 9924.0],
+ [6544.0, 9796.0, 6496.0],
+ ]
+ )
+
+ data5 = np.array(
+ [
+ [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]],
+ [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]],
+ [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]],
+ ]
+ )
+ return (data1, data2, data3, data4, data5)
+
+ @pytest.mark.parametrize(
+ ("idx_data", "idx_filter", "idx_res", "filter_rot"),
+ [(0, 1, 2, False), (1, 1, 3, True)],
+ )
+ def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot):
+ """Test filter_convolve."""
+ npt.assert_almost_equal(
+ wavelet.filter_convolve(
+ data[idx_data], data[idx_filter], filter_rot=filter_rot
+ ),
+ data[idx_res],
+ err_msg="Inccorect filter comvolution.",
+ )
+
+ def test_filter_convolve_stack(self, data):
+ """Test filter_convolve_stack."""
+ npt.assert_almost_equal(
+ wavelet.filter_convolve_stack(data[0], data[0]),
+ data[4],
+ err_msg="Inccorect filter stack comvolution.",
+ )