diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml index 1e49f8bc..21b6cc3a 100644 --- a/.github/workflows/cd-build.yml +++ b/.github/workflows/cd-build.yml @@ -14,32 +14,28 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Set up Conda with Python 3.8 - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/setup-python@v4 with: - auto-update-conda: true - python-version: 3.8 - auto-activate-base: false + python-version: "3.10" + cache: pip - name: Install dependencies shell: bash -l {0} run: | python -m pip install --upgrade pip - python -m pip install -r develop.txt python -m pip install twine - python -m pip install . + python -m pip install .[doc,test] - name: Run Tests shell: bash -l {0} run: | - python setup.py test + pytest - name: Check distribution shell: bash -l {0} run: | - python setup.py sdist twine check dist/* - name: Upload coverage to Codecov @@ -57,22 +53,15 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Set up Conda with Python 3.8 - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: 3.8 - auto-activate-base: false - name: Install dependencies shell: bash -l {0} run: | conda install -c conda-forge pandoc python -m pip install --upgrade pip - python -m pip install -r docs/requirements.txt - python -m pip install . + python -m pip install .[doc] - name: Build API documentation shell: bash -l {0} diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 3ffcb6f4..3a209d12 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -16,70 +16,41 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: [3.8] + python-version: ["3.8", "3.9", "3.10"] steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Report WPS Errors - uses: wemake-services/wemake-python-styleguide@0.14.1 - continue-on-error: true - with: - reporter: 'github-pr-review' - path: './modopt' - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Conda with Python ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 with: - auto-update-conda: true python-version: ${{ matrix.python-version }} - auto-activate-base: false - - - name: Check Conda - shell: bash -l {0} - run: | - conda info - conda list - python --version + cache: pip - name: Install Dependencies shell: bash -l {0} run: | python --version python -m pip install --upgrade pip - python -m pip install -r develop.txt - python -m pip install -r docs/requirements.txt - python -m pip install astropy scikit-image scikit-learn - python -m pip install tensorflow>=2.4.1 - python -m pip install twine - python -m pip install . + python -m pip install .[test] + python -m pip install astropy scikit-image scikit-learn matplotlib + python -m pip install tensorflow>=2.4.1 torch - name: Run Tests shell: bash -l {0} run: | - export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test + pytest -n 2 - name: Save Test Results if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }} - path: pytest.xml - - - name: Check Distribution - shell: bash -l {0} - run: | - python setup.py sdist - twine check dist/* + path: coverage.xml - name: Check API Documentation build shell: bash -l {0} run: | - conda install -c conda-forge pandoc + apt install pandoc + pip install .[doc] ipykernel sphinx-apidoc -t docs/_templates -feTMo docs/source modopt sphinx-build -b doctest -E docs/source docs/_build @@ -90,38 +61,3 @@ jobs: file: coverage.xml flags: unittests - test-basic: - name: Basic Test Suite - runs-on: ${{ matrix.os }} - - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest] - python-version: [3.6, 3.7, 3.9] - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set up Conda with Python ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: ${{ matrix.python-version }} - auto-activate-base: false - - - name: Install Dependencies - shell: bash -l {0} - run: | - python --version - python -m pip install --upgrade pip - python -m pip install -r develop.txt - python -m pip install astropy scikit-image scikit-learn - python -m pip install . - - - name: Run Tests - shell: bash -l {0} - run: | - export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml new file mode 100644 index 00000000..45fc2b23 --- /dev/null +++ b/.github/workflows/style.yml @@ -0,0 +1,38 @@ +name: Style checking + +on: + push: + branches: [ "master", "main", "develop" ] + pull_request: + branches: [ "master", "main", "develop" ] + + workflow_dispatch: + +env: + PYTHON_VERSION: "3.10" + +jobs: + linter-check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Install Python deps + shell: bash + run: | + python -m pip install --upgrade pip + python -m pip install -e .[test,dev] + + - name: Black Check + shell: bash + run: black . --diff --color --check + + - name: ruff Check + shell: bash + run: ruff check diff --git a/.gitignore b/.gitignore index 06dff8db..f9eaaa68 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,7 @@ instance/ docs/_build/ docs/source/fortuna.* docs/source/scripts.* +docs/source/auto_examples/ docs/source/*.nblink # PyBuilder diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 3ac9aef9..00000000 --- a/.pylintrc +++ /dev/null @@ -1,2 +0,0 @@ -[MASTER] -ignore-patterns=**/docs/**/*.py diff --git a/.pyup.yml b/.pyup.yml deleted file mode 100644 index 8fdac7ff..00000000 --- a/.pyup.yml +++ /dev/null @@ -1,14 +0,0 @@ -# autogenerated pyup.io config file -# see https://pyup.io/docs/configuration/ for all available options - -schedule: '' -update: all -label_prs: update -assignees: sfarrens -requirements: - - requirements.txt: - pin: False - - develop.txt: - pin: False - - docs/requirements.txt: - pin: True diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 9a2f374e..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,5 +0,0 @@ -include requirements.txt -include develop.txt -include docs/requirements.txt -include README.rst -include LICENSE.txt diff --git a/README.md b/README.md index acb316ad..223d0b73 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ All packages required by ModOpt should be installed automatically. Optional pack In order to run the code in this repository the following packages must be installed: -* [Python](https://www.python.org/) [> 3.6] +* [Python](https://www.python.org/) [> 3.7] * [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0] * [Numpy](http://www.numpy.org/) [==1.19.5] * [Scipy](http://www.scipy.org/) [==1.5.4] -* [Progressbar 2](https://progressbar-2.readthedocs.io/) [==3.53.1] +* [tqdm](https://tqdm.github.io/) [>=4.64.0] ### Optional Packages diff --git a/develop.txt b/develop.txt deleted file mode 100644 index 8beef0ff..00000000 --- a/develop.txt +++ /dev/null @@ -1,9 +0,0 @@ -coverage>=5.5 -flake8>=4 -nose>=1.3.7 -pytest>=6.2.2 -pytest-cov>=2.11.1 -pytest-pep8>=1.0.6 -pytest-emoji>=0.2.0 -pytest-flake8>=1.0.7 -wemake-python-styleguide>=0.15.2 diff --git a/docs/requirements.txt b/docs/requirements.txt index 4d2a14fb..c9e29c88 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ numpydoc==1.1.0 sphinx==4.3.1 sphinxcontrib-bibtex==2.4.1 sphinxawesome-theme==3.2.1 +sphinx-gallery==0.11.1 diff --git a/docs/source/conf.py b/docs/source/conf.py index fb954f6d..69921008 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Python Template sphinx config # Import relevant modules @@ -9,55 +8,53 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- General configuration ------------------------------------------------ # General information about the project. -project = 'modopt' +project = "modopt" mdata = metadata(project) -author = mdata['Author'] -version = mdata['Version'] -copyright = '2020, {}'.format(author) -gh_user = 'sfarrens' - -# If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '3.3' +author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu" +version = mdata["Version"] +copyright = f"2020, {author}" +gh_user = "sfarrens" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.ifconfig', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.todo', - 'sphinx.ext.viewcode', - 'sphinxawesome_theme', - 'sphinxcontrib.bibtex', - 'myst_parser', - 'nbsphinx', - 'nbsphinx_link', - 'numpydoc', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinxawesome_theme.highlighting", + "sphinxcontrib.bibtex", + "myst_parser", + "nbsphinx", + "nbsphinx_link", + "numpydoc", + "sphinx_gallery.gen_gallery", ] # Include module names for objects add_module_names = False # Set class documentation standard. -autoclass_content = 'class' +autoclass_content = "class" # Audodoc options autodoc_default_options = { - 'member-order': 'bysource', - 'private-members': True, - 'show-inheritance': True + "member-order": "bysource", + "private-members": True, + "show-inheritance": True, } # Generate summaries @@ -68,17 +65,17 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. show_authors = True # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'default' +pygments_style = "default" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -87,7 +84,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinxawesome_theme' +html_theme = "sphinxawesome_theme" # html_theme = 'sphinx_book_theme' # Theme options are theme-specific and customize the look and feel of a theme @@ -100,11 +97,10 @@ "breadcrumbs_separator": "/", "show_prev_next": True, "show_scrolltop": True, - } html_collapsible_definitions = True html_awesome_headerlinks = True -html_logo = 'modopt_logo.jpg' +html_logo = "modopt_logo.png" html_permalinks_icon = ( '' @@ -117,7 +113,7 @@ ) # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -html_title = '{0} v{1}'.format(project, version) +html_title = f"{project} v{version}" # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None @@ -133,7 +129,7 @@ # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -html_last_updated_fmt = '%d %b, %Y' +html_last_updated_fmt = "%d %b, %Y" # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. @@ -145,16 +141,26 @@ # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. html_show_copyright = True + +# -- Options for Sphinx Gallery ---------------------------------------------- + +sphinx_gallery_conf = { + "examples_dirs": ["../../examples/"], + "filename_pattern": "/example_", + "ignore_pattern": r"/(__init__|conftest)\.py", +} + + # -- Options for nbshpinx output ------------------------------------------ # Custom fucntion to find notebooks, create .nblink files and update the # notebooks.rst file -def add_notebooks(nb_path='../../notebooks'): +def add_notebooks(nb_path="../../notebooks"): - print('Looking for notebooks') - nb_ext = '.ipynb' - nb_rst_file_name = 'notebooks.rst' + print("Looking for notebooks") + nb_ext = ".ipynb" + nb_rst_file_name = "notebooks.rst" nb_link_format = '{{\n "path": "{0}/{1}"\n}}' nbs = sorted([nb for nb in os.listdir(nb_path) if nb.endswith(nb_ext)]) @@ -163,21 +169,21 @@ def add_notebooks(nb_path='../../notebooks'): nb_name = nb.rstrip(nb_ext) - nb_link_file_name = nb_name + '.nblink' - print('Writing {0}'.format(nb_link_file_name)) - with open(nb_link_file_name, 'w') as nb_link_file: + nb_link_file_name = nb_name + ".nblink" + print(f"Writing {nb_link_file_name}") + with open(nb_link_file_name, "w") as nb_link_file: nb_link_file.write(nb_link_format.format(nb_path, nb)) - print('Looking for {0} in {1}'.format(nb_name, nb_rst_file_name)) - with open(nb_rst_file_name, 'r') as nb_rst_file: + print(f"Looking for {nb_name} in {nb_rst_file_name}") + with open(nb_rst_file_name) as nb_rst_file: check_name = nb_name not in nb_rst_file.read() if check_name: - print('Adding {0} to {1}'.format(nb_name, nb_rst_file_name)) - with open(nb_rst_file_name, 'a') as nb_rst_file: + print(f"Adding {nb_name} to {nb_rst_file_name}") + with open(nb_rst_file_name, "a") as nb_rst_file: if list_pos == 0: - nb_rst_file.write('\n') - nb_rst_file.write(' {0}\n'.format(nb_name)) + nb_rst_file.write("\n") + nb_rst_file.write(f" {nb_name}\n") return nbs @@ -185,13 +191,13 @@ def add_notebooks(nb_path='../../notebooks'): # Add notebooks add_notebooks() -binder = 'https://mybinder.org/v2/gh' -binder_badge = 'https://mybinder.org/badge_logo.svg' -github = 'https://github.com/' -github_badge = 'https://badgen.net/badge/icon/github?icon=github&label' +binder = "https://mybinder.org/v2/gh" +binder_badge = "https://mybinder.org/badge_logo.svg" +github = "https://github.com/" +github_badge = "https://badgen.net/badge/icon/github?icon=github&label" # Remove promts and add binder badge -nb_header_pt1 = r''' +nb_header_pt1 = r""" {% if env.metadata[env.docname]['nbsphinx-link-target'] %} {% set docpath = env.metadata[env.docname]['nbsphinx-link-target'] %} {% else %} @@ -207,18 +213,18 @@ def add_notebooks(nb_path='../../notebooks'): } -''' +""" nb_header_pt2 = ( - r'''

''' - r'''''' + - r'''Binder badge
''' - r'''
GitHub badge'''.format(github_badge) + - r'''

''' + r"""

""" + rf"""""" + + rf"""Binder badge
""" + r"""
GitHub badge""" + + r"""

""" ) nbsphinx_prolog = nb_header_pt1 + nb_header_pt2 @@ -227,28 +233,28 @@ def add_notebooks(nb_path='../../notebooks'): # Refer to the package libraries for type definitions intersphinx_mapping = { - 'python': ('http://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), - 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None), - 'matplotlib': ('https://matplotlib.org', None), - 'astropy': ('http://docs.astropy.org/en/latest/', None), - 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'sklearn': ( - 'http://scikit-learn.org/stable', - (None, './_intersphinx/sklearn-objects.inv') + "python": ("http://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None), + "matplotlib": ("https://matplotlib.org", None), + "astropy": ("http://docs.astropy.org/en/latest/", None), + "cupy": ("https://docs-cupy.chainer.org/en/stable/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "sklearn": ( + "http://scikit-learn.org/stable", + (None, "./_intersphinx/sklearn-objects.inv"), ), - 'tensorflow': ( - 'https://www.tensorflow.org/api_docs/python', + "tensorflow": ( + "https://www.tensorflow.org/api_docs/python", ( - 'https://github.com/GPflow/tensorflow-intersphinx/' - + 'raw/master/tf2_py_objects.inv') - ) - + "https://github.com/GPflow/tensorflow-intersphinx/" + + "raw/master/tf2_py_objects.inv" + ), + ), } # -- BibTeX Setting ---------------------------------------------- -bibtex_bibfiles = ['refs.bib', 'my_ref.bib'] -bibtex_default_style = 'alpha' +bibtex_bibfiles = ["refs.bib", "my_ref.bib"] +bibtex_default_style = "alpha" diff --git a/docs/source/refs.bib b/docs/source/refs.bib index d8365e71..7782ca52 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -207,3 +207,15 @@ @article{zou2005 journal = {Journal of the Royal Statistical Society Series B}, doi = {10.1111/j.1467-9868.2005.00527.x} } + +@article{Goldstein2014, + author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard}, + year={2014}, + month={Jan}, + pages={1588–1623}, + title={Fast Alternating Direction Optimization Methods}, + journal={SIAM Journal on Imaging Sciences}, + volume={7}, + ISSN={1936-4954}, + doi={10/gdwr49}, +} diff --git a/docs/source/toc.rst b/docs/source/toc.rst index 84a6af87..ef5753f5 100644 --- a/docs/source/toc.rst +++ b/docs/source/toc.rst @@ -25,6 +25,7 @@ plugin_example notebooks + auto_examples/index .. toctree:: :hidden: diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..e6ffbe27 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,5 @@ +======== +Examples +======== + +This is a collection of Python scripts demonstrating the use of ModOpt. diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..d7e77357 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,10 @@ +"""EXAMPLES. + +This module contains documented examples that demonstrate the usage of various +ModOpt tools. + +These examples also serve as integration tests for various methods. + +:Author: Pierre-Antoine Comby + +""" diff --git a/examples/conftest.py b/examples/conftest.py new file mode 100644 index 00000000..f3ed371b --- /dev/null +++ b/examples/conftest.py @@ -0,0 +1,49 @@ +"""TEST CONFIGURATION. + +This module contains methods for configuring the testing of the example +scripts. + +:Author: Pierre-Antoine Comby + +Notes +----- +Based on: +https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test + +""" + +from pathlib import Path +import runpy +import pytest + + +def pytest_collect_file(path, parent): + """Pytest hook. + + Create a collector for the given path, or None if not relevant. + The new node needs to have the specified parent as parent. + """ + p = Path(path) + if p.suffix == ".py" and "example" in p.name: + return Script.from_parent(parent, path=p, name=p.name) + + +class Script(pytest.File): + """Script files collected by pytest.""" + + def collect(self): + """Collect the script as its own item.""" + yield ScriptItem.from_parent(self, name=self.name) + + +class ScriptItem(pytest.Item): + """Item script collected by pytest.""" + + def runtest(self): + """Run the script as a test.""" + runpy.run_path(str(self.path)) + + def repr_failure(self, excinfo): + """Return only the error traceback of the script.""" + excinfo.traceback = excinfo.traceback.cut(path=self.path) + return super().repr_failure(excinfo) diff --git a/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py new file mode 100644 index 00000000..f3e5091d --- /dev/null +++ b/examples/example_lasso_forward_backward.py @@ -0,0 +1,153 @@ +""" +Solving the LASSO Problem with the Forward Backward Algorithm. +============================================================== + +This an example to show how to solve an example LASSO Problem +using the Forward-Backward Algorithm. + +In this example we are going to use: + - Modopt Operators (Linear, Gradient, Proximal) + - Modopt implementation of solvers + - Modopt Metric API. +TODO: add reference to LASSO paper. +""" + +import numpy as np +import matplotlib.pyplot as plt + +from modopt.opt.algorithms import ForwardBackward, POGM +from modopt.opt.cost import costObj +from modopt.opt.linear import LinearParent, Identity +from modopt.opt.gradient import GradBasic +from modopt.opt.proximity import SparseThreshold +from modopt.math.matrix import PowerMethod +from modopt.math.stats import mse + +# %% +# Here we create a instance of the LASSO Problem + +BETA_TRUE = np.array( + [3.0, 1.5, 0, 0, 2, 0, 0, 0] +) # 8 original values from lLASSO Paper +DIM = len(BETA_TRUE) + + +rng = np.random.default_rng() +sigma_noise = 1 +obs = 20 +# create a measurement matrix with decaying covariance matrix. +cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM)) +x = rng.multivariate_normal(np.zeros(DIM), cov, obs) + +y = x @ BETA_TRUE +y_noise = y + (sigma_noise * np.random.standard_normal(obs)) + + +# %% +# Next we create Operators for solving the problem. + +# MatrixOperator could also work here. +lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb) +grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op) + +prox_op = SparseThreshold(Identity(), 1, thresh_type="soft") + +# %% +# In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator +# + +calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True) +lip = calc_lips.spec_rad +print("lipschitz constant:", lip) + +# %% +# Solving using FISTA algorithm +# ----------------------------- +# +# TODO: Add description/Reference of FISTA. + +cost_op_fista = costObj([grad_op, prox_op], verbose=False) + +fb_fista = ForwardBackward( + np.zeros(8), + beta_param=1 / lip, + grad=grad_op, + prox=prox_op, + cost=cost_op_fista, + metric_call_period=1, + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. +) + +fb_fista.iterate() + +# %% +# After the run we can have a look at the results + +print(fb_fista.x_final) +mse_fista = mse(fb_fista.x_final, BETA_TRUE) +plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-") +plt.stem(BETA_TRUE, label="reference", linefmt="C1-") +plt.legend() +plt.title(f"FISTA Estimation MSE={mse_fista:.4f}") + +# sphinx_gallery_start_ignore +assert mse(fb_fista.x_final, BETA_TRUE) < 1 +# sphinx_gallery_end_ignore + + +# %% +# Solving Using the POGM Algorithm +# -------------------------------- +# +# TODO: Add description/Reference to POGM. + + +cost_op_pogm = costObj([grad_op, prox_op], verbose=False) + +fb_pogm = POGM( + np.zeros(8), + np.zeros(8), + np.zeros(8), + np.zeros(8), + beta_param=1 / lip, + grad=grad_op, + prox=prox_op, + cost=cost_op_pogm, + metric_call_period=1, + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. +) + +fb_pogm.iterate() + +# %% +# After the run we can have a look at the results + +print(fb_pogm.x_final) +mse_pogm = mse(fb_pogm.x_final, BETA_TRUE) + +plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-") +plt.stem(BETA_TRUE, label="reference", linefmt="C1-") +plt.legend() +plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}") +# +# sphinx_gallery_start_ignore +assert mse(fb_pogm.x_final, BETA_TRUE) < 1 +# sphinx_gallery_end_ignore + +# %% +# Comparing the Two algorithms +# ---------------------------- + +plt.figure() +plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence") +plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence") +plt.xlabel("iterations") +plt.ylabel("Cost Function") +plt.legend() +plt.show() + + +# %% +# We can see that the two algorithm converges quickly, and POGM requires less iterations. +# However the POGM iterations are more costly, so a proper benchmark with time measurement is needed. +# Check the benchopt benchmark for more details. diff --git a/modopt/__init__.py b/modopt/__init__.py deleted file mode 100644 index 2c06c1db..00000000 --- a/modopt/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -"""MODOPT PACKAGE. - -ModOpt is a series of Modular Optimisation tools for solving inverse problems. - -""" - -from warnings import warn - -from importlib_metadata import version - -from modopt.base import * - -try: - _version = version('modopt') -except Exception: # pragma: no cover - _version = 'Unkown' - warn( - 'Could not extract package metadata. Make sure the package is ' - + 'correctly installed.', - ) - -__version__ = _version diff --git a/modopt/base/wrappers.py b/modopt/base/wrappers.py deleted file mode 100644 index baedb891..00000000 --- a/modopt/base/wrappers.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- - -"""WRAPPERS. - -This module contains wrappers for adding additional features to functions. - -:Author: Samuel Farrens - -""" - -from functools import wraps -from inspect import getfullargspec as argspec - - -def add_args_kwargs(func): - """Add args and kwargs. - - This wrapper adds support for additional arguments and keyword arguments to - any callable function. - - Parameters - ---------- - func : callable - Callable function - - Returns - ------- - callable - wrapper - - """ - @wraps(func) - def wrapper(*args, **kwargs): - - props = argspec(func) - - # if 'args' not in props: - if isinstance(props[1], type(None)): - args = args[:len(props[0])] - - if ( - (not isinstance(props[2], type(None))) - or (not isinstance(props[3], type(None))) - ): - return func(*args, **kwargs) - - return func(*args) - - return wrapper diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py deleted file mode 100644 index 7ff96a8b..00000000 --- a/modopt/tests/test_algorithms.py +++ /dev/null @@ -1,470 +0,0 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR OPT.ALGORITHMS. - -This module contains unit tests for the modopt.opt.algorithms module. - -:Author: Samuel Farrens - -""" - -from unittest import TestCase - -import numpy as np -import numpy.testing as npt - -from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight - -# Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - - prox_inst = proximity.Positivity() - prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) - - self.condat3 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, - ) - - self.pogm_all_iter = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - auto_iterate=False, - cost=None, - ) - self.pogm_all_iter.iterate(self.max_iter) - - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) - - self.vanilla_grad = algorithms.VanillaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.ada_grad = algorithms.AdaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.adam_grad = algorithms.ADAMGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.momentum_grad = algorithms.MomentumGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.rms_grad = algorithms.RMSpropGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.saga_grad = algorithms.SAGAOptGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - def test_ada_grad(self): - """Test ADA Gradient Descent.""" - self.ada_grad.iterate() - npt.assert_almost_equal( - self.ada_grad.x_final, - self.data1, - err_msg='Incorrect ADAGrad results.', - ) - - def test_adam_grad(self): - """Test ADAM Gradient Descent.""" - self.adam_grad.iterate() - npt.assert_almost_equal( - self.adam_grad.x_final, - self.data1, - err_msg='Incorrect ADAMGrad results.', - ) - - def test_momemtum_grad(self): - """Test Momemtum Gradient Descent.""" - self.momentum_grad.iterate() - npt.assert_almost_equal( - self.momentum_grad.x_final, - self.data1, - err_msg='Incorrect MomentumGrad results.', - ) - - def test_rmsprop_grad(self): - """Test RMSProp Gradient Descent.""" - self.rms_grad.iterate() - npt.assert_almost_equal( - self.rms_grad.x_final, - self.data1, - err_msg='Incorrect RMSPropGrad results.', - ) - - def test_saga_grad(self): - """Test SAGA Descent.""" - self.saga_grad.iterate() - npt.assert_almost_equal( - self.saga_grad.x_final, - self.data1, - err_msg='Incorrect SAGA Grad results.', - ) - - def test_vanilla_grad(self): - """Test Vanilla Gradient Descent.""" - self.vanilla_grad.iterate() - npt.assert_almost_equal( - self.vanilla_grad.x_final, - self.data1, - err_msg='Incorrect VanillaGrad results.', - ) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py deleted file mode 100644 index 873a4506..00000000 --- a/modopt/tests/test_base.py +++ /dev/null @@ -1,329 +0,0 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR BASE. - -This module contains unit tests for the modopt.base module. - -:Author: Samuel Farrens - -""" - -from builtins import range -from unittest import TestCase, skipIf - -import numpy as np -import numpy.testing as npt - -from modopt.base import np_adjust, transform, types -from modopt.base.backend import (LIBRARIES, change_backend, get_array_module, - get_backend) - - -class NPAdjustTestCase(TestCase): - """Test case for np_adjust module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape((3, 3)) - self.data2 = np.arange(18).reshape((2, 3, 3)) - self.data3 = np.array([ - [0, 0, 0, 0, 0], - [0, 0, 1, 2, 0], - [0, 3, 4, 5, 0], - [0, 6, 7, 8, 0], - [0, 0, 0, 0, 0], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - - def test_rotate(self): - """Test rotate.""" - npt.assert_array_equal( - np_adjust.rotate(self.data1), - np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]), - err_msg='Incorrect rotation', - ) - - def test_rotate_stack(self): - """Test rotate_stack.""" - npt.assert_array_equal( - np_adjust.rotate_stack(self.data2), - np.array([ - [[8, 7, 6], [5, 4, 3], [2, 1, 0]], - [[17, 16, 15], [14, 13, 12], [11, 10, 9]], - ]), - err_msg='Incorrect stack rotation', - ) - - def test_pad2d(self): - """Test pad2d.""" - npt.assert_array_equal( - np_adjust.pad2d(self.data1, (1, 1)), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, 1), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, np.array([1, 1])), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_raises(ValueError, np_adjust.pad2d, self.data1, '1') - - def test_fancy_transpose(self): - """Test fancy_transpose.""" - npt.assert_array_equal( - np_adjust.fancy_transpose(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose', - ) - - def test_ftr(self): - """Test ftr.""" - npt.assert_array_equal( - np_adjust.ftr(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose: ftr', - ) - - def test_ftl(self): - """Test ftl.""" - npt.assert_array_equal( - np_adjust.ftl(self.data2), - np.array([ - [[0, 9], [1, 10], [2, 11]], - [[3, 12], [4, 13], [5, 14]], - [[6, 15], [7, 16], [8, 17]], - ]), - err_msg='Incorrect fancy transpose: ftl', - ) - - -class TransformTestCase(TestCase): - """Test case for transform module.""" - - def setUp(self): - """Set test parameter values.""" - self.cube = np.arange(16).reshape((4, 2, 2)) - self.map = np.array( - [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]], - ) - self.matrix = np.array( - [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]], - ) - self.layout = (2, 2) - - def tearDown(self): - """Unset test parameter values.""" - self.cube = None - self.map = None - self.layout = None - - def test_cube2map(self): - """Test cube2map.""" - npt.assert_array_equal( - transform.cube2map(self.cube, self.layout), - self.map, - err_msg='Incorrect transformation: cube2map', - ) - - npt.assert_raises( - ValueError, - transform.cube2map, - self.map, - self.layout, - ) - - npt.assert_raises(ValueError, transform.cube2map, self.cube, (3, 3)) - - def test_map2cube(self): - """Test map2cube.""" - npt.assert_array_equal( - transform.map2cube(self.map, self.layout), - self.cube, - err_msg='Incorrect transformation: map2cube', - ) - - npt.assert_raises(ValueError, transform.map2cube, self.map, (3, 3)) - - def test_map2matrix(self): - """Test map2matrix.""" - npt.assert_array_equal( - transform.map2matrix(self.map, self.layout), - self.matrix, - err_msg='Incorrect transformation: map2matrix', - ) - - def test_matrix2map(self): - """Test matrix2map.""" - npt.assert_array_equal( - transform.matrix2map(self.matrix, self.map.shape), - self.map, - err_msg='Incorrect transformation: matrix2map', - ) - - def test_cube2matrix(self): - """Test cube2matrix.""" - npt.assert_array_equal( - transform.cube2matrix(self.cube), - self.matrix, - err_msg='Incorrect transformation: cube2matrix', - ) - - def test_matrix2cube(self): - """Test matrix2cube.""" - npt.assert_array_equal( - transform.matrix2cube(self.matrix, self.cube[0].shape), - self.cube, - err_msg='Incorrect transformation: matrix2cube', - ) - - -class TypesTestCase(TestCase): - """Test case for types module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = list(range(5)) - self.data2 = np.arange(5) - self.data3 = np.arange(5).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - - def test_check_float(self): - """Test check_float.""" - npt.assert_array_equal( - types.check_float(1.0), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(1), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data1), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data2), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_raises(TypeError, types.check_float, '1') - - def test_check_int(self): - """Test check_int.""" - npt.assert_array_equal( - types.check_int(1), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(1.0), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data1), - self.data2, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data3), - self.data2, - err_msg='Int check failed', - ) - - npt.assert_raises(TypeError, types.check_int, '1') - - def test_check_npndarray(self): - """Test check_npndarray.""" - npt.assert_raises( - TypeError, - types.check_npndarray, - self.data3, - dtype=np.integer, - ) - - -class TestBackend(TestCase): - """Test the backend codes.""" - - def setUp(self): - """Set test parameter values.""" - self.input = np.array([10, 10]) - - @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed') - def test_tf_backend(self): - """Test tensorflow backend.""" - xp, backend = get_backend('tensorflow') - if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']: - raise AssertionError('tensorflow get_backend fails!') - tf_input = change_backend(self.input, 'tensorflow') - if ( - get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow'] - or get_array_module(tf_input) != LIBRARIES['tensorflow'] - ): - raise AssertionError('tensorflow backend fails!') - - @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed') - def test_cp_backend(self): - """Test cupy backend.""" - xp, backend = get_backend('cupy') - if backend != 'cupy' or xp != LIBRARIES['cupy']: - raise AssertionError('cupy get_backend fails!') - cp_input = change_backend(self.input, 'cupy') - if ( - get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy'] - or get_array_module(cp_input) != LIBRARIES['cupy'] - ): - raise AssertionError('cupy backend fails!') - - def test_np_backend(self): - """Test numpy backend.""" - xp, backend = get_backend('numpy') - if backend != 'numpy' or xp != LIBRARIES['numpy']: - raise AssertionError('numpy get_backend fails!') - np_input = change_backend(self.input, 'numpy') - if ( - get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy'] - or get_array_module(np_input) != LIBRARIES['numpy'] - ): - raise AssertionError('numpy backend fails!') - - def tearDown(self): - """Tear Down of objects.""" - self.input = None diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py deleted file mode 100644 index 99908e02..00000000 --- a/modopt/tests/test_math.py +++ /dev/null @@ -1,496 +0,0 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR MATH. - -This module contains unit tests for the modopt.math module. - -:Author: Samuel Farrens - -""" - -from unittest import TestCase, skipIf, skipUnless - -import numpy as np -import numpy.testing as npt - -from modopt.math import convolve, matrix, metrics, stats - -try: - import astropy -except ImportError: # pragma: no cover - import_astropy = False -else: # pragma: no cover - import_astropy = True -try: - from skimage.metrics import structural_similarity as compare_ssim -except ImportError: # pragma: no cover - import_skimage = False -else: - import_skimage = True - - -class ConvolveTestCase(TestCase): - """Test case for convolve module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(2, 3, 3) - self.data2 = self.data1 + 1 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_convolve_astropy(self): - """Test convolve using astropy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='astropy'), - np.array([ - [210.0, 201.0, 210.0], - [129.0, 120.0, 129.0], - [210.0, 201.0, 210.0], - ]), - err_msg='Incorrect convolution: astropy', - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2, - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2[0], - method='bla', - ) - - def test_convolve_scipy(self): - """Test convolve using scipy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='scipy'), - np.array([ - [14.0, 35.0, 38.0], - [57.0, 120.0, 111.0], - [110.0, 197.0, 158.0], - ]), - err_msg='Incorrect convolution: scipy', - ) - - def test_convolve_stack(self): - """Test convolve_stack.""" - npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2), - np.array([ - [ - [14.0, 35.0, 38.0], - [57.0, 120.0, 111.0], - [110.0, 197.0, 158.0], - ], - [ - [518.0, 845.0, 614.0], - [975.0, 1578.0, 1137.0], - [830.0, 1331.0, 950.0], - ], - ]), - err_msg='Incorrect convolution: stack', - ) - - def test_convolve_stack_rot(self): - """Test convolve_stack rotated.""" - npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2, rot_kernel=True), - np.array([ - [ - [66.0, 115.0, 82.0], - [153.0, 240.0, 159.0], - [90.0, 133.0, 82.0], - ], - [ - [714.0, 1087.0, 730.0], - [1125.0, 1698.0, 1131.0], - [738.0, 1105.0, 730.0], - ], - ]), - err_msg='Incorrect convolution: stack rot', - ) - - -class MatrixTestCase(TestCase): - """Test case for matrix module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(3) - self.data3 = np.arange(6).reshape(2, 3) - np.random.seed(1) - self.pmInstance1 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - verbose=True, - ) - np.random.seed(1) - self.pmInstance2 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - auto_run=False, - verbose=True, - ) - self.pmInstance2.get_spec_rad(max_iter=1) - self.gram_schmidt_out = ( - np.array([ - [0, 1.0, 2.0], - [3.0, 1.2, -6e-1], - [-1.77635684e-15, 0, 0], - ]), - np.array([ - [0, 0.4472136, 0.89442719], - [0.91287093, 0.36514837, -0.18257419], - [-1.0, 0, 0], - ]), - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.pmInstance1 = None - self.pmInstance2 = None - self.gram_schmidt_out = None - - def test_gram_schmidt_orthonormal(self): - """Test gram_schmidt with orthonormal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1), - self.gram_schmidt_out[1], - err_msg='Incorrect Gram-Schmidt: orthonormal', - ) - - npt.assert_raises( - ValueError, - matrix.gram_schmidt, - self.data1, - return_opt='bla', - ) - - def test_gram_schmidt_orthogonal(self): - """Test gram_schmidt with orthogonal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='orthogonal'), - self.gram_schmidt_out[0], - err_msg='Incorrect Gram-Schmidt: orthogonal', - ) - - def test_gram_schmidt_both(self): - """Test gram_schmidt with both outputs.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='both'), - self.gram_schmidt_out, - err_msg='Incorrect Gram-Schmidt: both', - ) - - def test_nuclear_norm(self): - """Test nuclear_norm.""" - npt.assert_almost_equal( - matrix.nuclear_norm(self.data1), - 15.49193338482967, - err_msg='Incorrect nuclear norm', - ) - - def test_project(self): - """Test project.""" - npt.assert_array_equal( - matrix.project(self.data2, self.data2 + 3), - np.array([0, 2.8, 5.6]), - err_msg='Incorrect projection', - ) - - def test_rot_matrix(self): - """Test rot_matrix.""" - npt.assert_allclose( - matrix.rot_matrix(np.pi / 6), - np.array([[0.8660254, -0.5], [0.5, 0.8660254]]), - err_msg='Incorrect rotation matrix', - ) - - def test_rotate(self): - """Test rotate.""" - npt.assert_array_equal( - matrix.rotate(self.data1, np.pi / 2), - np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]), - err_msg='Incorrect rotation', - ) - - npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2) - - def test_powermethod_converged(self): - """Test PowerMethod converged.""" - npt.assert_almost_equal( - self.pmInstance1.spec_rad, - 0.90429242629600837, - err_msg='Incorrect spectral radius: converged', - ) - - npt.assert_almost_equal( - self.pmInstance1.inv_spec_rad, - 1.1058369736612865, - err_msg='Incorrect inverse spectral radius: converged', - ) - - def test_powermethod_unconverged(self): - """Test PowerMethod unconverged.""" - npt.assert_almost_equal( - self.pmInstance2.spec_rad, - 0.92048833577059219, - err_msg='Incorrect spectral radius: unconverged', - ) - - npt.assert_almost_equal( - self.pmInstance2.inv_spec_rad, - 1.0863798715741946, - err_msg='Incorrect inverse spectral radius: unconverged', - ) - - -class MetricsTestCase(TestCase): - """Test case for metrics module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(49).reshape(7, 7) - self.mask = np.ones(self.data1.shape) - self.ssim_res = 0.8963363560519094 - self.ssim_mask_res = 0.805154442543846 - self.snr_res = 10.134554256920536 - self.psnr_res = 14.860761791850397 - self.mse_res = 0.03265305507330247 - self.nrmse_res = 0.31136678840022625 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.mask = None - self.ssim_res = None - self.ssim_mask_res = None - self.psnr_res = None - self.mse_res = None - self.nrmse_res = None - - @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover - def test_ssim_skimage_error(self): - """Test ssim skimage error.""" - npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - - @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover - def test_ssim(self): - """Test ssim.""" - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2), - self.ssim_res, - err_msg='Incorrect SSIM result', - ) - - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask), - self.ssim_mask_res, - err_msg='Incorrect SSIM result', - ) - - npt.assert_raises( - ValueError, - metrics.ssim, - self.data1, - self.data1, - mask=1, - ) - - def test_snr(self): - """Test snr.""" - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2, mask=self.mask), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - def test_psnr(self): - """Test psnr.""" - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2, mask=self.mask), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - def test_nrmse(self): - """Test nrmse.""" - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - -class StatsTestCase(TestCase): - """Test case for stats module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(18).reshape(2, 3, 3) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover - def test_gaussian_kernel_astropy_error(self): - """Test gaussian_kernel astropy error.""" - npt.assert_raises( - ImportError, - stats.gaussian_kernel, - self.data1.shape, - 1, - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_max(self): - """Test gaussian_kernel with max norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1), - np.array([ - [0.36787944, 0.60653066, 0.36787944], - [0.60653066, 1.0, 0.60653066], - [0.36787944, 0.60653066, 0.36787944], - ]), - err_msg='Incorrect gaussian kernel: max norm', - ) - - npt.assert_raises( - ValueError, - stats.gaussian_kernel, - self.data1.shape, - 1, - norm='bla', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_sum(self): - """Test gaussian_kernel with sum norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='sum'), - np.array([ - [0.07511361, 0.1238414, 0.07511361], - [0.1238414, 0.20417996, 0.1238414], - [0.07511361, 0.1238414, 0.07511361], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_none(self): - """Test gaussian_kernel with no norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='none'), - np.array([ - [0.05854983, 0.09653235, 0.05854983], - [0.09653235, 0.15915494, 0.09653235], - [0.05854983, 0.09653235, 0.05854983], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) - - def test_mad(self): - """Test mad.""" - npt.assert_equal( - stats.mad(self.data1), - 2.0, - err_msg='Incorrect median absolute deviation', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_equal( - stats.mse(self.data1, self.data1 + 2), - 4.0, - err_msg='Incorrect mean squared error', - ) - - def test_psnr_starck(self): - """Test psnr.""" - npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2), - 12.041199826559248, - err_msg='Incorrect PSNR: starck', - ) - - npt.assert_raises( - ValueError, - stats.psnr, - self.data1, - self.data1, - method='bla', - ) - - def test_psnr_wiki(self): - """Test psnr wiki method.""" - npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2, method='wiki'), - 42.110203695399477, - err_msg='Incorrect PSNR: wiki', - ) - - def test_psnr_stack(self): - """Test psnr stack.""" - npt.assert_almost_equal( - stats.psnr_stack(self.data2, self.data2 + 2), - 12.041199826559248, - err_msg='Incorrect PSNR stack', - ) - - npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1) - - def test_sigma_mad(self): - """Test sigma_mad.""" - npt.assert_almost_equal( - stats.sigma_mad(self.data1), - 2.9651999999999998, - err_msg='Incorrect sigma from MAD', - ) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py deleted file mode 100644 index d5547783..00000000 --- a/modopt/tests/test_opt.py +++ /dev/null @@ -1,1071 +0,0 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR OPT. - -This module contains unit tests for the modopt.opt module. - -:Author: Samuel Farrens - -""" - -from builtins import zip -from unittest import TestCase, skipIf, skipUnless - -import numpy as np -import numpy.testing as npt - -from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight - -try: - import sklearn -except ImportError: # pragma: no cover - import_sklearn = False -else: - import_sklearn = True - - -# Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - - prox_inst = proximity.Positivity() - prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) - - self.condat3 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, - ) - - self.pogm_all_iter = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - auto_iterate=False, - cost=None, - ) - self.pogm_all_iter.iterate(self.max_iter) - - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - -class CostTestCase(TestCase): - """Test case for cost module.""" - - def setUp(self): - """Set test parameter values.""" - dummy_inst1 = Dummy() - dummy_inst1.cost = func_sq - dummy_inst2 = Dummy() - dummy_inst2.cost = func_cube - - self.inst1 = cost.costObj([dummy_inst1, dummy_inst2]) - self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2) - # Test that by default cost of False if interval is None - self.inst_none = cost.costObj( - [dummy_inst1, dummy_inst2], - cost_interval=None, - ) - for _ in range(2): - self.inst1.get_cost(2) - for _ in range(6): - self.inst2.get_cost(2) - self.inst_none.get_cost(2) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.inst = None - - def test_cost_object(self): - """Test cost_object.""" - npt.assert_equal( - self.inst1.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst1.get_cost(2), - True, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst_none.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - - npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.') - - npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.') - - npt.assert_raises(TypeError, cost.costObj, 1) - - npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy]) - - -class GradientTestCase(TestCase): - """Test case for gradient module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.gp = gradient.GradParent( - self.data1, - func_sq, - func_cube, - func_identity, - lambda input_val: 1.0, - data_type=np.floating, - ) - self.gp.grad = self.gp.get_grad(self.data1) - self.gb = gradient.GradBasic( - self.data1, - func_sq, - func_cube, - ) - self.gb.get_grad(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.gp = None - self.gb = None - - def test_grad_parent_operators(self): - """Test GradParent.""" - npt.assert_array_equal( - self.gp.op(self.data1), - np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]), - err_msg='Incorrect gradient operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op(self.data1), - np.array( - [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]], - ), - err_msg='Incorrect gradient transpose operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op_op(self.data1), - np.array([ - [0, 1.0, 6.40000000e1], - [7.29000000e2, 4.09600000e3, 1.56250000e4], - [4.66560000e4, 1.17649000e5, 2.62144000e5], - ]), - err_msg='Incorrect gradient transpose operation operation.', - ) - - npt.assert_equal( - self.gp.cost(self.data1), - 1.0, - err_msg='Incorrect cost.', - ) - - npt.assert_raises( - TypeError, - gradient.GradParent, - 1, - func_sq, - func_cube, - ) - - def test_grad_basic_gradient(self): - """Test GradBasic.""" - npt.assert_array_equal( - self.gb.grad, - np.array([ - [0, 0, 8.0], - [2.16000000e2, 1.72800000e3, 8.0e3], - [2.70000000e4, 7.40880000e4, 1.75616000e5], - ]), - err_msg='Incorrect gradient.', - ) - - -class LinearTestCase(TestCase): - """Test case for linear module.""" - - def setUp(self): - """Set test parameter values.""" - self.parent = linear.LinearParent( - func_sq, - func_cube, - ) - self.ident = linear.Identity() - filters = np.arange(8).reshape(2, 2, 2).astype(float) - self.wave = linear.WaveletConvolve(filters) - self.combo = linear.LinearCombo([self.parent, self.parent]) - self.combo_weight = linear.LinearCombo( - [self.parent, self.parent], - [1.0, 1.0], - ) - self.data1 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data2 = np.arange(4).reshape(1, 2, 2).astype(float) - self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float) - self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) - self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.ident = None - self.combo = None - self.combo_weight = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.dummy = None - - def test_linear_parent(self): - """Test LinearParent.""" - npt.assert_equal( - self.parent.op(2), - 4, - err_msg='Incorrect linear parent operation.', - ) - - npt.assert_equal( - self.parent.adj_op(2), - 8, - err_msg='Incorrect linear parent adjoint operation.', - ) - - npt.assert_raises(TypeError, linear.LinearParent, 0, 0) - - def test_identity(self): - """Test Identity.""" - npt.assert_equal( - self.ident.op(1.0), - 1.0, - err_msg='Incorrect identity operation.', - ) - - npt.assert_equal( - self.ident.adj_op(1.0), - 1.0, - err_msg='Incorrect identity adjoint operation.', - ) - - def test_wavelet_convolve(self): - """Test WaveletConvolve.""" - npt.assert_almost_equal( - self.wave.op(self.data2), - self.data4, - err_msg='Incorrect wavelet convolution operation.', - ) - - npt.assert_almost_equal( - self.wave.adj_op(self.data3), - self.data5, - err_msg='Incorrect wavelet convolution adjoint operation.', - ) - - def test_linear_combo(self): - """Test LinearCombo.""" - npt.assert_equal( - self.combo.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo.adj_op([2, 2]), - 8.0, - err_msg='Incorrect combined linear adjoint operation', - ) - - npt.assert_raises(TypeError, linear.LinearCombo, self.parent) - - npt.assert_raises(ValueError, linear.LinearCombo, []) - - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - - def test_linear_combo_weight(self): - """Test LinearCombo with weight .""" - npt.assert_equal( - self.combo_weight.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo_weight.adj_op([2, 2]), - 16.0, - err_msg='Incorrect combined linear adjoint operation', - ) - - npt.assert_raises( - ValueError, - linear.LinearCombo, - [self.parent, self.parent], - [1.0], - ) - - npt.assert_raises( - TypeError, - linear.LinearCombo, - [self.parent, self.parent], - ['1', '1'], - ) - - -class ProximityTestCase(TestCase): - """Test case for proximity module.""" - - def setUp(self): - """Set test parameter values.""" - self.parent = proximity.ProximityParent( - func_sq, - func_double, - ) - self.identity = proximity.IdentityProx() - self.positivity = proximity.Positivity() - weights = np.ones(9).reshape(3, 3).astype(float) * 3 - self.sparsethresh = proximity.SparseThreshold( - linear.Identity(), - weights, - ) - self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard') - self.lowrank_rank = proximity.LowRankMatrix( - 10.0, - initial_rank=1, - thresh_type='hard', - ) - self.lowrank_ngole = proximity.LowRankMatrix( - 10.0, - lowr_type='ngole', - operator=func_double, - ) - self.linear_comp = proximity.LinearCompositionProx( - linear_op=linear.Identity(), - prox_op=self.sparsethresh, - ) - self.combo = proximity.ProximityCombo([self.identity, self.positivity]) - if import_sklearn: - self.owl = proximity.OrderedWeightedL1Norm(weights.flatten()) - self.ridge = proximity.Ridge(linear.Identity(), weights) - self.elasticnet_alpha0 = proximity.ElasticNet( - linear.Identity(), - alpha=0, - beta=weights, - ) - self.elasticnet_beta0 = proximity.ElasticNet( - linear.Identity(), - alpha=weights, - beta=0, - ) - self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1) - self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5) - self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19) - self.group_lasso = proximity.GroupLASSO( - weights=np.tile(weights, (4, 1, 1)), - ) - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) - self.data3 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data4 = np.array([ - [ - [2.73843189, 3.14594066, 3.55344943], - [3.9609582, 4.36846698, 4.77597575], - [5.18348452, 5.59099329, 5.99850206], - ], - [ - [8.07085295, 9.2718846, 10.47291625], - [11.67394789, 12.87497954, 14.07601119], - [15.27704284, 16.47807449, 17.67910614], - ], - ]) - self.data5 = np.array([ - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [ - [4.00795282, 4.60438026, 5.2008077], - [5.79723515, 6.39366259, 6.99009003], - [7.58651747, 8.18294492, 8.77937236], - ], - ]) - self.data6 = self.data3 * -1 - self.data7 = self.combo.op(self.data6) - self.data8 = np.empty(2, dtype=np.ndarray) - self.data8[0] = np.array( - [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]], - ) - self.data8[1] = np.array( - [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]], - ) - self.data9 = self.data1 * (1 + 1j) - self.data10 = self.data9 / (2 * 3 + 1) - self.data11 = np.asarray( - [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]], - ) - self.random_data = 3 * np.random.random( - self.group_lasso.weights[0].shape, - ) - self.random_data_tile = np.tile( - self.random_data, - (self.group_lasso.weights.shape[0], 1, 1), - ) - self.gl_result_data = 2 * self.random_data_tile - 3 - self.gl_result_data = np.array( - (self.gl_result_data * (self.gl_result_data > 0).astype('int')) - / 2, - ) - - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.identity = None - self.positivity = None - self.sparsethresh = None - self.lowrank = None - self.lowrank_rank = None - self.lowrank_ngole = None - self.combo = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.data6 = None - self.data7 = None - self.data8 = None - self.dummy = None - self.random_data = None - self.random_data_tile = None - self.gl_result_data = None - - def test_proximity_parent(self): - """Test ProximityParent.""" - npt.assert_equal( - self.parent.op(3), - 9, - err_msg='Inccoret proximity parent operation.', - ) - - npt.assert_equal( - self.parent.cost(3), - 6, - err_msg='Incorrect proximity parent cost.', - ) - - def test_identity(self): - """Test IdentityProx.""" - npt.assert_equal( - self.identity.op(3), - 3, - err_msg='Incorrect proximity identity operation.', - ) - - npt.assert_equal( - self.identity.cost(3), - 0, - err_msg='Incorrect proximity identity cost.', - ) - - def test_positivity(self): - """Test Positivity.""" - npt.assert_equal( - self.positivity.op(-3), - 0, - err_msg='Incorrect proximity positivity operation.', - ) - - npt.assert_equal( - self.positivity.cost(-3, verbose=True), - 0, - err_msg='Incorrect proximity positivity cost.', - ) - - def test_sparse_threshold(self): - """Test SparseThreshold.""" - npt.assert_array_equal( - self.sparsethresh.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.sparsethresh.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_low_rank_matrix(self): - """Test LowRankMatrix.""" - npt.assert_almost_equal( - self.lowrank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard', - ) - - npt.assert_almost_equal( - self.lowrank_rank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard with rank', - ) - npt.assert_almost_equal( - self.lowrank_ngole.op(self.data3), - self.data5, - err_msg='Incorrect low rank operation: ngole', - ) - - npt.assert_almost_equal( - self.lowrank.cost(self.data3, verbose=True), - 469.39132942464983, - err_msg='Incorrect low rank cost.', - ) - - def test_linear_comp_prox(self): - """Test LinearCompositionProx.""" - npt.assert_array_equal( - self.linear_comp.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.linear_comp.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_proximity_combo(self): - """Test ProximityCombo.""" - for data7, data8 in zip(self.data7, self.data8): - npt.assert_array_equal( - data7, - data8, - err_msg='Incorrect combined operation', - ) - - npt.assert_equal( - self.combo.cost(self.data6), - 0, - err_msg='Incorrect combined cost.', - ) - - npt.assert_raises(TypeError, proximity.ProximityCombo, 1) - - npt.assert_raises(ValueError, proximity.ProximityCombo, []) - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover - def test_owl_sklearn_error(self): - """Test OrderedWeightedL1Norm with Scikit-Learn.""" - npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - - @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover - def test_sparse_owl(self): - """Test OrderedWeightedL1Norm.""" - npt.assert_array_equal( - self.owl.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.owl.cost(self.data1.flatten(), verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - npt.assert_raises( - ValueError, - proximity.OrderedWeightedL1Norm, - np.arange(10), - ) - - def test_ridge(self): - """Test Ridge.""" - npt.assert_array_equal( - self.ridge.op(self.data9), - self.data10, - err_msg='Incorect shrinkage operation.', - ) - - npt.assert_equal( - self.ridge.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost.', - ) - - def test_elastic_net_alpha0(self): - """Test ElasticNet.""" - npt.assert_array_equal( - self.elasticnet_alpha0.op(self.data1), - self.data2, - err_msg='Incorect sparse threshold operation ElasticNet class.', - ) - - npt.assert_equal( - self.elasticnet_alpha0.cost(self.data1), - 108.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - - def test_elastic_net_beta0(self): - """Test ElasticNet with beta=0.""" - npt.assert_array_equal( - self.elasticnet_beta0.op(self.data9), - self.data10, - err_msg='Incorect ridge operation ElasticNet class.', - ) - - npt.assert_equal( - self.elasticnet_beta0.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - - def test_one_support_norm(self): - """Test KSupportNorm with k=1.""" - npt.assert_allclose( - self.one_support.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorect sparse threshold operation for 1-support norm', - rtol=1e-6, - ) - - npt.assert_equal( - self.one_support.cost(self.data1.flatten(), verbose=True), - 259.2, - err_msg='Incorect sparse threshold cost.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_five_support_norm(self): - """Test KSupportNorm with k=5.""" - npt.assert_allclose( - self.five_support_norm.op(self.data1.flatten()), - self.data11.flatten(), - err_msg='Incorect sparse Ksupport norm operation', - rtol=1e-6, - ) - - npt.assert_equal( - self.five_support_norm.cost(self.data1.flatten(), verbose=True), - 684.0, - err_msg='Incorrect 5-support norm cost.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_d_support_norm(self): - """Test KSupportNorm with k=19.""" - npt.assert_allclose( - self.d_support.op(self.data9.flatten()), - self.data10.flatten(), - err_msg='Incorect shrinkage operation for d-support norm', - rtol=1e-6, - ) - - npt.assert_almost_equal( - self.d_support.cost(self.data9.flatten(), verbose=True), - 408.0 * 3.0, - err_msg='Incorrect shrinkage cost for d-support norm.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_group_lasso(self): - """Test GroupLASSO.""" - npt.assert_allclose( - self.group_lasso.op(self.random_data_tile), - self.gl_result_data, - ) - npt.assert_equal( - self.group_lasso.cost(self.random_data_tile), - np.sum(6 * self.random_data_tile), - ) - # Check that for 0 weights operator doesnt change result - self.group_lasso.weights = np.zeros_like(self.group_lasso.weights) - npt.assert_equal( - self.group_lasso.op(self.random_data_tile), - self.random_data_tile, - ) - npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0) - - -class ReweightTestCase(TestCase): - """Test case for reweight module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1 - self.data2 = np.array( - [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], - ) - self.rw = reweight.cwbReweight(self.data1) - self.rw.reweight(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.rw = None - - def test_cwbreweight(self): - """Test cwbReweight.""" - npt.assert_array_equal( - self.rw.weights, - self.data2, - err_msg='Incorrect CWB re-weighting.', - ) - - npt.assert_raises(ValueError, self.rw.reweight, self.data1[0]) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py deleted file mode 100644 index 7490b98c..00000000 --- a/modopt/tests/test_signal.py +++ /dev/null @@ -1,414 +0,0 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR SIGNAL. - -This module contains unit tests for the modopt.signal module. - -:Author: Samuel Farrens - -""" - -from unittest import TestCase - -import numpy as np -import numpy.testing as npt - -from modopt.signal import filter, noise, positivity, svd, validation, wavelet - - -class FilterTestCase(TestCase): - """Test case for filter module.""" - - def test_guassian_filter(self): - """Test guassian_filter.""" - npt.assert_almost_equal( - filter.gaussian_filter(1, 1), - 0.24197072451914337, - err_msg='Incorrect Gaussian filter', - ) - - npt.assert_almost_equal( - filter.gaussian_filter(1, 1, norm=False), - 0.60653065971263342, - err_msg='Incorrect Gaussian filter', - ) - - def test_mex_hat(self): - """Test mex_hat.""" - npt.assert_almost_equal( - filter.mex_hat(2, 1), - -0.35213905225713371, - err_msg='Incorrect Mexican hat filter', - ) - - def test_mex_hat_dir(self): - """Test mex_hat_dir.""" - npt.assert_almost_equal( - filter.mex_hat_dir(1, 2, 1), - 0.17606952612856686, - err_msg='Incorrect directional Mexican hat filter', - ) - - -class NoiseTestCase(TestCase): - """Test case for noise module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array( - [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], - ) - self.data3 = np.array([ - [1.62434536, 0.38824359, 1.47182825], - [1.92703138, 4.86540763, 2.6984613], - [7.74481176, 6.2387931, 8.3190391], - ]) - self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) - self.data5 = np.array( - [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_add_noise_poisson(self): - """Test add_noise with Poisson noise.""" - np.random.seed(1) - npt.assert_array_equal( - noise.add_noise(self.data1, noise_type='poisson'), - self.data2, - err_msg='Incorrect noise: Poisson', - ) - - npt.assert_raises( - ValueError, - noise.add_noise, - self.data1, - noise_type='bla', - ) - - npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1)) - - def test_add_noise_gaussian(self): - """Test add_noise with Gaussian noise.""" - np.random.seed(1) - npt.assert_almost_equal( - noise.add_noise(self.data1), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - - np.random.seed(1) - npt.assert_almost_equal( - noise.add_noise(self.data1, sigma=(1, 1, 1)), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - - def test_thresh_hard(self): - """Test thresh with hard threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5), - self.data4, - err_msg='Incorrect threshold: hard', - ) - - npt.assert_raises( - ValueError, - noise.thresh, - self.data1, - 5, - threshold_type='bla', - ) - - def test_thresh_soft(self): - """Test thresh with soft threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5, threshold_type='soft'), - self.data5, - err_msg='Incorrect threshold: soft', - ) - - -class PositivityTestCase(TestCase): - """Test case for positivity module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - 5 - self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]]) - self.data3 = np.array( - [np.arange(5) - 3, np.arange(4) - 2], - dtype=object, - ) - self.data4 = np.array( - [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])], - dtype=object, - ) - self.pos_dtype_obj = positivity.positive(self.data3) - self.err = 'Incorrect positivity' - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - def test_positivity(self): - """Test positivity.""" - npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err) - - npt.assert_equal( - positivity.positive(-1.0), - -float(0), - err_msg=self.err, - ) - - npt.assert_equal( - positivity.positive(self.data1), - self.data2, - err_msg=self.err, - ) - - for expected, output in zip(self.data4, self.pos_dtype_obj): - print(expected, output) - npt.assert_array_equal(expected, output, err_msg=self.err) - - npt.assert_raises(TypeError, positivity.positive, '-1') - - -class SVDTestCase(TestCase): - """Test case for svd module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(9, 2).astype(float) - self.data2 = np.arange(32).reshape(16, 2).astype(float) - self.data3 = np.array( - [ - np.array([ - [-0.01744594, -0.61438865], - [-0.08435304, -0.50397984], - [-0.15126014, -0.39357102], - [-0.21816724, -0.28316221], - [-0.28507434, -0.17275339], - [-0.35198144, -0.06234457], - [-0.41888854, 0.04806424], - [-0.48579564, 0.15847306], - [-0.55270274, 0.26888188], - ]), - np.array([42.23492742, 1.10041151]), - np.array([ - [-0.67608034, -0.73682791], - [0.73682791, -0.67608034], - ]), - ], - dtype=object, - ) - self.data4 = np.array([ - [-1.05426832e-16, 1.0], - [2.0, 3.0], - [4.0, 5.0], - [6.0, 7.0], - [8.0, 9.0], - [1.0e1, 1.1e1], - [1.2e1, 1.3e1], - [1.4e1, 1.5e1], - [1.6e1, 1.7e1], - ]) - self.data5 = np.array([ - [0.49815487, 0.54291537], - [2.40863386, 2.62505584], - [4.31911286, 4.70719631], - [6.22959185, 6.78933678], - [8.14007085, 8.87147725], - [10.05054985, 10.95361772], - [11.96102884, 13.03575819], - [13.87150784, 15.11789866], - [15.78198684, 17.20003913], - ]) - self.svd = svd.calculate_svd(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.svd = None - - def test_find_n_pc(self): - """Test find_n_pc.""" - npt.assert_equal( - svd.find_n_pc(svd.svd(self.data2)[0]), - 2, - err_msg='Incorrect number of principal components.', - ) - - npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) - - def test_calculate_svd(self): - """Test calculate_svd.""" - npt.assert_almost_equal( - self.svd[0], - np.array(self.data3)[0], - err_msg='Incorrect SVD calculation: U', - ) - - npt.assert_almost_equal( - self.svd[1], - np.array(self.data3)[1], - err_msg='Incorrect SVD calculation: S', - ) - - npt.assert_almost_equal( - self.svd[2], - np.array(self.data3)[2], - err_msg='Incorrect SVD calculation: V', - ) - - def test_svd_thresh(self): - """Test svd_thresh.""" - npt.assert_almost_equal( - svd.svd_thresh(self.data1), - self.data4, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc=1), - self.data5, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc='all'), - self.data1, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_raises(TypeError, svd.svd_thresh, 1) - - npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla') - - def test_svd_thresh_coef(self): - """Test svd_thresh_coef.""" - npt.assert_almost_equal( - svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0), - self.data1, - err_msg='Incorrect SVD coefficient tresholding', - ) - - npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0) - - -class ValidationTestCase(TestCase): - """Test case for validation module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - def test_transpose_test(self): - """Test transpose_test.""" - np.random.seed(2) - npt.assert_equal( - validation.transpose_test( - lambda x_val, y_val: x_val.dot(y_val), - lambda x_val, y_val: x_val.dot(y_val.T), - self.data1.shape, - x_args=self.data1, - ), - None, - ) - - npt.assert_raises( - TypeError, - validation.transpose_test, - 0, - 0, - self.data1.shape, - x_args=self.data1, - ) - - -class WaveletTestCase(TestCase): - """Test case for wavelet module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.arange(36).reshape(4, 3, 3).astype(float) - self.data3 = np.array([ - [ - [6.0, 20, 26.0], - [36.0, 84.0, 84.0], - [90, 164.0, 134.0], - ], - [ - [78.0, 155.0, 134.0], - [225.0, 408.0, 327.0], - [270, 461.0, 350], - ], - [ - [150, 290, 242.0], - [414.0, 732.0, 570], - [450, 758.0, 566.0], - ], - [ - [222.0, 425.0, 350], - [603.0, 1056.0, 813.0], - [630, 1055.0, 782.0], - ], - ]) - - self.data4 = np.array([ - [6496.0, 9796.0, 6544.0], - [9924.0, 14910, 9924.0], - [6544.0, 9796.0, 6496.0], - ]) - - self.data5 = np.array([ - [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], - [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], - [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_filter_convolve(self): - """Test filter_convolve.""" - npt.assert_almost_equal( - wavelet.filter_convolve(self.data1, self.data2), - self.data3, - err_msg='Inccorect filter comvolution.', - ) - - npt.assert_almost_equal( - wavelet.filter_convolve(self.data2, self.data2, filter_rot=True), - self.data4, - err_msg='Inccorect filter comvolution.', - ) - - def test_filter_convolve_stack(self): - """Test filter_convolve_stack.""" - npt.assert_almost_equal( - wavelet.filter_convolve_stack(self.data1, self.data1), - self.data5, - err_msg='Inccorect filter stack comvolution.', - ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..84eb967a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[project] +name="modopt" +description = 'Modular Optimisation tools for soliving inverse problems.' +version = "1.7.2" +requires-python= ">=3.8" + +authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"}, +{name="Chaithya GR", email="chaithyagr@gmail.com"}, +{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"}, +{name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"} +] +readme="README.md" +license={file="LICENCE.txt"} + +dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"] + +[project.optional-dependencies] +gpu=["torch", "ptwt"] +doc=["myst-parser", +"nbsphinx", +"nbsphinx-link", +"sphinx-gallery", +"numpydoc", +"sphinxawesome-theme", +"sphinxcontrib-bibtex"] +dev=["black", "ruff"] +test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"] + +[build-system] +requires=["setuptools", "setuptools-scm[toml]", "wheel"] + +[tool.coverage.run] +omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"] + +[tool.coverage.report] +precision = 2 +exclude_lines = ["pragma: no cover", "raise NotImplementedError"] + +[tool.black] + + +[tool.ruff] +exclude = ["examples", "docs"] +[tool.ruff.lint] +select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"] + +ignore = ["F401"] # we like the try: import ... expect: ... + +[tool.ruff.lint.pydocstyle] +convention="numpy" + +[tool.isort] +profile="black" + +[tool.pytest.ini_options] +minversion = "6.0" +norecursedirs = ["tests/test_helpers"] +addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 63a404ba..00000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -importlib_metadata>=3.7.0 -numpy>=1.19.5 -scipy>=1.5.4 -progressbar2>=3.53.1 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index cabd35a0..00000000 --- a/setup.cfg +++ /dev/null @@ -1,91 +0,0 @@ -[aliases] -test=pytest - -[metadata] -description_file = README.rst - -[darglint] -docstring_style = numpy -strictness = short - -[flake8] -ignore = - D107, #Justification: Don't need docstring for __init__ in numpydoc style - RST304, #Justification: Need to use :cite: role for citations - RST210, #Justification: RST210, RST213 Inconsistent with numpydoc - RST213, # documentation for handling *args and **kwargs - W503, #Justification: Have to choose one multiline operator format - WPS202, #Todo: Rethink module size, possibly split large modules - WPS337, #Todo: Consider simplifying multiline conditions. - WPS338, #Todo: Consider changing method order - WPS403, #Todo: Rethink no cover lines - WPS421, #Todo: Review need for print statements - WPS432, #Justification: Mathematical codes require "magic numbers" - WPS433, #Todo: Rethink conditional imports - WPS463, #Todo: Rename get_ methods - WPS615, #Todo: Rename get_ methods -per-file-ignores = - #Justification: Needed for keeping package version and current API - *__init__.py*: F401,F403,WPS347,WPS410,WPS412 - #Todo: Rethink conditional imports - #Todo: How can we bypass mutable constants? - modopt/base/backend.py: WPS229, WPS420, WPS407 - #Todo: Rethink conditional imports - modopt/base/observable.py: WPS420,WPS604 - #Todo: Check string for log formatting - modopt/interface/log.py: WPS323 - #Todo: Rethink conditional imports - modopt/math/convolve.py: WPS301,WPS420 - #Todo: Rethink conditional imports - modopt/math/matrix.py: WPS420 - #Todo: import has bad parenthesis - modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410 - #Todo: x is a too short name. - modopt/opt/algorithms/forward_backward.py: WPS111 - #Todo: Check need for del statement - modopt/opt/algorithms/primal_dual.py: WPS111, WPS420 - #multiline parameters bug with tuples - modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317 - #Todo: Consider changing costObj name - modopt/opt/cost.py: N801, - #Todo: - # - Rethink subscript slice assignment - # - Reduce complexity of KSupportNorm - # - Check bitwise operations - modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508 - #Todo: Consider changing cwbReweight name - modopt/opt/reweight.py: N801 - #Justification: Needed to import matplotlib.pyplot - modopt/plot/cost_plot.py: N802,WPS301 - #Todo: Investigate possible bug in find_n_pc function - #Todo: Investigate darglint error - modopt/signal/svd.py: WPS345, DAR000 - #Todo: Check security of using system executable call - modopt/signal/wavelet.py: S404,S603 - #Todo: Clean up tests - modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604 - #Todo: Import has bad parenthesis - modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301 -#WPS Settings -max-arguments = 25 -max-attributes = 40 -max-cognitive-score = 20 -max-function-expressions = 20 -max-line-complexity = 30 -max-local-variables = 10 -max-methods = 20 -max-module-expressions = 20 -max-string-usages = 20 -max-raises = 5 - -[tool:pytest] -testpaths = - modopt -addopts = - --verbose - --emoji - --flake8 - --cov=modopt - --cov-report=term - --cov-report=xml - --junitxml=pytest.xml diff --git a/setup.py b/setup.py deleted file mode 100644 index c93dd020..00000000 --- a/setup.py +++ /dev/null @@ -1,73 +0,0 @@ -#! /usr/bin/env python -# -*- coding: utf-8 -*- - -from setuptools import setup, find_packages -import os - -# Set the package release version -major = 1 -minor = 6 -patch = 1 - -# Set the package details -name = 'modopt' -version = '.'.join(str(value) for value in (major, minor, patch)) -author = 'Samuel Farrens' -email = 'samuel.farrens@cea.fr' -gh_user = 'cea-cosmic' -url = 'https://github.com/{0}/{1}'.format(gh_user, name) -description = 'Modular Optimisation tools for soliving inverse problems.' -license = 'MIT' - -# Set the package classifiers -python_versions_supported = ['3.6', '3.7', '3.8', '3.9'] -os_platforms_supported = ['Unix', 'MacOS'] - -lc_str = 'License :: OSI Approved :: {0} License' -ln_str = 'Programming Language :: Python' -py_str = 'Programming Language :: Python :: {0}' -os_str = 'Operating System :: {0}' - -classifiers = ( - [lc_str.format(license)] - + [ln_str] - + [py_str.format(ver) for ver in python_versions_supported] - + [os_str.format(ops) for ops in os_platforms_supported] -) - -# Source package description from README.md -this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - -# Source package requirements from requirements.txt -with open('requirements.txt') as open_file: - install_requires = open_file.read() - -# Source test requirements from develop.txt -with open('develop.txt') as open_file: - tests_require = open_file.read() - -# Source doc requirements from docs/requirements.txt -with open('docs/requirements.txt') as open_file: - docs_require = open_file.read() - - -setup( - name=name, - author=author, - author_email=email, - version=version, - license=license, - url=url, - description=description, - long_description=long_description, - long_description_content_type='text/markdown', - packages=find_packages(), - install_requires=install_requires, - python_requires='>=3.6', - setup_requires=['pytest-runner'], - tests_require=tests_require, - extras_require={'develop': tests_require + docs_require}, - classifiers=classifiers, -) diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py new file mode 100644 index 00000000..5e8de1b6 --- /dev/null +++ b/src/modopt/__init__.py @@ -0,0 +1,25 @@ +"""MODOPT PACKAGE. + +ModOpt is a series of Modular Optimisation tools for solving inverse problems. + +""" + +from warnings import warn + +from importlib_metadata import version + +from modopt.base import np_adjust, transform, types, observable + +__all__ = ["np_adjust", "transform", "types", "observable"] + +try: + _version = version("modopt") +except Exception: # pragma: no cover + _version = "Unkown" + warn( + "Could not extract package metadata. Make sure the package is " + + "correctly installed.", + stacklevel=1, + ) + +__version__ = _version diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py similarity index 68% rename from modopt/base/__init__.py rename to src/modopt/base/__init__.py index 1c0c8b2c..c4c681d7 100644 --- a/modopt/base/__init__.py +++ b/src/modopt/base/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """BASE ROUTINES. This module contains submodules for basic operations such as type @@ -9,4 +7,4 @@ """ -__all__ = ['np_adjust', 'transform', 'types', 'wrappers', 'observable'] +__all__ = ["np_adjust", "transform", "types", "observable"] diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py similarity index 77% rename from modopt/base/backend.py rename to src/modopt/base/backend.py index 1f4e9a72..485f649a 100644 --- a/modopt/base/backend.py +++ b/src/modopt/base/backend.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """BACKEND MODULE. This module contains methods for GPU Compatiblity. @@ -26,22 +24,24 @@ # Handle the compatibility with variable LIBRARIES = { - 'cupy': None, - 'tensorflow': None, - 'numpy': np, + "cupy": None, + "tensorflow": None, + "numpy": np, } -if util.find_spec('cupy') is not None: +if util.find_spec("cupy") is not None: try: import cupy as cp - LIBRARIES['cupy'] = cp + + LIBRARIES["cupy"] = cp except ImportError: pass -if util.find_spec('tensorflow') is not None: +if util.find_spec("tensorflow") is not None: try: from tensorflow.experimental import numpy as tnp - LIBRARIES['tensorflow'] = tnp + + LIBRARIES["tensorflow"] = tnp except ImportError: pass @@ -66,12 +66,12 @@ def get_backend(backend): """ if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None: msg = ( - '{0} backend not possible, please ensure that ' - + 'the optional libraries are installed.\n' - + 'Reverting to numpy.' + "{0} backend not possible, please ensure that " + + "the optional libraries are installed.\n" + + "Reverting to numpy." ) warn(msg.format(backend)) - backend = 'numpy' + backend = "numpy" return LIBRARIES[backend], backend @@ -92,16 +92,16 @@ def get_array_module(input_data): The numpy or cupy module """ - if LIBRARIES['tensorflow'] is not None: - if isinstance(input_data, LIBRARIES['tensorflow'].ndarray): - return LIBRARIES['tensorflow'] - if LIBRARIES['cupy'] is not None: - if isinstance(input_data, LIBRARIES['cupy'].ndarray): - return LIBRARIES['cupy'] + if LIBRARIES["tensorflow"] is not None: + if isinstance(input_data, LIBRARIES["tensorflow"].ndarray): + return LIBRARIES["tensorflow"] + if LIBRARIES["cupy"] is not None: + if isinstance(input_data, LIBRARIES["cupy"].ndarray): + return LIBRARIES["cupy"] return np -def change_backend(input_data, backend='cupy'): +def change_backend(input_data, backend="cupy"): """Move data to device. This method changes the backend of an array. This can be used to copy data @@ -151,13 +151,13 @@ def move_to_cpu(input_data): """ xp = get_array_module(input_data) - if xp == LIBRARIES['numpy']: + if xp == LIBRARIES["numpy"]: return input_data - elif xp == LIBRARIES['cupy']: + elif xp == LIBRARIES["cupy"]: return input_data.get() - elif xp == LIBRARIES['tensorflow']: + elif xp == LIBRARIES["tensorflow"]: return input_data.data.numpy() - raise ValueError('Cannot identify the array type.') + raise ValueError("Cannot identify the array type.") def convert_to_tensor(input_data): @@ -184,9 +184,9 @@ def convert_to_tensor(input_data): """ if not import_torch: raise ImportError( - 'Required version of Torch package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Torch package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) xp = get_array_module(input_data) @@ -220,9 +220,9 @@ def convert_to_cupy_array(input_data): """ if not import_torch: raise ImportError( - 'Required version of Torch package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Torch package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) if input_data.is_cuda: diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py similarity index 96% rename from modopt/base/np_adjust.py rename to src/modopt/base/np_adjust.py index 6d290e43..10cb5c29 100644 --- a/modopt/base/np_adjust.py +++ b/src/modopt/base/np_adjust.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """NUMPY ADJUSTMENT ROUTINES. This module contains methods for adjusting the default output for certain @@ -154,8 +152,7 @@ def pad2d(input_data, padding): padding = np.array(padding) elif not isinstance(padding, np.ndarray): raise ValueError( - 'Padding must be an integer or a tuple (or list, np.ndarray) ' - + 'of itegers', + "Padding must be an integer or a tuple (or list, np.ndarray) of integers", ) if padding.size == 1: @@ -164,7 +161,7 @@ def pad2d(input_data, padding): pad_x = (padding[0], padding[0]) pad_y = (padding[1], padding[1]) - return np.pad(input_data, (pad_x, pad_y), 'constant') + return np.pad(input_data, (pad_x, pad_y), "constant") def ftr(input_data): diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py similarity index 95% rename from modopt/base/observable.py rename to src/modopt/base/observable.py index 6471ba58..bf8371c3 100644 --- a/modopt/base/observable.py +++ b/src/modopt/base/observable.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """Observable. This module contains observable classes @@ -13,13 +11,13 @@ import numpy as np -class SignalObject(object): +class SignalObject: """Dummy class for signals.""" pass -class Observable(object): +class Observable: """Base class for observable classes. This class defines a simple interface to add or remove observers @@ -33,7 +31,6 @@ class Observable(object): """ def __init__(self, signals): - # Define class parameters self._allowed_signals = [] self._observers = {} @@ -177,7 +174,7 @@ def _remove_observer(self, signal, observer): self._observers[signal].remove(observer) -class MetricObserver(object): +class MetricObserver: """Metric observer. Wrapper of the metric to the observer object notify by the Observable @@ -215,7 +212,6 @@ def __init__( wind=6, eps=1.0e-3, ): - self.name = name self.metric = metric self.mapping = mapping @@ -264,9 +260,7 @@ def is_converge(self): mid_idx = -(self.wind // 2) old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean() current_mean = np.array(self.list_cv_values[mid_idx:]).mean() - normalize_residual_metrics = ( - np.abs(old_mean - current_mean) / np.abs(old_mean) - ) + normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean) self.converge_flag = normalize_residual_metrics < self.eps def retrieve_metrics(self): @@ -287,7 +281,7 @@ def retrieve_metrics(self): time_val -= time_val[0] return { - 'time': time_val, - 'index': self.list_iters, - 'values': self.list_cv_values, + "time": time_val, + "index": self.list_iters, + "values": self.list_cv_values, } diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py similarity index 87% rename from modopt/base/transform.py rename to src/modopt/base/transform.py index 07ce846f..25ed102a 100644 --- a/modopt/base/transform.py +++ b/src/modopt/base/transform.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """DATA TRANSFORM ROUTINES. This module contains methods for transforming data. @@ -53,18 +51,17 @@ def cube2map(data_cube, layout): """ if data_cube.ndim != 3: - raise ValueError('The input data must have 3 dimensions.') + raise ValueError("The input data must have 3 dimensions.") if data_cube.shape[0] != np.prod(layout): raise ValueError( - 'The desired layout must match the number of input ' - + 'data layers.', + "The desired layout must match the number of input " + "data layers.", ) - res = ([ + res = [ np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))]) for elem in range(layout[0]) - ]) + ] return np.vstack(res) @@ -118,20 +115,24 @@ def map2cube(data_map, layout): """ if np.all(np.array(data_map.shape) % np.array(layout)): raise ValueError( - 'The desired layout must be a multiple of the number ' - + 'pixels in the data map.', + "The desired layout must be a multiple of the number " + + "pixels in the data map.", ) d_shape = np.array(data_map.shape) // np.array(layout) - return np.array([ - data_map[( - slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]), - slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]), - )] - for i_elem in range(layout[0]) - for j_elem in range(layout[1]) - ]) + return np.array( + [ + data_map[ + ( + slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]), + slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]), + ) + ] + for i_elem in range(layout[0]) + for j_elem in range(layout[1]) + ] + ) def map2matrix(data_map, layout): @@ -186,9 +187,9 @@ def map2matrix(data_map, layout): image_shape * (i_elem % layout[1] + 1), ) data_matrix.append( - ( - data_map[lower[0]:upper[0], lower[1]:upper[1]] - ).reshape(image_shape ** 2), + (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape( + image_shape**2 + ), ) return np.array(data_matrix).T @@ -232,7 +233,7 @@ def matrix2map(data_matrix, map_shape): # Get the shape and layout of the images image_shape = np.sqrt(data_matrix.shape[0]).astype(int) - layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int') + layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int") # Map objects from matrix data_map = np.zeros(map_shape) @@ -248,7 +249,7 @@ def matrix2map(data_matrix, map_shape): image_shape * (i_elem // layout[1] + 1), image_shape * (i_elem % layout[1] + 1), ) - data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem] + data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem] return data_map.astype(int) @@ -285,7 +286,7 @@ def cube2matrix(data_cube): """ return data_cube.reshape( - [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])], + [data_cube.shape[0], np.prod(data_cube.shape[1:])], ).T @@ -330,4 +331,4 @@ def matrix2cube(data_matrix, im_shape): cube2matrix : complimentary function """ - return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape)) + return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)]) diff --git a/modopt/base/types.py b/src/modopt/base/types.py similarity index 74% rename from modopt/base/types.py rename to src/modopt/base/types.py index 88051675..9e9a15b9 100644 --- a/modopt/base/types.py +++ b/src/modopt/base/types.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """TYPE HANDLING ROUTINES. This module contains methods for handing object types. @@ -9,12 +7,10 @@ """ import numpy as np - -from modopt.base.wrappers import add_args_kwargs from modopt.interface.errors import warn -def check_callable(input_obj, add_agrs=True): +def check_callable(input_obj): """Check input object is callable. This method checks if the input operator is a callable funciton and @@ -25,30 +21,14 @@ def check_callable(input_obj, add_agrs=True): ---------- input_obj : callable Callable function - add_agrs : bool, optional - Option to add support for agrs and kwargs (default is ``True``) - - Returns - ------- - function - Function wrapped by ``add_args_kwargs`` Raises ------ TypeError For invalid input type - - See Also - -------- - modopt.base.wrappers.add_args_kwargs : wrapper used - """ if not callable(input_obj): - raise TypeError('The input object must be a callable function.') - - if add_agrs: - input_obj = add_args_kwargs(input_obj) - + raise TypeError("The input object must be a callable function.") return input_obj @@ -89,14 +69,13 @@ def check_float(input_obj): """ if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): - raise TypeError('Invalid input type.') + raise TypeError("Invalid input type.") if isinstance(input_obj, int): input_obj = float(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=float) - elif ( - isinstance(input_obj, np.ndarray) - and (not np.issubdtype(input_obj.dtype, np.floating)) + elif isinstance(input_obj, np.ndarray) and ( + not np.issubdtype(input_obj.dtype, np.floating) ): input_obj = input_obj.astype(float) @@ -139,14 +118,13 @@ def check_int(input_obj): """ if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): - raise TypeError('Invalid input type.') + raise TypeError("Invalid input type.") if isinstance(input_obj, float): input_obj = int(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=int) - elif ( - isinstance(input_obj, np.ndarray) - and (not np.issubdtype(input_obj.dtype, np.integer)) + elif isinstance(input_obj, np.ndarray) and ( + not np.issubdtype(input_obj.dtype, np.integer) ): input_obj = input_obj.astype(int) @@ -178,19 +156,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True): """ if not isinstance(input_obj, np.ndarray): - raise TypeError('Input is not a numpy array.') + raise TypeError("Input is not a numpy array.") - if ( - (not isinstance(dtype, type(None))) - and (not np.issubdtype(input_obj.dtype, dtype)) + if (not isinstance(dtype, type(None))) and ( + not np.issubdtype(input_obj.dtype, dtype) ): raise ( TypeError( - 'The numpy array elements are not of type: {0}'.format(dtype), + f"The numpy array elements are not of type: {dtype}", ), ) if not writeable and verbose and input_obj.flags.writeable: - warn('Making input data immutable.') + warn("Making input data immutable.") input_obj.flags.writeable = writeable diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py similarity index 75% rename from modopt/interface/__init__.py rename to src/modopt/interface/__init__.py index f9439747..a54f4bf5 100644 --- a/modopt/interface/__init__.py +++ b/src/modopt/interface/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """INTERFACE ROUTINES. This module contains submodules for error handling, logging and IO interaction. @@ -8,4 +6,4 @@ """ -__all__ = ['errors', 'log'] +__all__ = ["errors", "log"] diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py similarity index 75% rename from modopt/interface/errors.py rename to src/modopt/interface/errors.py index 0fbe7e71..84031e3c 100644 --- a/modopt/interface/errors.py +++ b/src/modopt/interface/errors.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ERROR HANDLING ROUTINES. This module contains methods for handing warnings and errors. @@ -34,16 +32,16 @@ def warn(warn_string, log=None): """ if import_fail: - warn_txt = 'WARNING' + warn_txt = "WARNING" else: - warn_txt = colored('WARNING', 'yellow') + warn_txt = colored("WARNING", "yellow") # Print warning to stdout. - sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string)) + sys.stderr.write(f"{warn_txt}: {warn_string}\n") # Check if a logging structure is provided. if not isinstance(log, type(None)): - warnings.warn(warn_string) + warnings.warn(warn_string, stacklevel=2) def catch_error(exception, log=None): @@ -61,17 +59,17 @@ def catch_error(exception, log=None): """ if import_fail: - err_txt = 'ERROR' + err_txt = "ERROR" else: - err_txt = colored('ERROR', 'red') + err_txt = colored("ERROR", "red") # Print exception to stdout. - stream_txt = '{0}: {1}\n'.format(err_txt, exception) + stream_txt = f"{err_txt}: {exception}\n" sys.stderr.write(stream_txt) # Check if a logging structure is provided. if not isinstance(log, type(None)): - log_txt = 'ERROR: {0}\n'.format(exception) + log_txt = f"ERROR: {exception}\n" log.exception(log_txt) @@ -91,11 +89,11 @@ def file_name_error(file_name): If file name not specified or file not found """ - if file_name == '' or file_name[0][0] == '-': - raise IOError('Input file name not specified.') + if file_name == "" or file_name[0][0] == "-": + raise OSError("Input file name not specified.") elif not os.path.isfile(file_name): - raise IOError('Input file name {0} not found!'.format(file_name)) + raise OSError(f"Input file name {file_name} not found!") def is_exe(fpath): @@ -136,7 +134,7 @@ def is_executable(exe_name): """ if not isinstance(exe_name, str): - raise TypeError('Executable name must be a string.') + raise TypeError("Executable name must be a string.") fpath, fname = os.path.split(exe_name) @@ -146,11 +144,9 @@ def is_executable(exe_name): else: res = any( is_exe(os.path.join(path, exe_name)) - for path in os.environ['PATH'].split(os.pathsep) + for path in os.environ["PATH"].split(os.pathsep) ) if not res: - message = ( - '{0} does not appear to be a valid executable on this system.' - ) - raise IOError(message.format(exe_name)) + message = "{0} does not appear to be a valid executable on this system." + raise OSError(message.format(exe_name)) diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py similarity index 77% rename from modopt/interface/log.py rename to src/modopt/interface/log.py index 3b2fa77a..50c316b7 100644 --- a/modopt/interface/log.py +++ b/src/modopt/interface/log.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """LOGGING ROUTINES. This module contains methods for handing logging. @@ -30,22 +28,22 @@ def set_up_log(filename, verbose=True): """ # Add file extension. - filename = '{0}.log'.format(filename) + filename = f"{filename}.log" if verbose: - print('Preparing log file:', filename) + print("Preparing log file:", filename) # Capture warnings. logging.captureWarnings(True) # Set output format. formatter = logging.Formatter( - fmt='%(asctime)s %(message)s', - datefmt='%d/%m/%Y %H:%M:%S', + fmt="%(asctime)s %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", ) # Create file handler. - fh = logging.FileHandler(filename=filename, mode='w') + fh = logging.FileHandler(filename=filename, mode="w") fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) @@ -55,7 +53,7 @@ def set_up_log(filename, verbose=True): log.addHandler(fh) # Send opening message. - log.info('The log file has been set-up.') + log.info("The log file has been set-up.") return log @@ -74,10 +72,10 @@ def close_log(log, verbose=True): """ if verbose: - print('Closing log file:', log.name) + print("Closing log file:", log.name) # Send closing message. - log.info('The log file has been closed.') + log.info("The log file has been closed.") # Remove all handlers from log. for log_handler in log.handlers: diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py similarity index 64% rename from modopt/math/__init__.py rename to src/modopt/math/__init__.py index a22c0c98..d5ffc67a 100644 --- a/modopt/math/__init__.py +++ b/src/modopt/math/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """MATHEMATICS ROUTINES. This module contains submodules for mathematical applications. @@ -8,4 +6,4 @@ """ -__all__ = ['convolve', 'matrix', 'stats', 'metrics'] +__all__ = ["convolve", "matrix", "stats", "metrics"] diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py similarity index 87% rename from modopt/math/convolve.py rename to src/modopt/math/convolve.py index a4322ff2..21dc8b4e 100644 --- a/modopt/math/convolve.py +++ b/src/modopt/math/convolve.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """CONVOLUTION ROUTINES. This module contains methods for convolution. @@ -18,7 +16,7 @@ from astropy.convolution import convolve_fft except ImportError: # pragma: no cover import_astropy = False - warn('astropy not found, will default to scipy for convolution') + warn("astropy not found, will default to scipy for convolution") else: import_astropy = True try: @@ -30,7 +28,7 @@ warn('Using pyFFTW "monkey patch" for scipy.fftpack') -def convolve(input_data, kernel, method='scipy'): +def convolve(input_data, kernel, method="scipy"): """Convolve data with kernel. This method convolves the input data with a given kernel using FFT and @@ -80,29 +78,29 @@ def convolve(input_data, kernel, method='scipy'): """ if input_data.ndim != kernel.ndim: - raise ValueError('Data and kernel must have the same dimensions.') + raise ValueError("Data and kernel must have the same dimensions.") - if method not in {'astropy', 'scipy'}: + if method not in {"astropy", "scipy"}: raise ValueError('Invalid method. Options are "astropy" or "scipy".') if not import_astropy: # pragma: no cover - method = 'scipy' + method = "scipy" - if method == 'astropy': + if method == "astropy": return convolve_fft( input_data, kernel, - boundary='wrap', + boundary="wrap", crop=False, - nan_treatment='fill', + nan_treatment="fill", normalize_kernel=False, ) - elif method == 'scipy': - return scipy.signal.fftconvolve(input_data, kernel, mode='same') + elif method == "scipy": + return scipy.signal.fftconvolve(input_data, kernel, mode="same") -def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'): +def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"): """Convolve stack of data with stack of kernels. This method convolves the input data with a given kernel using FFT and @@ -156,7 +154,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'): if rot_kernel: kernel = rotate_stack(kernel) - return np.array([ - convolve(data_i, kernel_i, method=method) - for data_i, kernel_i in zip(input_data, kernel) - ]) + return np.array( + [ + convolve(data_i, kernel_i, method=method) + for data_i, kernel_i in zip(input_data, kernel) + ] + ) diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py similarity index 88% rename from modopt/math/matrix.py rename to src/modopt/math/matrix.py index 939cf41f..b200f15d 100644 --- a/modopt/math/matrix.py +++ b/src/modopt/math/matrix.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """MATRIX ROUTINES. This module contains methods for matrix operations. @@ -15,7 +13,7 @@ from modopt.base.backend import get_array_module, get_backend -def gram_schmidt(matrix, return_opt='orthonormal'): +def gram_schmidt(matrix, return_opt="orthonormal"): r"""Gram-Schmit. This method orthonormalizes the row vectors of the input matrix. @@ -55,7 +53,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'): https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process """ - if return_opt not in {'orthonormal', 'orthogonal', 'both'}: + if return_opt not in {"orthonormal", "orthogonal", "both"}: raise ValueError( 'Invalid return_opt, options are: "orthonormal", "orthogonal" or ' + '"both"', @@ -65,7 +63,6 @@ def gram_schmidt(matrix, return_opt='orthonormal'): e_vec = [] for vector in matrix: - if u_vec: u_now = vector - sum(project(u_i, vector) for u_i in u_vec) else: @@ -77,11 +74,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'): u_vec = np.array(u_vec) e_vec = np.array(e_vec) - if return_opt == 'orthonormal': + if return_opt == "orthonormal": return e_vec - elif return_opt == 'orthogonal': + elif return_opt == "orthogonal": return u_vec - elif return_opt == 'both': + elif return_opt == "both": return u_vec, e_vec @@ -201,7 +198,7 @@ def rot_matrix(angle): return np.around( np.array( [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], - dtype='float', + dtype="float", ), 10, ) @@ -243,22 +240,21 @@ def rotate(matrix, angle): shape = np.array(matrix.shape) if shape[0] != shape[1]: - raise ValueError('Input matrix must be square.') + raise ValueError("Input matrix must be square.") shift = (shape - 1) // 2 index = ( - np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - - shift + np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift ) - new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift + new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift new_index[new_index >= shape[0]] -= shape[0] return matrix[tuple(zip(new_index.T))].reshape(shape.T) -class PowerMethod(object): +class PowerMethod: """Power method class. This method performs implements power method to calculate the spectral @@ -277,6 +273,8 @@ class PowerMethod(object): initialisation (default is ``True``) verbose : bool, optional Optional verbosity (default is ``False``) + rng: int, xp.random.Generator or None (default is ``None``) + Random number generator or seed. Examples -------- @@ -285,9 +283,9 @@ class PowerMethod(object): >>> np.random.seed(1) >>> pm = PowerMethod(lambda x: x.dot(x.T), (3, 3)) >>> np.around(pm.spec_rad, 6) - 0.904292 + 1.0 >>> np.around(pm.inv_spec_rad, 6) - 1.105837 + 1.0 Notes ----- @@ -301,16 +299,17 @@ def __init__( data_shape, data_type=float, auto_run=True, - compute_backend='numpy', + compute_backend="numpy", verbose=False, + rng=None, ): - self._operator = operator self._data_shape = data_shape self._data_type = data_type self._verbose = verbose xp, compute_backend = get_backend(compute_backend) self.xp = xp + self.rng = None self.compute_backend = compute_backend if auto_run: self.get_spec_rad() @@ -327,7 +326,8 @@ def _set_initial_x(self): Random values of the same shape as the input data """ - return self.xp.random.random(self._data_shape).astype(self._data_type) + rng = self.xp.random.default_rng(self.rng) + return rng.random(self._data_shape).astype(self._data_type) def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): """Get spectral radius. @@ -348,32 +348,33 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): # Set (or reset) values of x. x_old = self._set_initial_x() + xp = get_array_module(x_old) + x_old_norm = xp.linalg.norm(x_old) + + x_old /= x_old_norm + # Iterate until the L2 norm of x converges. for i_elem in range(max_iter): - xp = get_array_module(x_old) - x_old_norm = xp.linalg.norm(x_old) - - x_new = self._operator(x_old) / x_old_norm + x_new = self._operator(x_old) x_new_norm = xp.linalg.norm(x_new) - if (xp.abs(x_new_norm - x_old_norm) < tolerance): - message = ( - ' - Power Method converged after {0} iterations!' - ) + x_new /= x_new_norm + + if xp.abs(x_new_norm - x_old_norm) < tolerance: + message = " - Power Method converged after {0} iterations!" if self._verbose: print(message.format(i_elem + 1)) break elif i_elem == max_iter - 1 and self._verbose: - message = ( - ' - Power Method did not converge after {0} iterations!' - ) + message = " - Power Method did not converge after {0} iterations!" print(message.format(max_iter)) xp.copyto(x_old, x_new) + x_old_norm = x_new_norm self.spec_rad = x_new_norm * extra_factor self.inv_spec_rad = 1.0 / self.spec_rad diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py similarity index 90% rename from modopt/math/metrics.py rename to src/modopt/math/metrics.py index cf41a9c2..befd4fa4 100644 --- a/modopt/math/metrics.py +++ b/src/modopt/math/metrics.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """METRICS. This module contains classes of different metric functions for optimization. @@ -23,7 +21,7 @@ def min_max_normalize(img): """Min-Max Normalize. - Centre and normalize a given array. + Normalize a given array in the [0,1] range. Parameters ---------- @@ -33,7 +31,7 @@ def min_max_normalize(img): Returns ------- numpy.ndarray - Centred and normalized array + normalized array """ min_img = img.min() @@ -71,15 +69,13 @@ def _preprocess_input(test, ref, mask=None): The SNR """ - test = np.abs(np.copy(test)).astype('float64') - ref = np.abs(np.copy(ref)).astype('float64') + test = np.abs(np.copy(test)).astype("float64") + ref = np.abs(np.copy(ref)).astype("float64") test = min_max_normalize(test) ref = min_max_normalize(ref) if (not isinstance(mask, np.ndarray)) and (mask is not None): - message = ( - 'Mask should be None, or a numpy.ndarray, got "{0}" instead.' - ) + message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.' raise ValueError(message.format(mask)) if mask is None: @@ -119,9 +115,9 @@ def ssim(test, ref, mask=None): """ if not import_skimage: # pragma: no cover raise ImportError( - 'Required version of Scikit-Image package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Scikit-Image package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) test, ref, mask = _preprocess_input(test, ref, mask) @@ -270,6 +266,6 @@ def nrmse(test, ref, mask=None): ref = mask * ref num = np.sqrt(mse(test, ref)) - deno = np.sqrt(np.mean((np.square(test)))) + deno = np.sqrt(np.mean(np.square(test))) return num / deno diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py similarity index 85% rename from modopt/math/stats.py rename to src/modopt/math/stats.py index 3ac818a7..8583a8c3 100644 --- a/modopt/math/stats.py +++ b/src/modopt/math/stats.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """STATISTICS ROUTINES. This module contains methods for basic statistics. @@ -11,6 +9,8 @@ import numpy as np try: + from packaging import version + from astropy import __version__ as astropy_version from astropy.convolution import Gaussian2DKernel except ImportError: # pragma: no cover import_astropy = False @@ -18,7 +18,7 @@ import_astropy = True -def gaussian_kernel(data_shape, sigma, norm='max'): +def gaussian_kernel(data_shape, sigma, norm="max"): """Gaussian kernel. This method produces a Gaussian kerenal of a specified size and dispersion. @@ -29,9 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm='max'): Desiered shape of the kernel sigma : float Standard deviation of the kernel - norm : {'max', 'sum', 'none'}, optional - Normalisation of the kerenl (options are ``'max'``, ``'sum'`` or - ``'none'``, default is ``'max'``) + norm : {'max', 'sum'}, optional, default='max' + Normalisation of the kernel Returns ------- @@ -60,22 +59,22 @@ def gaussian_kernel(data_shape, sigma, norm='max'): """ if not import_astropy: # pragma: no cover - raise ImportError('Astropy package not found.') + raise ImportError("Astropy package not found.") - if norm not in {'max', 'sum', 'none'}: + if norm not in {"max", "sum"}: raise ValueError('Invalid norm, options are "max", "sum" or "none".') kernel = np.array( Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]), ) - if norm == 'max': + if norm == "max": return kernel / np.max(kernel) - elif norm == 'sum': + elif version.parse(astropy_version) < version.parse("5.2"): return kernel / np.sum(kernel) - elif norm == 'none': + else: return kernel @@ -147,7 +146,7 @@ def mse(data1, data2): return np.mean((data1 - data2) ** 2) -def psnr(data1, data2, method='starck', max_pix=255): +def psnr(data1, data2, method="starck", max_pix=255): r"""Peak Signal-to-Noise Ratio. This method calculates the Peak Signal-to-Noise Ratio between two data @@ -202,23 +201,21 @@ def psnr(data1, data2, method='starck', max_pix=255): 10\log_{10}(\mathrm{MSE})) """ - if method == 'starck': - return ( - 20 * np.log10( - (data1.shape[0] * np.abs(np.max(data1) - np.min(data1))) - / np.linalg.norm(data1 - data2), - ) + if method == "starck": + return 20 * np.log10( + (data1.shape[0] * np.abs(np.max(data1) - np.min(data1))) + / np.linalg.norm(data1 - data2), ) - elif method == 'wiki': - return (20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))) + elif method == "wiki": + return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)) raise ValueError( 'Invalid PSNR method. Options are "starck" and "wiki"', ) -def psnr_stack(data1, data2, metric=np.mean, method='starck'): +def psnr_stack(data1, data2, metric=np.mean, method="starck"): """Peak Signa-to-Noise for stack of images. This method calculates the PSNRs for two stacks of 2D arrays. @@ -261,12 +258,11 @@ def psnr_stack(data1, data2, metric=np.mean, method='starck'): """ if data1.ndim != 3 or data2.ndim != 3: - raise ValueError('Input data must be a 3D np.ndarray') + raise ValueError("Input data must be a 3D np.ndarray") - return metric([ - psnr(i_elem, j_elem, method=method) - for i_elem, j_elem in zip(data1, data2) - ]) + return metric( + [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)] + ) def sigma_mad(input_data): diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py similarity index 59% rename from modopt/opt/__init__.py rename to src/modopt/opt/__init__.py index 2fd3d747..62d1f388 100644 --- a/modopt/opt/__init__.py +++ b/src/modopt/opt/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """OPTIMISATION PROBLEM MODULES. This module contains submodules for solving optimisation problems. @@ -8,4 +6,4 @@ """ -__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight'] +__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"] diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py similarity index 53% rename from modopt/opt/algorithms/__init__.py rename to src/modopt/opt/algorithms/__init__.py index e0ac2572..ff79502c 100644 --- a/modopt/opt/algorithms/__init__.py +++ b/src/modopt/opt/algorithms/__init__.py @@ -1,5 +1,4 @@ -# -*- coding: utf-8 -*- -r"""OPTIMISATION ALGOTITHMS. +r"""OPTIMISATION ALGORITHMS. This module contains class implementations of various optimisation algoritms. @@ -45,15 +44,32 @@ """ -from modopt.opt.algorithms.base import SetUp -from modopt.opt.algorithms.forward_backward import (FISTA, POGM, - ForwardBackward, - GenForwardBackward) -from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt, - ADAMGradOpt, - GenericGradOpt, - MomentumGradOpt, - RMSpropGradOpt, - SAGAOptGradOpt, - VanillaGenericGradOpt) -from modopt.opt.algorithms.primal_dual import Condat +from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM +from .primal_dual import Condat +from .gradient_descent import ( + ADAMGradOpt, + AdaGenericGradOpt, + GenericGradOpt, + MomentumGradOpt, + RMSpropGradOpt, + SAGAOptGradOpt, + VanillaGenericGradOpt, +) +from .admm import ADMM, FastADMM + +__all__ = [ + "FISTA", + "ForwardBackward", + "GenForwardBackward", + "POGM", + "Condat", + "ADAMGradOpt", + "AdaGenericGradOpt", + "GenericGradOpt", + "MomentumGradOpt", + "RMSpropGradOpt", + "SAGAOptGradOpt", + "VanillaGenericGradOpt", + "ADMM", + "FastADMM", +] diff --git a/src/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py new file mode 100644 index 00000000..b2f45171 --- /dev/null +++ b/src/modopt/opt/algorithms/admm.py @@ -0,0 +1,340 @@ +"""ADMM Algorithms.""" + +import numpy as np + +from modopt.base.backend import get_array_module +from modopt.opt.algorithms.base import SetUp +from modopt.opt.cost import CostParent + + +class ADMMcostObj(CostParent): + r"""Cost Object for the ADMM problem class. + + Parameters + ---------- + cost_funcs: 2-tuples of callable + f and g function. + A : OperatorBase + First Operator + B : OperatorBase + Second Operator + b : numpy.ndarray + Observed data + **kwargs : dict + Extra parameters for cost operator configuration + + Notes + ----- + Compute :math:`f(u)+g(v) + \tau \| Au +Bv - b\|^2` + + See Also + -------- + CostParent: parent class + """ + + def __init__(self, cost_funcs, A, B, b, tau, **kwargs): + super().__init__(*kwargs) + self.cost_funcs = cost_funcs + self.A = A + self.B = B + self.b = b + self.tau = tau + + def _calc_cost(self, u, v, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + u: numpy.ndarray + First primal variable of ADMM + v: numpy.ndarray + Second primal variable of ADMM + + Returns + ------- + float + Cost value + + """ + xp = get_array_module(u) + cost = self.cost_funcs[0](u) + cost += self.cost_funcs[1](v) + cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b) + return cost + + +class ADMM(SetUp): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the ADMM algorithm described in :cite:`Goldstein2014` + (Algorithm 1). + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + + Notes + ----- + The algorithm solve the problem: + + .. math:: u, v = \arg\min H(u) + G(v) + \frac\tau2 \|Au + Bv - b \|_2^2 + + with the following augmented lagrangian: + + .. math :: \mathcal{L}_{\tau}(u,v, \lambda) = H(u) + G(v) + +\langle\lambda |Au + Bv -b \rangle + \frac\tau2 \| Au + Bv -b \|^2 + + To allow easy iterative solving, the change of variable + :math:`\mu=\lambda/\tau` is used. Hence, the lagrangian of interest is: + + .. math :: \tilde{\mathcal{L}}_{\tau}(u,v, \mu) = H(u) + G(v) + + \frac\tau2 \left(\|\mu + Au +Bv - b\|^2 - \|\mu\|^2\right) + + See Also + -------- + SetUp: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + tau=1, + cost_funcs=None, + **kwargs, + ): + super().__init__(**kwargs) + self.A = A + self.B = B + self.b = b + self._opti_H = optimizers[0] + self._opti_G = optimizers[1] + self._tau = tau + if cost_funcs is not None: + self._cost_func = ADMMcostObj(cost_funcs, A, B, b, tau) + else: + self._cost_func = None + + # init iteration variables. + self._u_old = self.xp.copy(u) + self._u_new = self.xp.copy(u) + self._v_old = self.xp.copy(v) + self._v_new = self.xp.copy(v) + self._mu_new = self.xp.copy(mu) + self._mu_old = self.xp.copy(mu) + + def _update(self): + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_old) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_old, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_old + (tmp + self.B.op(self._v_new) - self.b) + + # update cycle + self._u_old = self.xp.copy(self._u_new) + self._v_old = self.xp.copy(self._v_new) + self._mu_old = self.xp.copy(self._mu_new) + + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.converge |= self._cost_func.get_cost(self._u_new, self._v_new) + + def iterate(self, max_iter=150): + """Iterate. + + This method calls update until either convergence criteria is met or + the maximum number of iterations is reached. + + Parameters + ---------- + max_iter : int, optional + Maximum number of iterations (default is ``150``) + """ + self._run_alg(max_iter) + + # retrieve metrics results + self.retrieve_outputs() + # rename outputs as attributes + self.u_final = self._u_new + self.x_final = self.u_final # for backward compatibility + self.v_final = self._v_new + + def get_notify_observers_kwargs(self): + """Notify observers. + + Return the mapping between the metrics call and the iterated + variables. + + Returns + ------- + dict + The mapping between the iterated variables + """ + return { + "x_new": self._u_new, + "v_new": self._v_new, + "idx": self.idx, + } + + def retrieve_outputs(self): + """Retrieve outputs. + + Declare the outputs of the algorithms as attributes: x_final, + y_final, metrics. + """ + metrics = {} + for obs in self._observers["cv_metrics"]: + metrics[obs.name] = obs.retrieve_metrics() + self.metrics = metrics + + +class FastADMM(ADMM): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the fast ADMM algorithm + (Algorithm 8 from :cite:`Goldstein2014`) + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + eta: float, default=0.999 + Convergence parameter for ADMM. + alpha: float, default=1. + Initial value for the FISTA-like acceleration parameter. + + Notes + ----- + This is an accelerated version of the ADMM algorithm. The convergence hypothesis are + stronger than for the ADMM algorithm. + + See Also + -------- + ADMM: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + cost_funcs=None, + alpha=1, + eta=0.999, + tau=1, + **kwargs, + ): + super().__init__( + u=u, + v=b, + mu=mu, + A=A, + B=B, + b=b, + optimizers=optimizers, + cost_funcs=cost_funcs, + **kwargs, + ) + self._c_old = np.inf + self._c_new = 0 + self._eta = eta + self._alpha_old = alpha + self._alpha_new = alpha + self._v_hat = self.xp.copy(self._v_new) + self._mu_hat = self.xp.copy(self._mu_new) + + def _update(self): + # Classical ADMM steps + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_hat) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_hat, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_hat + (tmp + self.B.op(self._v_new) - self.b) + + # restarting condition + self._c_new = self.xp.linalg.norm(self._mu_new - self._mu_hat) + self._c_new += self._tau * self.xp.linalg.norm( + self.B.op(self._v_new - self._v_hat), + ) + if self._c_new < self._eta * self._c_old: + self._alpha_new = 1 + np.sqrt(1 + 4 * self._alpha_old**2) + beta = (self._alpha_new - 1) / self._alpha_old + self._v_hat = self._v_new + (self._v_new - self._v_old) * beta + self._mu_hat = self._mu_new + (self._mu_new - self._mu_old) * beta + else: + # reboot to old iteration + self._alpha_new = 1 + self._v_hat = self._v_old + self._mu_hat = self._mu_old + self._c_new = self._c_old / self._eta + + self.xp.copyto(self._u_old, self._u_new) + self.xp.copyto(self._v_old, self._v_new) + self.xp.copyto(self._mu_old, self._mu_new) + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.convergd |= self._cost_func.get_cost(self._u_new, self._v_new) diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py similarity index 67% rename from modopt/opt/algorithms/base.py rename to src/modopt/opt/algorithms/base.py index 85c36306..f7391063 100644 --- a/modopt/opt/algorithms/base.py +++ b/src/modopt/opt/algorithms/base.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """Base SetUp for optimisation algorithms.""" from inspect import getmro import numpy as np -from progressbar import ProgressBar +from tqdm.auto import tqdm from modopt.base import backend from modopt.base.observable import MetricObserver, Observable @@ -12,17 +11,17 @@ class SetUp(Observable): - r"""Algorithm Set-Up. + """Algorithm Set-Up. This class contains methods for checking the set-up of an optimisation - algotithm and produces warnings if they do not comply. + algorithm and produces warnings if they do not comply. Parameters ---------- metric_call_period : int, optional Metric call period (default is ``5``) metrics : dict, optional - Metrics to be used (default is ``\{\}``) + Metrics to be used (default is ``None``) verbose : bool, optional Option for verbose output (default is ``False``) progress : bool, optional @@ -34,11 +33,32 @@ class SetUp(Observable): use_gpu : bool, optional Option to use available GPU + Notes + ----- + If provided, the ``metrics`` argument should be a nested dictionary of the + following form:: + + metrics = { + 'metric_name': { + 'metric': callable, + 'mapping': {'x_new': 'test'}, + 'cst_kwargs': {'ref': ref_image}, + 'early_stopping': False, + } + } + + Where ``callable`` is a function with arguments being for instance + ``test`` and ``ref``. The mapping of the argument uses the same keys as the + output of ``get_notify_observer_kwargs``, ``cst_kwargs`` defines constant + arguments that will always be passed to the metric call. + If ``early_stopping`` is True, the metric will be used to check for + convergence of the algorithm, in that case it is recommended to have + ``metric_call_period = 1`` + See Also -------- modopt.base.observable.Observable : parent class modopt.base.observable.MetricObserver : definition of metrics - """ def __init__( @@ -48,7 +68,7 @@ def __init__( verbose=False, progress=True, step_size=None, - compute_backend='numpy', + compute_backend="numpy", **dummy_kwargs, ): self.idx = 0 @@ -58,26 +78,26 @@ def __init__( self.metrics = metrics self.step_size = step_size self._op_parents = ( - 'GradParent', - 'ProximityParent', - 'LinearParent', - 'costObj', + "GradParent", + "ProximityParent", + "LinearParent", + "costObj", ) self.metric_call_period = metric_call_period # Declaration of observers for metrics - super().__init__(['cv_metrics']) + super().__init__(["cv_metrics"]) for name, dic in self.metrics.items(): observer = MetricObserver( name, - dic['metric'], - dic['mapping'], - dic['cst_kwargs'], - dic['early_stopping'], + dic["metric"], + dic["mapping"], + dic["cst_kwargs"], + dic["early_stopping"], ) - self.add_observer('cv_metrics', observer) + self.add_observer("cv_metrics", observer) xp, compute_backend = backend.get_backend(compute_backend) self.xp = xp @@ -90,14 +110,13 @@ def metrics(self): @metrics.setter def metrics(self, metrics): - if isinstance(metrics, type(None)): self._metrics = {} elif isinstance(metrics, dict): self._metrics = metrics else: raise TypeError( - 'Metrics must be a dictionary, not {0}.'.format(type(metrics)), + f"Metrics must be a dictionary, not {type(metrics)}.", ) def any_convergence_flag(self): @@ -111,9 +130,7 @@ def any_convergence_flag(self): True if any convergence criteria met """ - return any( - obs.converge_flag for obs in self._observers['cv_metrics'] - ) + return any(obs.converge_flag for obs in self._observers["cv_metrics"]) def copy_data(self, input_data): """Copy Data. @@ -131,10 +148,12 @@ def copy_data(self, input_data): Copy of input data """ - return self.xp.copy(backend.change_backend( - input_data, - self.compute_backend, - )) + return self.xp.copy( + backend.change_backend( + input_data, + self.compute_backend, + ) + ) def _check_input_data(self, input_data): """Check input data type. @@ -154,7 +173,7 @@ def _check_input_data(self, input_data): """ if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))): raise TypeError( - 'Input data must be a numpy array or backend array', + "Input data must be a numpy array or backend array", ) def _check_param(self, param_val): @@ -174,7 +193,7 @@ def _check_param(self, param_val): """ if not isinstance(param_val, float): - raise TypeError('Algorithm parameter must be a float value.') + raise TypeError("Algorithm parameter must be a float value.") def _check_param_update(self, param_update): """Check algorithm parameter update methods. @@ -192,14 +211,13 @@ def _check_param_update(self, param_update): For invalid input type """ - param_conditions = ( - not isinstance(param_update, type(None)) - and not callable(param_update) + param_conditions = not isinstance(param_update, type(None)) and not callable( + param_update ) if param_conditions: raise TypeError( - 'Algorithm parameter update must be a callabale function.', + "Algorithm parameter update must be a callabale function.", ) def _check_operator(self, operator): @@ -218,7 +236,7 @@ def _check_operator(self, operator): tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)] if not any(parent in tree for parent in self._op_parents): - message = '{0} does not inherit an operator parent.' + message = "{0} does not inherit an operator parent." warn(message.format(str(operator.__class__))) def _compute_metrics(self): @@ -229,7 +247,7 @@ def _compute_metrics(self): """ kwargs = self.get_notify_observers_kwargs() - self.notify_observers('cv_metrics', **kwargs) + self.notify_observers("cv_metrics", **kwargs) def _iterations(self, max_iter, progbar=None): """Iterate method. @@ -240,9 +258,8 @@ def _iterations(self, max_iter, progbar=None): ---------- max_iter : int Maximum number of iterations - progbar : progressbar.bar.ProgressBar - Progress bar (default is ``None``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ for idx in range(max_iter): self.idx = idx @@ -253,7 +270,6 @@ def _iterations(self, max_iter, progbar=None): # We do not call metrics if metrics is empty or metric call # period is None if self.metrics and self.metric_call_period is not None: - metric_conditions = ( self.idx % self.metric_call_period == 0 or self.idx == (max_iter - 1) @@ -265,13 +281,13 @@ def _iterations(self, max_iter, progbar=None): if self.converge: if self.verbose: - print(' - Converged!') + print(" - Converged!") break - if not isinstance(progbar, type(None)): - progbar.update(idx) + if progbar: + progbar.update() - def _run_alg(self, max_iter): + def _run_alg(self, max_iter, progbar=None): """Run algorithm. Run the update step of a given algorithm up to the maximum number of @@ -281,17 +297,34 @@ def _run_alg(self, max_iter): ---------- max_iter : int Maximum number of iterations + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) See Also -------- - progressbar.bar.ProgressBar + tqdm.tqdm """ - if self.progress: - with ProgressBar( - redirect_stdout=True, - max_value=max_iter, - ) as progbar: - self._iterations(max_iter, progbar=progbar) + if self.progress and progbar is None: + with tqdm(total=max_iter) as pb: + self._iterations(max_iter, progbar=pb) + elif progbar: + self._iterations(max_iter, progbar=progbar) else: self._iterations(max_iter) + + def _update(self): + raise NotImplementedError + + def get_notify_observers_kwargs(self): + """Notify Observers. + + Return the mapping between the metrics call and the iterated + variables. + + Raises + ------ + NotImplementedError + This method should be overriden by subclasses. + """ + raise NotImplementedError diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py similarity index 86% rename from modopt/opt/algorithms/forward_backward.py rename to src/modopt/opt/algorithms/forward_backward.py index e18f66c3..31927eb0 100644 --- a/modopt/opt/algorithms/forward_backward.py +++ b/src/modopt/opt/algorithms/forward_backward.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Forward-Backward Algorithms.""" import numpy as np @@ -9,7 +8,7 @@ from modopt.opt.linear import Identity -class FISTA(object): +class FISTA: r"""FISTA. This class is inherited by optimisation classes to speed up convergence @@ -52,12 +51,12 @@ class FISTA(object): """ _restarting_strategies = ( - 'adaptive', # option 1 in alg 4 - 'adaptive-i', - 'adaptive-1', - 'adaptive-ii', # option 2 in alg 4 - 'adaptive-2', - 'greedy', # alg 5 + "adaptive", # option 1 in alg 4 + "adaptive-i", + "adaptive-1", + "adaptive-ii", # option 2 in alg 4 + "adaptive-2", + "greedy", # alg 5 None, # no restarting ) @@ -73,26 +72,28 @@ def __init__( r_lazy=4, **kwargs, ): - if isinstance(a_cd, type(None)): - self.mode = 'regular' + self.mode = "regular" self.p_lazy = p_lazy self.q_lazy = q_lazy self.r_lazy = r_lazy elif a_cd > 2: - self.mode = 'CD' + self.mode = "CD" self.a_cd = a_cd self._n = 0 else: raise ValueError( - 'a_cd must either be None (for regular mode) or a number > 2', + "a_cd must either be None (for regular mode) or a number > 2", ) if restart_strategy in self._restarting_strategies: self._check_restart_params( - restart_strategy, min_beta, s_greedy, xi_restart, + restart_strategy, + min_beta, + s_greedy, + xi_restart, ) self.restart_strategy = restart_strategy self.min_beta = min_beta @@ -100,10 +101,10 @@ def __init__( self.xi_restart = xi_restart else: - message = 'Restarting strategy must be one of {0}.' + message = "Restarting strategy must be one of {0}." raise ValueError( message.format( - ', '.join(self._restarting_strategies), + ", ".join(self._restarting_strategies), ), ) self._t_now = 1.0 @@ -155,22 +156,20 @@ def _check_restart_params( if restart_strategy is None: return True - if self.mode != 'regular': + if self.mode != "regular": raise ValueError( - 'Restarting strategies can only be used with regular mode.', + "Restarting strategies can only be used with regular mode.", ) - greedy_params_check = ( - min_beta is None or s_greedy is None or s_greedy <= 1 - ) + greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1 - if restart_strategy == 'greedy' and greedy_params_check: + if restart_strategy == "greedy" and greedy_params_check: raise ValueError( - 'You need a min_beta and an s_greedy > 1 for greedy restart.', + "You need a min_beta and an s_greedy > 1 for greedy restart.", ) if xi_restart is None or xi_restart >= 1: - raise ValueError('You need a xi_restart < 1 for restart.') + raise ValueError("You need a xi_restart < 1 for restart.") return True @@ -210,12 +209,12 @@ def is_restart(self, z_old, x_new, x_old): criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0 if criterion: - if 'adaptive' in self.restart_strategy: + if "adaptive" in self.restart_strategy: self.r_lazy *= self.xi_restart - if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}: + if self.restart_strategy in {"adaptive-ii", "adaptive-2"}: self._t_now = 1 - if self.restart_strategy == 'greedy': + if self.restart_strategy == "greedy": cur_delta = xp.linalg.norm(x_new - x_old) if self._delta0 is None: self._delta0 = self.s_greedy * cur_delta @@ -269,17 +268,17 @@ def update_lambda(self, *args, **kwargs): Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`. """ - if self.restart_strategy == 'greedy': + if self.restart_strategy == "greedy": return 2 # Steps 3 and 4 from alg.10.7. self._t_prev = self._t_now - if self.mode == 'regular': - sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy + if self.mode == "regular": + sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5 - elif self.mode == 'CD': + elif self.mode == "CD": self._t_now = (self._n + self.a_cd - 1) / self.a_cd self._n += 1 @@ -344,18 +343,17 @@ def __init__( x, grad, prox, - cost='auto', + cost="auto", beta_param=1.0, lambda_param=1.0, beta_update=None, - lambda_update='fista', + lambda_update="fista", auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -376,7 +374,7 @@ def __init__( self._prox = prox self._linear = linear - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -384,7 +382,7 @@ def __init__( # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( - 'When using metrics, you must pass a linear operator', + "When using metrics, you must pass a linear operator", ) if self._linear is None: @@ -400,7 +398,7 @@ def __init__( # Set the algorithm parameter update methods self._check_param_update(beta_update) self._beta_update = beta_update - if isinstance(lambda_update, str) and lambda_update == 'fista': + if isinstance(lambda_update, str) and lambda_update == "fista": fista = FISTA(**kwargs) self._lambda_update = fista.update_lambda self._is_restart = fista.is_restart @@ -462,12 +460,11 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either the convergence criteria is met @@ -477,9 +474,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -499,9 +497,9 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._linear.adj_op(self._x_new), - 'z_new': self._z_new, - 'idx': self.idx, + "x_new": self._linear.adj_op(self._x_new), + "z_new": self._z_new, + "idx": self.idx, } def retrieve_outputs(self): @@ -512,7 +510,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -576,7 +574,7 @@ def __init__( x, grad, prox_list, - cost='auto', + cost="auto", gamma_param=1.0, lambda_param=1.0, gamma_update=None, @@ -588,7 +586,6 @@ def __init__( linear=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -601,22 +598,22 @@ def __init__( self._x_old = self.xp.copy(x) # Set the algorithm operators - for operator in [grad, cost] + prox_list: + for operator in [grad, cost, *prox_list]: self._check_operator(operator) self._grad = grad self._prox_list = self.xp.array(prox_list) self._linear = linear - if cost == 'auto': - self._cost_func = costObj([self._grad] + prox_list) + if cost == "auto": + self._cost_func = costObj([self._grad, *prox_list]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( - 'When using metrics, you must pass a linear operator', + "When using metrics, you must pass a linear operator", ) if self._linear is None: @@ -640,9 +637,7 @@ def __init__( self._set_weights(weights) # Set initial z - self._z = self.xp.array([ - self._x_old for i in range(self._prox_list.size) - ]) + self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)]) # Automatically run the algorithm if auto_iterate: @@ -672,25 +667,25 @@ def _set_weights(self, weights): self._prox_list.size, ) elif not isinstance(weights, (list, tuple, np.ndarray)): - raise TypeError('Weights must be provided as a list.') + raise TypeError("Weights must be provided as a list.") weights = self.xp.array(weights) if not np.issubdtype(weights.dtype, np.floating): - raise ValueError('Weights must be list of float values.') + raise ValueError("Weights must be list of float values.") if weights.size != self._prox_list.size: raise ValueError( - 'The number of weights must match the number of proximity ' - + 'operators.', + "The number of weights must match the number of proximity " + + "operators.", ) expected_weight_sum = 1.0 if self.xp.sum(weights) != expected_weight_sum: raise ValueError( - 'Proximity operator weights must sum to 1.0. Current sum of ' - + 'weights = {0}'.format(self.xp.sum(weights)), + "Proximity operator weights must sum to 1.0. Current sum of " + + f"weights = {self.xp.sum(weights)}", ) self._weights = weights @@ -725,9 +720,7 @@ def _update(self): # Update z values. for i in range(self._prox_list.size): - z_temp = ( - 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad - ) + z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad z_prox = self._prox_list[i].op( z_temp, extra_factor=self._gamma / self._weights[i], @@ -750,7 +743,7 @@ def _update(self): if self._cost_func: self.converge = self._cost_func.get_cost(self._x_new) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -760,9 +753,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -782,9 +776,9 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._linear.adj_op(self._x_new), - 'z_new': self._z, - 'idx': self.idx, + "x_new": self._linear.adj_op(self._x_new), + "z_new": self._z, + "idx": self.idx, } def retrieve_outputs(self): @@ -795,7 +789,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -815,9 +809,9 @@ class POGM(SetUp): Initial guess for the :math:`y` variable z : numpy.ndarray Initial guess for the :math:`z` variable - grad + grad : GradBasic Gradient operator class - prox + prox : ProximalParent Proximity operator class cost : class instance or str, optional Cost function class instance (default is ``'auto'``); Use ``'auto'`` to @@ -869,7 +863,7 @@ def __init__( z, grad, prox, - cost='auto', + cost="auto", linear=None, beta_param=1.0, sigma_bar=1.0, @@ -878,7 +872,6 @@ def __init__( metrics=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -903,7 +896,7 @@ def __init__( self._grad = grad self._prox = prox self._linear = linear - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -916,7 +909,7 @@ def __init__( for param_val in (beta_param, sigma_bar): self._check_param(param_val) if sigma_bar < 0 or sigma_bar > 1: - raise ValueError('The sigma bar parameter needs to be in [0, 1]') + raise ValueError("The sigma bar parameter needs to be in [0, 1]") self._beta = self.step_size or beta_param self._sigma_bar = sigma_bar @@ -942,16 +935,18 @@ def _update(self): """ # Step 4 from alg. 3 self._grad.get_grad(self._x_old) - self._u_new = self._x_old - self._beta * self._grad.grad + # self._u_new = self._x_old - self._beta * self._grad.grad + self._u_new = -self._beta * self._grad.grad + self._u_new += self._x_old # Step 5 from alg. 3 - self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2)) + self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2)) # Step 6 from alg. 3 t_shifted_ratio = (self._t_old - 1) / self._t_new sigma_t_ratio = self._sigma * self._t_old / self._t_new beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi - self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z) + self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z) self._z += self._u_new self._z += t_shifted_ratio * (self._u_new - self._u_old) self._z += sigma_t_ratio * (self._u_new - self._x_old) @@ -964,15 +959,18 @@ def _update(self): # Restarting and gamma-Decreasing # Step 9 from alg. 3 - self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi + # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi + self._g_new = self._z - self._x_new + self._g_new /= self._xi + self._g_new += self._grad.grad # Step 10 from alg 3. - self._y_new = self._x_old - self._beta * self._g_new + # self._y_new = self._x_old - self._beta * self._g_new + self._y_new = -self._beta * self._g_new + self._y_new += self._x_old # Step 11 from alg. 3 - restart_crit = ( - self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0 - ) + restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0 if restart_crit: self._t_new = 1 self._sigma = 1 @@ -990,12 +988,11 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -1005,9 +1002,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -1027,14 +1025,14 @@ def get_notify_observers_kwargs(self): """ return { - 'u_new': self._u_new, - 'x_new': self._linear.adj_op(self._x_new), - 'y_new': self._y_new, - 'z_new': self._z, - 'xi': self._xi, - 'sigma': self._sigma, - 't': self._t_new, - 'idx': self.idx, + "u_new": self._u_new, + "x_new": self._linear.adj_op(self._x_new), + "y_new": self._y_new, + "z_new": self._z, + "xi": self._xi, + "sigma": self._sigma, + "t": self._t_new, + "idx": self.idx, } def retrieve_outputs(self): @@ -1045,6 +1043,6 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py similarity index 95% rename from modopt/opt/algorithms/gradient_descent.py rename to src/modopt/opt/algorithms/gradient_descent.py index f3fe4b10..0960be5a 100644 --- a/modopt/opt/algorithms/gradient_descent.py +++ b/src/modopt/opt/algorithms/gradient_descent.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Gradient Descent Algorithms.""" import numpy as np @@ -103,7 +102,7 @@ def __init__( self._check_operator(operator) self._grad = grad self._prox = prox - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -157,9 +156,8 @@ def _update(self): self._eta = self._eta_update(self._eta, self.idx) # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) def _update_grad_dir(self, grad): @@ -208,10 +206,10 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._x_new, - 'dir_grad': self._dir_grad, - 'speed_grad': self._speed_grad, - 'idx': self.idx, + "x_new": self._x_new, + "dir_grad": self._dir_grad, + "speed_grad": self._speed_grad, + "idx": self.idx, } def retrieve_outputs(self): @@ -222,7 +220,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -308,7 +306,7 @@ class RMSpropGradOpt(GenericGradOpt): def __init__(self, *args, gamma=0.5, **kwargs): super().__init__(*args, **kwargs) if gamma < 0 or gamma > 1: - raise ValueError('gamma is outside of range [0,1]') + raise ValueError("gamma is outside of range [0,1]") self._check_param(gamma) self._gamma = gamma @@ -405,9 +403,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs): self._check_param(gamma) self._check_param(beta) if gamma < 0 or gamma >= 1: - raise ValueError('gamma is outside of range [0,1]') + raise ValueError("gamma is outside of range [0,1]") if beta < 0 or beta >= 1: - raise ValueError('beta is outside of range [0,1]') + raise ValueError("beta is outside of range [0,1]") self._gamma = gamma self._beta = beta self._beta_pow = 1 diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py similarity index 86% rename from modopt/opt/algorithms/primal_dual.py rename to src/modopt/opt/algorithms/primal_dual.py index c8566969..fee49a25 100644 --- a/modopt/opt/algorithms/primal_dual.py +++ b/src/modopt/opt/algorithms/primal_dual.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Primal-Dual Algorithms.""" from modopt.opt.algorithms.base import SetUp @@ -81,7 +80,7 @@ def __init__( prox, prox_dual, linear=None, - cost='auto', + cost="auto", reweight=None, rho=0.5, sigma=1.0, @@ -96,7 +95,6 @@ def __init__( metrics=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -123,12 +121,14 @@ def __init__( self._linear = Identity() else: self._linear = linear - if cost == 'auto': - self._cost_func = costObj([ - self._grad, - self._prox, - self._prox_dual, - ]) + if cost == "auto": + self._cost_func = costObj( + [ + self._grad, + self._prox, + self._prox_dual, + ] + ) else: self._cost_func = cost @@ -187,22 +187,17 @@ def _update(self): self._grad.get_grad(self._x_old) x_prox = self._prox.op( - self._x_old - self._tau * self._grad.grad - self._tau - * self._linear.adj_op(self._y_old), + self._x_old + - self._tau * self._grad.grad + - self._tau * self._linear.adj_op(self._y_old), ) # Step 2 from eq.9. - y_temp = ( - self._y_old + self._sigma - * self._linear.op(2 * x_prox - self._x_old) - ) + y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old) - y_prox = ( - y_temp - self._sigma - * self._prox_dual.op( - y_temp / self._sigma, - extra_factor=(1.0 / self._sigma), - ) + y_prox = y_temp - self._sigma * self._prox_dual.op( + y_temp / self._sigma, + extra_factor=(1.0 / self._sigma), ) # Step 3 from eq.9. @@ -220,12 +215,11 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new, self._y_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new, self._y_new ) - def iterate(self, max_iter=150, n_rewightings=1): + def iterate(self, max_iter=150, n_rewightings=1, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -237,14 +231,17 @@ def iterate(self, max_iter=150, n_rewightings=1): Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) if not isinstance(self._reweight, type(None)): for _ in range(n_rewightings): self._reweight.reweight(self._linear.op(self._x_new)) - self._run_alg(max_iter) + if progbar: + progbar.reset(total=max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -264,7 +261,7 @@ def get_notify_observers_kwargs(self): The mapping between the iterated variables """ - return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx} + return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx} def retrieve_outputs(self): """Retrieve outputs. @@ -274,6 +271,6 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py similarity index 67% rename from modopt/opt/cost.py rename to src/modopt/opt/cost.py index 3cdfcc50..37771f16 100644 --- a/modopt/opt/cost.py +++ b/src/modopt/opt/cost.py @@ -6,6 +6,8 @@ """ +import abc + import numpy as np from modopt.base.backend import get_array_module @@ -13,8 +15,8 @@ from modopt.plot.cost_plot import plotCost -class costObj(object): - """Generic cost function object. +class CostParent(abc.ABC): + """Abstract cost function object. This class updates the cost according to the input operator classes and tests for convergence. @@ -40,7 +42,8 @@ class costObj(object): Notes ----- - The costFunc class must contain a method called ``cost``. + All child classes should implement a ``_calc_cost`` method (returning + a float) or a ``get_cost`` for more complex behavior on convergence test. Examples -------- @@ -71,7 +74,6 @@ class costObj(object): def __init__( self, - operators, initial_cost=1e6, tolerance=1e-4, cost_interval=1, @@ -79,10 +81,6 @@ def __init__( verbose=True, plot_output=None, ): - - self._operators = operators - if not isinstance(operators, type(None)): - self._check_operators() self.cost = initial_cost self._cost_list = [] self._cost_interval = cost_interval @@ -93,30 +91,6 @@ def __init__( self._plot_output = plot_output self._verbose = verbose - def _check_operators(self): - """Check operators. - - This method checks if the input operators have a ``cost`` method. - - Raises - ------ - TypeError - For invalid operators type - ValueError - For operators without ``cost`` method - - """ - if not isinstance(self._operators, (list, tuple, np.ndarray)): - message = ( - 'Input operators must be provided as a list, not {0}' - ) - raise TypeError(message.format(type(self._operators))) - - for op in self._operators: - if not hasattr(op, 'cost'): - raise ValueError('Operators must contain "cost" method.') - op.cost = check_callable(op.cost) - def _check_cost(self): """Check cost function. @@ -137,20 +111,19 @@ def _check_cost(self): # Check if enough cost values have been collected if len(self._test_list) == self._test_range: - # The mean of the first half of the test list t1 = xp.mean( - xp.array(self._test_list[len(self._test_list) // 2:]), + xp.array(self._test_list[len(self._test_list) // 2 :]), axis=0, ) # The mean of the second half of the test list t2 = xp.mean( - xp.array(self._test_list[:len(self._test_list) // 2]), + xp.array(self._test_list[: len(self._test_list) // 2]), axis=0, ) # Calculate the change across the test list if xp.around(t1, decimals=16): - cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)) + cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1) else: cost_diff = 0 @@ -158,15 +131,16 @@ def _check_cost(self): self._test_list = [] if self._verbose: - print(' - CONVERGENCE TEST - ') - print(' - CHANGE IN COST:', cost_diff) - print('') + print(" - CONVERGENCE TEST - ") + print(" - CHANGE IN COST:", cost_diff) + print("") # Check for convergence return cost_diff <= self._tolerance return False + @abc.abstractmethod def _calc_cost(self, *args, **kwargs): """Calculate the cost. @@ -178,14 +152,7 @@ def _calc_cost(self, *args, **kwargs): Positional arguments **kwargs : dict Keyword arguments - - Returns - ------- - float - Cost value - """ - return np.sum([op.cost(*args, **kwargs) for op in self._operators]) def get_cost(self, *args, **kwargs): """Get cost function. @@ -207,8 +174,7 @@ def get_cost(self, *args, **kwargs): """ # Check if the cost should be calculated test_conditions = ( - self._cost_interval is None - or self._iteration % self._cost_interval + self._cost_interval is None or self._iteration % self._cost_interval ) if test_conditions: @@ -216,15 +182,15 @@ def get_cost(self, *args, **kwargs): else: if self._verbose: - print(' - ITERATION:', self._iteration) + print(" - ITERATION:", self._iteration) # Calculate the current cost - self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs) + self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs) self._cost_list.append(self.cost) if self._verbose: - print(' - COST:', self.cost) - print('') + print(" - COST:", self.cost) + print("") # Test for convergence test_result = self._check_cost() @@ -241,3 +207,108 @@ def plot_cost(self): # pragma: no cover """ plotCost(self._cost_list, self._plot_output) + + +class costObj(CostParent): + """Abstract cost function object. + + This class updates the cost according to the input operator classes and + tests for convergence. + + Parameters + ---------- + opertors : list, tuple or numpy.ndarray + List of operators classes containing ``cost`` method + initial_cost : float, optional + Initial value of the cost (default is ``1e6``) + tolerance : float, optional + Tolerance threshold for convergence (default is ``1e-4``) + cost_interval : int, optional + Iteration interval to calculate cost (default is ``1``). + If ``cost_interval`` is ``None`` the cost is never calculated, + thereby saving on computation time. + test_range : int, optional + Number of cost values to be used in test (default is ``4``) + verbose : bool, optional + Option for verbose output (default is ``True``) + plot_output : str, optional + Output file name for cost function plot + + Examples + -------- + >>> from modopt.opt.cost import * + >>> class dummy(object): + ... def cost(self, x): + ... return x ** 2 + ... + ... + >>> inst = costObj([dummy(), dummy()]) + >>> inst.get_cost(2) + - ITERATION: 1 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 2 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 3 + - COST: 8 + + False + """ + + def __init__( + self, + operators, + **kwargs, + ): + super().__init__(**kwargs) + + self._operators = operators + if not isinstance(operators, type(None)): + self._check_operators() + + def _check_operators(self): + """Check operators. + + This method checks if the input operators have a ``cost`` method. + + Raises + ------ + TypeError + For invalid operators type + ValueError + For operators without ``cost`` method + + """ + if not isinstance(self._operators, (list, tuple, np.ndarray)): + message = "Input operators must be provided as a list, not {0}" + raise TypeError(message.format(type(self._operators))) + + for op in self._operators: + if not hasattr(op, "cost"): + raise ValueError('Operators must contain "cost" method.') + op.cost = check_callable(op.cost) + + def _calc_cost(self, *args, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + *args : tuple + Positional arguments + **kwargs : dict + Keyword arguments + + Returns + ------- + float + Cost value + + """ + return np.sum([op.cost(*args, **kwargs) for op in self._operators]) diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py similarity index 97% rename from modopt/opt/gradient.py rename to src/modopt/opt/gradient.py index caa8fa9d..fe9b87d8 100644 --- a/modopt/opt/gradient.py +++ b/src/modopt/opt/gradient.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """GRADIENT CLASSES. This module contains classses for defining algorithm gradients. @@ -14,7 +12,7 @@ from modopt.base.types import check_callable, check_float, check_npndarray -class GradParent(object): +class GradParent: """Gradient Parent Class. This class defines the basic methods that will be inherited by specific @@ -71,7 +69,6 @@ def __init__( input_data_writeable=False, verbose=True, ): - self.verbose = verbose self._input_data_writeable = input_data_writeable self._grad_data_type = data_type @@ -100,7 +97,6 @@ def obs_data(self): @obs_data.setter def obs_data(self, input_data): - if self._grad_data_type in {float, np.floating}: input_data = check_float(input_data) check_npndarray( @@ -128,7 +124,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -147,7 +142,6 @@ def trans_op(self): @trans_op.setter def trans_op(self, operator): - self._trans_op = check_callable(operator) @property @@ -157,7 +151,6 @@ def get_grad(self): @get_grad.setter def get_grad(self, method): - self._get_grad = check_callable(method) @property @@ -167,7 +160,6 @@ def grad(self): @grad.setter def grad(self, input_value): - if self._grad_data_type in {float, np.floating}: input_value = check_float(input_value) self._grad = input_value @@ -179,7 +171,6 @@ def cost(self): @cost.setter def cost(self, method): - self._cost = check_callable(method) def trans_op_op(self, input_data): @@ -243,7 +234,6 @@ class GradBasic(GradParent): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) self.get_grad = self._get_grad_method self.cost = self._cost_method @@ -289,7 +279,7 @@ def _cost_method(self, *args, **kwargs): """ cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2 - if 'verbose' in kwargs and kwargs['verbose']: - print(' - DATA FIDELITY (X):', cost_val) + if kwargs.get("verbose"): + print(" - DATA FIDELITY (X):", cost_val) return cost_val diff --git a/src/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py new file mode 100644 index 00000000..d5c0d21f --- /dev/null +++ b/src/modopt/opt/linear/__init__.py @@ -0,0 +1,21 @@ +"""LINEAR OPERATORS. + +This module contains linear operator classes. + +:Author: Samuel Farrens +:Author: Pierre-Antoine Comby +""" + +from .base import LinearParent, Identity, MatrixOperator, LinearCombo + +from .wavelet import WaveletConvolve, WaveletTransform + + +__all__ = [ + "LinearParent", + "Identity", + "MatrixOperator", + "LinearCombo", + "WaveletConvolve", + "WaveletTransform", +] diff --git a/modopt/opt/linear.py b/src/modopt/opt/linear/base.py similarity index 79% rename from modopt/opt/linear.py rename to src/modopt/opt/linear/base.py index d8679998..af748a73 100644 --- a/modopt/opt/linear.py +++ b/src/modopt/opt/linear/base.py @@ -1,20 +1,12 @@ -# -*- coding: utf-8 -*- - -"""LINEAR OPERATORS. - -This module contains linear operator classes. - -:Author: Samuel Farrens - -""" +"""Base classes for linear operators.""" import numpy as np -from modopt.base.types import check_callable, check_float -from modopt.signal.wavelet import filter_convolve_stack +from modopt.base.types import check_callable +from modopt.base.backend import get_array_module -class LinearParent(object): +class LinearParent: """Linear Operator Parent Class. This class sets the structure for defining linear operator instances. @@ -38,7 +30,6 @@ class LinearParent(object): """ def __init__(self, op, adj_op): - self.op = op self.adj_op = adj_op @@ -49,7 +40,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -59,7 +49,6 @@ def adj_op(self): @adj_op.setter def adj_op(self, operator): - self._adj_op = check_callable(operator) @@ -75,45 +64,26 @@ class Identity(LinearParent): """ def __init__(self): - self.op = lambda input_data: input_data self.adj_op = self.op + self.cost = lambda *args, **kwargs: 0 -class WaveletConvolve(LinearParent): - """Wavelet Convolution Class. - - This class defines the wavelet transform operators via convolution with - predefined filters. - - Parameters - ---------- - filters: numpy.ndarray - Array of wavelet filter coefficients - method : str, optional - Convolution method (default is ``'scipy'``) - - See Also - -------- - LinearParent : parent class - modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution +class MatrixOperator(LinearParent): + """ + Matrix Operator class. + This class transforms an array into a suitable linear operator. """ - def __init__(self, filters, method='scipy'): + def __init__(self, array): + self.op = lambda x: array @ x + xp = get_array_module(array) - self._filters = check_float(filters) - self.op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - method=method, - ) - self.adj_op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - filter_rot=True, - method=method, - ) + if xp.any(xp.iscomplex(array)): + self.adj_op = lambda x: array.T.conjugate() @ x + else: + self.adj_op = lambda x: array.T @ x class LinearCombo(LinearParent): @@ -150,11 +120,9 @@ class LinearCombo(LinearParent): See Also -------- LinearParent : parent class - """ def __init__(self, operators, weights=None): - operators, weights = self._check_inputs(operators, weights) self.operators = operators self.weights = weights @@ -187,14 +155,13 @@ def _check_type(self, input_val): """ if not isinstance(input_val, (list, tuple, np.ndarray)): raise TypeError( - 'Invalid input type, input must be a list, tuple or numpy ' - + 'array.', + "Invalid input type, input must be a list, tuple or numpy " + "array.", ) input_val = np.array(input_val) if not input_val.size: - raise ValueError('Input list is empty.') + raise ValueError("Input list is empty.") return input_val @@ -227,11 +194,10 @@ def _check_inputs(self, operators, weights): operators = self._check_type(operators) for operator in operators: - - if not hasattr(operator, 'op'): + if not hasattr(operator, "op"): raise ValueError('Operators must contain "op" method.') - if not hasattr(operator, 'adj_op'): + if not hasattr(operator, "adj_op"): raise ValueError('Operators must contain "adj_op" method.') operator.op = check_callable(operator.op) @@ -242,12 +208,11 @@ def _check_inputs(self, operators, weights): if weights.size != operators.size: raise ValueError( - 'The number of weights must match the number of ' - + 'operators.', + "The number of weights must match the number of " + "operators.", ) if not np.issubdtype(weights.dtype, np.floating): - raise TypeError('The weights must be a list of float values.') + raise TypeError("The weights must be a list of float values.") return operators, weights diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py new file mode 100644 index 00000000..8012a072 --- /dev/null +++ b/src/modopt/opt/linear/wavelet.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python3 +"""Wavelet operator, using either scipy filter or pywavelet.""" +import warnings + +import numpy as np + +from modopt.base.types import check_float +from modopt.signal.wavelet import filter_convolve_stack + +from .base import LinearParent + +pywt_available = True +try: + import pywt + from joblib import Parallel, cpu_count, delayed +except ImportError: + pywt_available = False + +ptwt_available = True +try: + import ptwt + import torch + import cupy as cp +except ImportError: + ptwt_available = False + + +class WaveletConvolve(LinearParent): + """Wavelet Convolution Class. + + This class defines the wavelet transform operators via convolution with + predefined filters. + + Parameters + ---------- + filters: numpy.ndarray + Array of wavelet filter coefficients + method : str, optional + Convolution method (default is ``'scipy'``) + + See Also + -------- + LinearParent : parent class + modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution + + """ + + def __init__(self, filters, method="scipy"): + self._filters = check_float(filters) + self.op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + method=method, + ) + self.adj_op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + filter_rot=True, + method=method, + ) + + +class WaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + + This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU). + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + mode: str, default "zero" + Boundary Condition mode + compute_backend: str, "numpy" or "cupy", default "numpy" + Backend library to use. "cupy" also requires a working installation of PyTorch + and PyTorch wavelets (ptwt). + + **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet + """ + + def __init__( + self, + wavelet_name, + shape, + level=4, + mode="symmetric", + compute_backend="numpy", + **kwargs, + ): + if compute_backend == "cupy" and ptwt_available: + self.operator = CupyWaveletTransform( + wavelet=wavelet_name, shape=shape, level=level, mode=mode + ) + elif compute_backend == "numpy" and pywt_available: + self.operator = CPUWaveletTransform( + wavelet_name=wavelet_name, shape=shape, level=level, **kwargs + ) + else: + raise ValueError(f"Compute Backend {compute_backend} not available") + + self.op = self.operator.op + self.adj_op = self.operator.adj_op + + @property + def coeffs_shape(self): + """Get the coeffs shapes.""" + return self.operator.coeffs_shape + + +class CPUWaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + + This is a light wrapper around PyWavelet, with multicoil support. + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + n_batchs: int, default 1 + the number of channel/ batch dimension + n_jobs: int, default 1 + the number of cores to use for multichannel. + backend: str, default "threading" + the backend to use for parallel multichannel linear operation. + verbose: int, default 0 + the verbosity level. + + Attributes + ---------- + nb_scale: int + number of scale decomposed in wavelet space. + n_jobs: int + number of jobs for parallel computation + n_batchs: int + number of coils use f + backend: str + Backend use for parallel computation + verbose: int + Verbosity level + """ + + def __init__( + self, + wavelet_name, + shape, + level=4, + n_batch=1, + n_jobs=1, + decimated=True, + backend="threading", + mode="symmetric", + ): + if not pywt_available: + raise ImportError( + "PyWavelet and/or joblib are not available. " + "Please install it to use WaveletTransform." + ) + if wavelet_name not in pywt.wavelist(kind="all"): + raise ValueError( + "Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``" + ) + + self.wavelet = wavelet_name + if isinstance(shape, int): + shape = (shape,) + self.shape = shape + self.n_jobs = n_jobs + self.mode = mode + self.level = level + if not decimated: + raise NotImplementedError( + "Undecimated Wavelet Transform is not implemented yet." + ) + ca, *cds = pywt.wavedecn_shapes( + self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level + ) + self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()] + + if len(shape) > 1: + self.dwt = pywt.wavedecn + self.idwt = pywt.waverecn + self._pywt_fun = "wavedecn" + else: + self.dwt = pywt.wavedec + self.idwt = pywt.waverec + self._pywt_fun = "wavedec" + + self.n_batch = n_batch + if self.n_batch == 1 and self.n_jobs != 1: + warnings.warn( + "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1 + ) + self.n_jobs = 1 + self.backend = backend + n_proc = self.n_jobs + if n_proc < 0: + n_proc = cpu_count() + self.n_jobs + 1 + + def op(self, data): + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: ndarray or Image + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ + if self.n_batch > 1: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip( + *Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )(delayed(self._op)(data[i]) for i in np.arange(self.n_batch)) + ) + coeffs = np.asarray(coeffs) + else: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data) + return coeffs + + def _op(self, data): + """Single coil wavelet transform.""" + return pywt.ravel_coeffs( + self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet) + ) + + def adj_op(self, coeffs): + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ + if self.n_batch > 1: + images = Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )( + delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i]) + for i in np.arange(self.n_batch) + ) + images = np.asarray(images) + else: + images = self._adj_op(coeffs) + return images + + def _adj_op(self, coeffs): + """Single coil inverse wavelet transform.""" + return self.idwt( + pywt.unravel_coeffs( + coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun + ), + wavelet=self.wavelet, + mode=self.mode, + ) + + +class TorchWaveletTransform: + """Wavelet transform using pytorch.""" + + wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd") + + def __init__( + self, + shape, + wavelet, + level, + mode, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + self.coeffs_shape = None # will be set after op. + + def op(self, data): + """Apply the wavelet decomposition on. + + Parameters + ---------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of + the operator. + + Returns + ------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + """ + if data.shape == self.shape: + data = data[None, ...] # add a batch dimension + + if len(self.shape) == 2: + if torch.is_complex(data): + # 2D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec2( + data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2) + ) + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c + ] + + return coeffs + # 2D Real + coeffs_ = ptwt.wavedec2( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c] + + if torch.is_complex(data): + # 3D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec3( + data_, + self.wavelet, + level=self.level, + mode=self.mode, + axes=(-4, -3, -2), + ) + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c.values() + ] + + return coeffs + # 3D Real + coeffs_ = ptwt.wavedec3( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()] + + def adj_op(self, coeffs): + """Apply the wavelet recomposition. + + Parameters + ---------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + + Returns + ------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of the + operator. + + """ + if len(self.shape) == 2: + if torch.is_complex(coeffs[0]): + ## 2D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + tuple(torch.view_as_real(coeffs[i + k]) for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 2D Real ## + coeffs_ = [coeffs[0]] + [ + tuple(coeffs[i + k] for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1)) + return data + + if torch.is_complex(coeffs[0]): + ## 3D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + { + v: torch.view_as_real(coeffs[i + k]) + for k, v in enumerate(self.wavedec3_keys) + } + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 3D Real ## + coeffs_ = [coeffs[0]] + [ + {v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)} + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1)) + return data + + +class CupyWaveletTransform(LinearParent): + """Wrapper around torch wavelet transform to be compatible with the Modopt API.""" + + def __init__( + self, + shape, + wavelet, + level, + mode, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + + self.operator = TorchWaveletTransform( + shape=shape, wavelet=wavelet, level=level, mode=mode + ) + self.coeffs_shape = None # will be set after op + + def op(self, data): + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: cp.ndarray + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ + data_ = torch.as_tensor(data) + tensor_list = self.operator.op(data_) + # flatten the list of tensor to a cupy array + # this requires an on device copy... + self.coeffs_shape = [c.shape for c in tensor_list] + n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape]) + ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch + start = 0 + for t in tensor_list: + stop = start + np.prod(t.shape) + ret[start:stop] = cp.asarray(t.flatten()) + start = stop + + return ret + + def adj_op(self, data): + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: cp.ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ + start = 0 + tensor_list = [None] * len(self.coeffs_shape) + for i, s in enumerate(self.coeffs_shape): + stop = start + np.prod(s) + tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda") + start = stop + ret_tensor = self.operator.adj_op(tensor_list) + return cp.from_dlpack(ret_tensor) diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py similarity index 89% rename from modopt/opt/proximity.py rename to src/modopt/opt/proximity.py index f8f368ef..dea862ca 100644 --- a/modopt/opt/proximity.py +++ b/src/modopt/opt/proximity.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PROXIMITY OPERATORS. This module contains classes of proximity operators for optimisation. @@ -22,6 +20,7 @@ else: import_sklearn = True +from modopt.base.backend import get_array_module from modopt.base.transform import cube2matrix, matrix2cube from modopt.base.types import check_callable from modopt.interface.errors import warn @@ -31,7 +30,7 @@ from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast -class ProximityParent(object): +class ProximityParent: """Proximity Operator Parent Class. This class sets the structure for defining proximity operator instances. @@ -47,7 +46,6 @@ class ProximityParent(object): """ def __init__(self, op, cost): - self.op = op self.cost = cost @@ -58,7 +56,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -78,7 +75,6 @@ def cost(self): @cost.setter def cost(self, method): - self._cost = check_callable(method) @@ -98,9 +94,8 @@ class IdentityProx(ProximityParent): """ def __init__(self): - - self.op = lambda x_val: x_val - self.cost = lambda x_val: 0 + self.op = lambda x_val, *args, **kwargs: x_val + self.cost = lambda x_val, *args, **kwargs: 0 class Positivity(ProximityParent): @@ -116,10 +111,25 @@ class Positivity(ProximityParent): """ def __init__(self): - - self.op = lambda input_data: positive(input_data) self.cost = self._cost_method + def op(self, input_data, *args, **kwargs): + """ + Make the data positive. + + Parameters + ---------- + input_data: np.ndarray + Input array + *args, **kwargs: dummy. + + Returns + ------- + np.ndarray + Positive data. + """ + return positive(input_data) + def _cost_method(self, *args, **kwargs): """Calculate positivity component of the cost. @@ -139,8 +149,8 @@ def _cost_method(self, *args, **kwargs): ``0.0`` """ - if 'verbose' in kwargs and kwargs['verbose']: - print(' - Min (X):', np.min(args[0])) + if kwargs.get("verbose"): + print(" - Min (X):", np.min(args[0])) return 0 @@ -166,8 +176,7 @@ class SparseThreshold(ProximityParent): """ - def __init__(self, linear, weights, thresh_type='soft'): - + def __init__(self, linear, weights, thresh_type="soft"): self._linear = linear self.weights = weights self._thresh_type = thresh_type @@ -215,10 +224,13 @@ def _cost_method(self, *args, **kwargs): Sparsity cost component """ - cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0]))) + xp = get_array_module(args[0]) + cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0]))) + if isinstance(cost_val, xp.ndarray): + cost_val = cost_val.item() - if 'verbose' in kwargs and kwargs['verbose']: - print(' - L1 NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - L1 NORM (X):", cost_val) return cost_val @@ -269,12 +281,11 @@ class LowRankMatrix(ProximityParent): def __init__( self, threshold, - thresh_type='soft', - lowr_type='standard', + thresh_type="soft", + lowr_type="standard", initial_rank=None, operator=None, ): - self.thresh = threshold self.thresh_type = thresh_type self.lowr_type = lowr_type @@ -311,13 +322,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): """ # Update threshold with extra factor. threshold = self.thresh * extra_factor - if self.lowr_type == 'standard' and self.rank is None and rank is None: + if self.lowr_type == "standard" and self.rank is None and rank is None: data_matrix = svd_thresh( cube2matrix(input_data), threshold, thresh_type=self.thresh_type, ) - elif self.lowr_type == 'standard': + elif self.lowr_type == "standard": data_matrix, update_rank = svd_thresh_coef_fast( cube2matrix(input_data), threshold, @@ -327,7 +338,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): ) self.rank = update_rank # save for future use - elif self.lowr_type == 'ngole': + elif self.lowr_type == "ngole": data_matrix = svd_thresh_coef( cube2matrix(input_data), self.operator, @@ -335,7 +346,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): thresh_type=self.thresh_type, ) else: - raise ValueError('lowr_type should be standard or ngole') + raise ValueError("lowr_type should be standard or ngole") # Return updated data. return matrix2cube(data_matrix, input_data.shape[1:]) @@ -361,8 +372,8 @@ def _cost_method(self, *args, **kwargs): """ cost_val = self.thresh * nuclear_norm(cube2matrix(args[0])) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - NUCLEAR NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - NUCLEAR NORM (X):", cost_val) return cost_val @@ -466,7 +477,6 @@ class ProximityCombo(ProximityParent): """ def __init__(self, operators): - operators = self._check_operators(operators) self.operators = operators self.op = self._op_method @@ -502,19 +512,19 @@ def _check_operators(self, operators): """ if not isinstance(operators, (list, tuple, np.ndarray)): raise TypeError( - 'Invalid input type, operators must be a list, tuple or ' - + 'numpy array.', + "Invalid input type, operators must be a list, tuple or " + + "numpy array.", ) operators = np.array(operators) if not operators.size: - raise ValueError('Operator list is empty.') + raise ValueError("Operator list is empty.") for operator in operators: - if not hasattr(operator, 'op'): + if not hasattr(operator, "op"): raise ValueError('Operators must contain "op" method.') - if not hasattr(operator, 'cost'): + if not hasattr(operator, "cost"): raise ValueError('Operators must contain "cost" method.') operator.op = check_callable(operator.op) operator.cost = check_callable(operator.cost) @@ -569,10 +579,12 @@ def _cost_method(self, *args, **kwargs): Combinded cost components """ - return np.sum([ - operator.cost(input_data) - for operator, input_data in zip(self.operators, args[0]) - ]) + return np.sum( + [ + operator.cost(input_data) + for operator, input_data in zip(self.operators, args[0]) + ] + ) class OrderedWeightedL1Norm(ProximityParent): @@ -613,16 +625,16 @@ class OrderedWeightedL1Norm(ProximityParent): def __init__(self, weights): if not import_sklearn: # pragma: no cover raise ImportError( - 'Required version of Scikit-Learn package not found see ' - + 'documentation for details: ' - + 'https://cea-cosmic.github.io/ModOpt/#optional-packages', + "Required version of Scikit-Learn package not found see " + + "documentation for details: " + + "https://cea-cosmic.github.io/ModOpt/#optional-packages", ) if np.max(np.diff(weights)) > 0: - raise ValueError('Weights must be non increasing') + raise ValueError("Weights must be non increasing") self.weights = weights.flatten() if (self.weights < 0).any(): raise ValueError( - 'The weight values must be provided in descending order', + "The weight values must be provided in descending order", ) self.op = self._op_method self.cost = self._cost_method @@ -660,7 +672,9 @@ def _op_method(self, input_data, extra_factor=1.0): # Projection onto the monotone non-negative cone using # isotonic_regression data_abs = isotonic_regression( - data_abs - threshold, y_min=0, increasing=False, + data_abs - threshold, + y_min=0, + increasing=False, ) # Unsorting the data @@ -668,7 +682,7 @@ def _op_method(self, input_data, extra_factor=1.0): data_abs_unsorted[data_abs_sort_idx] = data_abs # Putting the sign back - with np.errstate(invalid='ignore'): + with np.errstate(invalid="ignore"): sign_data = data_squeezed / np.abs(data_squeezed) # Removing NAN caused by the sign @@ -698,8 +712,8 @@ def _cost_method(self, *args, **kwargs): self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - OWL NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - OWL NORM (X):", cost_val) return cost_val @@ -730,8 +744,7 @@ class Ridge(ProximityParent): """ - def __init__(self, linear, weights, thresh_type='soft'): - + def __init__(self, linear, weights, thresh_type="soft"): self._linear = linear self.weights = weights self.op = self._op_method @@ -782,8 +795,8 @@ def _cost_method(self, *args, **kwargs): np.abs(self.weights * self._linear.op(args[0]) ** 2), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - L2 NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - L2 NORM (X):", cost_val) return cost_val @@ -818,7 +831,6 @@ class ElasticNet(ProximityParent): """ def __init__(self, linear, alpha, beta): - self._linear = linear self.alpha = alpha self.beta = beta @@ -844,8 +856,8 @@ def _op_method(self, input_data, extra_factor=1.0): """ soft_threshold = self.beta * extra_factor - normalization = (self.alpha * 2 * extra_factor + 1) - return thresh(input_data, soft_threshold, 'soft') / normalization + normalization = self.alpha * 2 * extra_factor + 1 + return thresh(input_data, soft_threshold, "soft") / normalization def _cost_method(self, *args, **kwargs): """Calculate Ridge component of the cost. @@ -871,8 +883,8 @@ def _cost_method(self, *args, **kwargs): + np.abs(self.beta * self._linear.op(args[0])), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - ELASTIC NET (X):', cost_val) + if kwargs.get("verbose"): + print(" - ELASTIC NET (X):", cost_val) return cost_val @@ -938,7 +950,7 @@ def k_value(self): def k_value(self, k_val): if k_val < 1: raise ValueError( - 'The k parameter should be greater or equal than 1', + "The k parameter should be greater or equal than 1", ) self._k_value = k_val @@ -983,7 +995,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0): alpha_beta = alpha_input - self.beta * extra_factor theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0)) theta = np.nan_to_num(theta) - theta += (alpha_input > (self.beta * extra_factor + 1)) + theta += alpha_input > (self.beta * extra_factor + 1) return theta def _interpolate(self, alpha0, alpha1, sum0, sum1): @@ -993,7 +1005,7 @@ def _interpolate(self, alpha0, alpha1, sum0, sum1): :math:`\sum\theta(\alpha^*)=k` via a linear interpolation. Parameters - ----------- + ---------- alpha0: float A value for wich :math:`\sum\theta(\alpha^0) \leq k` alpha1: float @@ -1074,12 +1086,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): midpoint = 0 while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]): - midpoint = (first_idx + last_idx) // 2 cnt += 1 if prev_midpoint == midpoint: - # Particular case sum0 = self._compute_theta( data_abs, @@ -1092,11 +1102,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): extra_factor, ).sum() - if (np.abs(sum0 - self._k_value) <= tolerance): + if np.abs(sum0 - self._k_value) <= tolerance: found = True midpoint = first_idx - if (np.abs(sum1 - self._k_value) <= tolerance): + if np.abs(sum1 - self._k_value) <= tolerance: found = True midpoint = last_idx - 1 # -1 because output is index such that @@ -1141,13 +1151,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): if found: return ( - midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1, + midpoint, + alpha[midpoint], + alpha[midpoint + 1], + sum0, + sum1, ) raise ValueError( - 'Cannot find the coordinate of alpha (i) such ' - + 'that sum(theta(alpha[i])) =< k and ' - + 'sum(theta(alpha[i+1])) >= k ', + "Cannot find the coordinate of alpha (i) such " + + "that sum(theta(alpha[i])) =< k and " + + "sum(theta(alpha[i+1])) >= k ", ) def _find_alpha(self, input_data, extra_factor=1.0): @@ -1171,15 +1185,13 @@ def _find_alpha(self, input_data, extra_factor=1.0): data_size = input_data.shape[0] # Computes the alpha^i points line 1 in Algorithm 1. - alpha = np.zeros((data_size * 2)) + alpha = np.zeros(data_size * 2) data_abs = np.abs(input_data) - alpha[:data_size] = ( - (self.beta * extra_factor) - / (data_abs + sys.float_info.epsilon) + alpha[:data_size] = (self.beta * extra_factor) / ( + data_abs + sys.float_info.epsilon ) - alpha[data_size:] = ( - (self.beta * extra_factor + 1) - / (data_abs + sys.float_info.epsilon) + alpha[data_size:] = (self.beta * extra_factor + 1) / ( + data_abs + sys.float_info.epsilon ) alpha = np.sort(np.unique(alpha)) @@ -1216,8 +1228,8 @@ def _op_method(self, input_data, extra_factor=1.0): k_max = np.prod(data_shape) if self._k_value > k_max: warn( - 'K value of the K-support norm is greater than the input ' - + 'dimension, its value will be set to {0}'.format(k_max), + "K value of the K-support norm is greater than the input " + + f"dimension, its value will be set to {k_max}", ) self._k_value = k_max @@ -1229,8 +1241,7 @@ def _op_method(self, input_data, extra_factor=1.0): # Computes line 5. in Algorithm 1. rslt = np.nan_to_num( - (input_data.flatten() * theta) - / (theta + self.beta * extra_factor), + (input_data.flatten() * theta) / (theta + self.beta * extra_factor), ) return rslt.reshape(data_shape) @@ -1271,25 +1282,20 @@ def _find_q(self, sorted_data): found = True q_val = 0 - elif ( - (sorted_data[self._k_value - 1:].sum()) - <= sorted_data[self._k_value - 1] - ): + elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]: found = True q_val = self._k_value - 1 while ( - not found and not cnt == self._k_value + not found + and not cnt == self._k_value and (first_idx <= last_idx < self._k_value) ): - q_val = (first_idx + last_idx) // 2 cnt += 1 l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val) - if ( - sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val] - ): + if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]: found = True else: @@ -1324,15 +1330,12 @@ def _cost_method(self, *args, **kwargs): data_abs = data_abs[ix] # Sorted absolute value of the data q_val = self._find_q(data_abs) cost_val = ( - ( - np.sum(data_abs[:q_val] ** 2) * 0.5 - + np.sum(data_abs[q_val:]) ** 2 - / (self._k_value - q_val) - ) * self.beta - ) + np.sum(data_abs[:q_val] ** 2) * 0.5 + + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val) + ) * self.beta - if 'verbose' in kwargs and kwargs['verbose']: - print(' - K-SUPPORT NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - K-SUPPORT NORM (X):", cost_val) return cost_val @@ -1377,7 +1380,7 @@ def __init__(self, weights): self.op = self._op_method self.cost = self._cost_method - def _op_method(self, input_data, extra_factor=1.0): + def _op_method(self, input_data, *args, extra_factor=1.0, **kwargs): """Operator. This method returns the input data thresholded by the weights. @@ -1388,6 +1391,7 @@ def _op_method(self, input_data, extra_factor=1.0): Input data array extra_factor : float Additional multiplication factor (default is ``1.0``) + *args, **kwargs: no effects Returns ------- @@ -1403,7 +1407,7 @@ def _op_method(self, input_data, extra_factor=1.0): (1.0 - self.weights * extra_factor / denominator), ) - def _cost_method(self, input_data): + def _cost_method(self, input_data, *args, **kwargs): """Calculate the group LASSO component of the cost. This method calculate the cost function of the proximable part. @@ -1413,6 +1417,8 @@ def _cost_method(self, input_data): input_data : numpy.ndarray Input array of the sparse code + *args, **kwargs: no effects. + Returns ------- float diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py similarity index 91% rename from modopt/opt/reweight.py rename to src/modopt/opt/reweight.py index 8c4f2449..4a9bf44b 100644 --- a/modopt/opt/reweight.py +++ b/src/modopt/opt/reweight.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """REWEIGHTING CLASSES. This module contains classes for reweighting optimisation implementations. @@ -13,7 +11,7 @@ from modopt.base.types import check_float -class cwbReweight(object): +class cwbReweight: """Candes, Wakin and Boyd reweighting class. This class implements the reweighting scheme described in @@ -45,7 +43,6 @@ class cwbReweight(object): """ def __init__(self, weights, thresh_factor=1.0, verbose=False): - self.weights = check_float(weights) self.original_weights = np.copy(self.weights) self.thresh_factor = check_float(thresh_factor) @@ -81,7 +78,7 @@ def reweight(self, input_data): """ if self.verbose: - print(' - Reweighting: {0}'.format(self._rw_num)) + print(f" - Reweighting: {self._rw_num}") self._rw_num += 1 @@ -89,7 +86,7 @@ def reweight(self, input_data): if input_data.shape != self.weights.shape: raise ValueError( - 'Input data must have the same shape as the initial weights.', + "Input data must have the same shape as the initial weights.", ) thresh_weights = self.thresh_factor * self.original_weights diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py similarity index 73% rename from modopt/plot/__init__.py rename to src/modopt/plot/__init__.py index 28d60be6..f31ed596 100644 --- a/modopt/plot/__init__.py +++ b/src/modopt/plot/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PLOTTING ROUTINES. This module contains submodules for plotting applications. @@ -8,4 +6,4 @@ """ -__all__ = ['cost_plot'] +__all__ = ["cost_plot"] diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py similarity index 66% rename from modopt/plot/cost_plot.py rename to src/modopt/plot/cost_plot.py index aa855eaa..7fb7e39b 100644 --- a/modopt/plot/cost_plot.py +++ b/src/modopt/plot/cost_plot.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PLOTTING ROUTINES. This module contains methods for making plots. @@ -37,20 +35,20 @@ def plotCost(cost_list, output=None): """ if import_fail: - raise ImportError('Matplotlib package not found') + raise ImportError("Matplotlib package not found") else: if isinstance(output, type(None)): - file_name = 'cost_function.png' + file_name = "cost_function.png" else: - file_name = '{0}_cost_function.png'.format(output) + file_name = f"{output}_cost_function.png" plt.figure() - plt.plot(np.log10(cost_list), 'r-') - plt.title('Cost Function') - plt.xlabel('Iteration') - plt.ylabel(r'$\log_{10}$ Cost') + plt.plot(np.log10(cost_list), "r-") + plt.title("Cost Function") + plt.xlabel("Iteration") + plt.ylabel(r"$\log_{10}$ Cost") plt.savefig(file_name) plt.close() - print(' - Saving cost function data to:', file_name) + print(" - Saving cost function data to:", file_name) diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py similarity index 58% rename from modopt/signal/__init__.py rename to src/modopt/signal/__init__.py index dbc6d053..6bf0912b 100644 --- a/modopt/signal/__init__.py +++ b/src/modopt/signal/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """SIGNAL PROCESSING ROUTINES. This module contains submodules for signal processing. @@ -8,4 +6,4 @@ """ -__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet'] +__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"] diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py similarity index 92% rename from modopt/signal/filter.py rename to src/modopt/signal/filter.py index 8e24768c..33c3c105 100644 --- a/modopt/signal/filter.py +++ b/src/modopt/signal/filter.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """FILTER ROUTINES. This module contains methods for distance measurements in cosmology. @@ -73,15 +71,15 @@ def mex_hat(data_point, sigma): Examples -------- >>> from modopt.signal.filter import mex_hat - >>> mex_hat(2, 1) - -0.3521390522571337 + >>> round(mex_hat(2, 1), 15) + -0.352139052257134 """ data_point = check_float(data_point) sigma = check_float(sigma) xs = (data_point / sigma) ** 2 - factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25 + factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25 return factor * (1 - xs) * np.exp(-0.5 * xs) @@ -108,8 +106,8 @@ def mex_hat_dir(data_gauss, data_mex, sigma): Examples -------- >>> from modopt.signal.filter import mex_hat_dir - >>> mex_hat_dir(1, 2, 1) - 0.17606952612856686 + >>> round(mex_hat_dir(1, 2, 1), 16) + 0.1760695261285668 """ data_gauss = check_float(data_gauss) diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py similarity index 86% rename from modopt/signal/noise.py rename to src/modopt/signal/noise.py index a59d5553..28307f52 100644 --- a/modopt/signal/noise.py +++ b/src/modopt/signal/noise.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """NOISE ROUTINES. This module contains methods for adding and removing noise from data. @@ -8,14 +6,12 @@ """ -from builtins import zip - import numpy as np from modopt.base.backend import get_array_module -def add_noise(input_data, sigma=1.0, noise_type='gauss'): +def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None): """Add noise to data. This method adds Gaussian or Poisson noise to the input data. @@ -29,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): default is ``1.0``) noise_type : {'gauss', 'poisson'} Type of noise to be added (default is ``'gauss'``) + rng: np.random.Generator or int + A Random number generator or a seed to initialize one. + Returns ------- @@ -68,9 +67,12 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526]) """ + if not isinstance(rng, np.random.Generator): + rng = np.random.default_rng(rng) + input_data = np.array(input_data) - if noise_type not in {'gauss', 'poisson'}: + if noise_type not in {"gauss", "poisson"}: raise ValueError( 'Invalid noise type. Options are "gauss" or "poisson"', ) @@ -78,15 +80,14 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): if isinstance(sigma, (list, tuple, np.ndarray)): if len(sigma) != input_data.shape[0]: raise ValueError( - 'Number of sigma values must match first dimension of input ' - + 'data', + "Number of sigma values must match first dimension of input " + "data", ) - if noise_type == 'gauss': - random = np.random.randn(*input_data.shape) + if noise_type == "gauss": + random = rng.standard_normal(input_data.shape) - elif noise_type == 'poisson': - random = np.random.poisson(np.abs(input_data)) + elif noise_type == "poisson": + random = rng.poisson(np.abs(input_data)) if isinstance(sigma, (int, float)): return input_data + sigma * random @@ -96,7 +97,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): return input_data + noise -def thresh(input_data, threshold, threshold_type='hard'): +def thresh(input_data, threshold, threshold_type="hard"): r"""Threshold data. This method perfoms hard or soft thresholding on the input data. @@ -169,12 +170,12 @@ def thresh(input_data, threshold, threshold_type='hard'): input_data = xp.array(input_data) - if threshold_type not in {'hard', 'soft'}: + if threshold_type not in {"hard", "soft"}: raise ValueError( 'Invalid threshold type. Options are "hard" or "soft"', ) - if threshold_type == 'soft': + if threshold_type == "soft": denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data)) max_value = xp.maximum((1.0 - threshold / denominator), 0) diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py similarity index 90% rename from modopt/signal/positivity.py rename to src/modopt/signal/positivity.py index e4ec098d..8d7aa46c 100644 --- a/modopt/signal/positivity.py +++ b/src/modopt/signal/positivity.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """POSITIVITY. This module contains a function that retains only positive coefficients in @@ -47,8 +45,8 @@ def pos_recursive(input_data): Positive coefficients """ - if input_data.dtype == 'O': - res = np.array([pos_recursive(elem) for elem in input_data]) + if input_data.dtype == "O": + res = np.array([pos_recursive(elem) for elem in input_data], dtype="object") else: res = pos_thresh(input_data) @@ -97,15 +95,15 @@ def positive(input_data, ragged=False): """ if not isinstance(input_data, (int, float, list, tuple, np.ndarray)): raise TypeError( - 'Invalid data type, input must be `int`, `float`, `list`, ' - + '`tuple` or `np.ndarray`.', + "Invalid data type, input must be `int`, `float`, `list`, " + + "`tuple` or `np.ndarray`.", ) if isinstance(input_data, (int, float)): return pos_thresh(input_data) if ragged: - input_data = np.array(input_data, dtype='object') + input_data = np.array(input_data, dtype="object") else: input_data = np.array(input_data) diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py similarity index 84% rename from modopt/signal/svd.py rename to src/modopt/signal/svd.py index 6dcb9eda..cf147503 100644 --- a/modopt/signal/svd.py +++ b/src/modopt/signal/svd.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """SVD ROUTINES. This module contains methods for thresholding singular values. @@ -52,12 +50,12 @@ def find_n_pc(u_vec, factor=0.5): """ if np.sqrt(u_vec.shape[0]) % 1: raise ValueError( - 'Invalid left singular vector. The size of the first ' - + 'dimenion of ``u_vec`` must be perfect square.', + "Invalid left singular vector. The size of the first " + + "dimenion of ``u_vec`` must be perfect square.", ) # Get the shape of the array - array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2) + array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) # Find the auto correlation of the left singular vector. u_auto = [ @@ -69,13 +67,12 @@ def find_n_pc(u_vec, factor=0.5): ] # Return the required number of principal components. - return np.sum([ - ( - u_val[tuple(zip(array_shape // 2))] ** 2 <= factor - * np.sum(u_val ** 2), - ) - for u_val in u_auto - ]) + return np.sum( + [ + (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),) + for u_val in u_auto + ] + ) def calculate_svd(input_data): @@ -101,17 +98,17 @@ def calculate_svd(input_data): """ if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2): - raise TypeError('Input data must be a 2D np.ndarray.') + raise TypeError("Input data must be a 2D np.ndarray.") return svd( input_data, check_finite=False, - lapack_driver='gesvd', + lapack_driver="gesvd", full_matrices=False, ) -def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): +def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"): """Threshold the singular values. This method thresholds the input data using singular value decomposition. @@ -156,16 +153,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): """ less_than_zero = isinstance(n_pc, int) and n_pc <= 0 - str_not_all = isinstance(n_pc, str) and n_pc != 'all' + str_not_all = isinstance(n_pc, str) and n_pc != "all" - if ( - (not isinstance(n_pc, (int, str, type(None)))) - or less_than_zero - or str_not_all - ): + if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all: raise ValueError( - 'Invalid value for "n_pc", specify a positive integer value or ' - + '"all"', + 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"', ) # Get SVD of input data. @@ -176,15 +168,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): # Find the required number of principal components if not specified. if isinstance(n_pc, type(None)): n_pc = find_n_pc(u_vec, factor=0.1) - print('xxxx', n_pc, u_vec) + print("xxxx", n_pc, u_vec) # If the number of PCs is too large use all of the singular values. - if ( - (isinstance(n_pc, int) and n_pc >= s_values.size) - or (isinstance(n_pc, str) and n_pc == 'all') + if (isinstance(n_pc, int) and n_pc >= s_values.size) or ( + isinstance(n_pc, str) and n_pc == "all" ): n_pc = s_values.size - warn('Using all singular values.') + warn("Using all singular values.") threshold = s_values[n_pc - 1] @@ -192,7 +183,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): s_new = thresh(s_values, threshold, thresh_type) if np.all(s_new == s_values): - warn('No change to singular values.') + warn("No change to singular values.") # Diagonalize the svd s_new = np.diag(s_new) @@ -206,7 +197,7 @@ def svd_thresh_coef_fast( threshold, n_vals=-1, extra_vals=5, - thresh_type='hard', + thresh_type="hard", ): """Threshold the singular values coefficients. @@ -241,7 +232,7 @@ def svd_thresh_coef_fast( ok = False while not ok: (u_vec, s_values, v_vec) = svds(input_data, k=n_vals) - ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1) + ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1 n_vals = min(n_vals + extra_vals, *input_data.shape) s_values = thresh( @@ -259,7 +250,7 @@ def svd_thresh_coef_fast( ) -def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): +def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"): """Threshold the singular values coefficients. This method thresholds the input data using singular value decomposition. @@ -287,7 +278,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): """ if not callable(operator): - raise TypeError('Operator must be a callable function.') + raise TypeError("Operator must be a callable function.") # Get SVD of data matrix u_vec, s_values, v_vec = calculate_svd(input_data) @@ -299,13 +290,12 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): a_matrix = np.dot(s_values, v_vec) # Get the shape of the array - array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2) + array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) # Compute threshold matrix. - ti = np.array([ - np.linalg.norm(elem) - for elem in operator(matrix2cube(u_vec, array_shape)) - ]) + ti = np.array( + [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))] + ) threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape) # Threshold coefficients. diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py similarity index 80% rename from modopt/signal/validation.py rename to src/modopt/signal/validation.py index 422a987b..cdf69b7d 100644 --- a/modopt/signal/validation.py +++ b/src/modopt/signal/validation.py @@ -16,6 +16,7 @@ def transpose_test( x_args=None, y_shape=None, y_args=None, + rng=None, ): """Transpose test. @@ -36,6 +37,8 @@ def transpose_test( Shape of transpose operator input data (default is ``None``) y_args : tuple, optional Arguments to be passed to transpose operator (default is ``None``) + rng: numpy.random.Generator or int or None (default is ``None``) + Initialized random number generator or seed. Raises ------ @@ -54,7 +57,7 @@ def transpose_test( """ if not callable(operator) or not callable(operator_t): - raise TypeError('The input operators must be callable functions.') + raise TypeError("The input operators must be callable functions.") if isinstance(y_shape, type(None)): y_shape = x_shape @@ -62,9 +65,11 @@ def transpose_test( if isinstance(y_args, type(None)): y_args = x_args + if not isinstance(rng, np.random.Generator): + rng = np.random.default_rng(rng) # Generate random arrays. - x_val = np.random.ranf(x_shape) - y_val = np.random.ranf(y_shape) + x_val = rng.random(x_shape) + y_val = rng.random(y_shape) # Calculate mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val)) @@ -73,4 +78,4 @@ def transpose_test( x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args))) # Test the difference between the two. - print(' - | - | =', np.abs(mx_y - x_mty)) + print(" - | - | =", np.abs(mx_y - x_mty)) diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py similarity index 88% rename from modopt/signal/wavelet.py rename to src/modopt/signal/wavelet.py index bc4ffc70..b55b78d9 100644 --- a/modopt/signal/wavelet.py +++ b/src/modopt/signal/wavelet.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """WAVELET MODULE. This module contains methods for performing wavelet transformations using @@ -58,20 +56,20 @@ def execute(command_line): """ if not isinstance(command_line, str): - raise TypeError('Command line must be a string.') + raise TypeError("Command line must be a string.") command = command_line.split() process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE) stdout, stderr = process.communicate() - return stdout.decode('utf-8'), stderr.decode('utf-8') + return stdout.decode("utf-8"), stderr.decode("utf-8") def call_mr_transform( input_data, - opt='', - path='./', + opt="", + path="./", remove_files=True, ): # pragma: no cover """Call ``mr_transform``. @@ -127,26 +125,23 @@ def call_mr_transform( """ if not import_astropy: - raise ImportError('Astropy package not found.') + raise ImportError("Astropy package not found.") if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2): - raise ValueError('Input data must be a 2D numpy array.') + raise ValueError("Input data must be a 2D numpy array.") - executable = 'mr_transform' + executable = "mr_transform" # Make sure mr_transform is installed. is_executable(executable) # Create a unique string using the current date and time. - unique_string = ( - datetime.now().strftime('%Y.%m.%d_%H.%M.%S') - + str(getrandbits(128)) - ) + unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128)) # Set the ouput file names. - file_name = '{0}mr_temp_{1}'.format(path, unique_string) - file_fits = '{0}.fits'.format(file_name) - file_mr = '{0}.mr'.format(file_name) + file_name = f"{path}mr_temp_{unique_string}" + file_fits = f"{file_name}.fits" + file_mr = f"{file_name}.mr" # Write the input data to a fits file. fits.writeto(file_fits, input_data) @@ -155,15 +150,15 @@ def call_mr_transform( opt = opt.split() # Prepare command and execute it - command_line = ' '.join([executable] + opt + [file_fits, file_mr]) + command_line = " ".join([executable, *opt, file_fits, file_mr]) stdout, _ = execute(command_line) # Check for errors - if any(word in stdout for word in ('bad', 'Error', 'Sorry')): + if any(word in stdout for word in ("bad", "Error", "Sorry")): remove(file_fits) message = '{0} raised following exception: "{1}"' raise RuntimeError( - message.format(executable, stdout.rstrip('\n')), + message.format(executable, stdout.rstrip("\n")), ) # Retrieve wavelet transformed data. @@ -198,12 +193,12 @@ def trim_filter(filter_array): min_idx = np.min(non_zero_indices, axis=-1) max_idx = np.max(non_zero_indices, axis=-1) - return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1] + return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1] def get_mr_filters( data_shape, - opt='', + opt="", coarse=False, trim=False, ): # pragma: no cover @@ -256,7 +251,7 @@ def get_mr_filters( return mr_filters[:-1] -def filter_convolve(input_data, filters, filter_rot=False, method='scipy'): +def filter_convolve(input_data, filters, filter_rot=False, method="scipy"): """Filter convolve. This method convolves the input image with the wavelet filters. @@ -315,16 +310,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'): axis=0, ) - return np.array([ - convolve(input_data, filt, method=method) for filt in filters - ]) + return np.array([convolve(input_data, filt, method=method) for filt in filters]) def filter_convolve_stack( input_data, filters, filter_rot=False, - method='scipy', + method="scipy", ): """Filter convolve. @@ -366,7 +359,9 @@ def filter_convolve_stack( """ # Return the convolved data cube. - return np.array([ - filter_convolve(elem, filters, filter_rot=filter_rot, method=method) - for elem in input_data - ]) + return np.array( + [ + filter_convolve(elem, filters, filter_rot=filter_rot, method=method) + for elem in input_data + ] + ) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py new file mode 100644 index 00000000..63847764 --- /dev/null +++ b/tests/test_algorithms.py @@ -0,0 +1,277 @@ +"""UNIT TESTS FOR Algorithms. + +This module contains unit tests for the modopt.opt module. + +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" + +import numpy as np +import numpy.testing as npt +from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight +from pytest_cases import ( + fixture, + parametrize, + parametrize_with_cases, +) + + +SKLEARN_AVAILABLE = True +try: + import sklearn +except ImportError: + SKLEARN_AVAILABLE = False + + +rng = np.random.default_rng() + + +@fixture +def idty(): + """Identity function.""" + return lambda x: x + + +@fixture +def reweight_op(): + """Reweight operator.""" + data3 = np.arange(9).reshape(3, 3).astype(float) + 1 + return reweight.cwbReweight(data3) + + +def build_kwargs(kwargs, use_metrics): + """Build the kwargs for each algorithm, replacing placeholders by true values. + + This function has to be call for each test, as direct parameterization somehow + is not working with pytest-xdist and pytest-cases. + It also adds dummy metric measurement to validate the metric api. + """ + update_value = { + "idty": lambda x: x, + "lin_idty": linear.Identity(), + "reweight_op": reweight.cwbReweight( + np.arange(9).reshape(3, 3).astype(float) + 1 + ), + } + new_kwargs = dict() + print(kwargs) + # update the value of the dict is possible. + for key in kwargs: + new_kwargs[key] = update_value.get(kwargs[key], kwargs[key]) + + if use_metrics: + new_kwargs["linear"] = linear.Identity() + new_kwargs["metrics"] = { + "diff": { + "metric": lambda test, ref: np.sum(test - ref), + "mapping": {"x_new": "test"}, + "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))}, + "early_stopping": False, + } + } + + return new_kwargs + + +@parametrize(use_metrics=[True, False]) +class AlgoCases: + r"""Cases for algorithms. + + Most of the test solves the trivial problem + + .. math:: + \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0 + + More complex and concrete usecases are shown in examples. + """ + + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = data1 + rng.standard_normal(data1.shape) * 1e-6 + max_iter = 20 + + @parametrize( + kwargs=[ + {"beta_update": "idty", "auto_iterate": False, "cost": None}, + {"beta_update": "idty"}, + {"cost": None, "lambda_update": None}, + {"beta_update": "idty", "a_cd": 3}, + {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, + {"restart_strategy": "adaptive", "xi_restart": 0.9}, + { + "restart_strategy": "greedy", + "xi_restart": 0.9, + "min_beta": 1.0, + "s_greedy": 1.1, + }, + ] + ) + def case_forward_backward(self, kwargs, idty, use_metrics): + """Forward Backward case.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + algo = algorithms.ForwardBackward( + self.data1, + grad=gradient.GradBasic(self.data1, idty, idty), + prox=proximity.Positivity(), + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "cost": None, + "auto_iterate": False, + "gamma_update": "idty", + "beta_update": "idty", + }, + {"gamma_update": "idty", "lambda_update": "idty"}, + {"cost": True}, + {"cost": True, "step_size": 2}, + ] + ) + def case_gen_forward_backward(self, kwargs, use_metrics, idty): + """General FB setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + prox_dual_inst = proximity.IdentityProx() + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + algo = algorithms.GenForwardBackward( + self.data1, + grad=grad_inst, + prox_list=[prox_inst, prox_dual_inst], + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + "auto_iterate": False, + }, + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + }, + { + "linear": "lin_idty", + "cost": True, + "reweight": "reweight_op", + }, + ] + ) + def case_condat(self, kwargs, use_metrics, idty): + """Condat Vu Algorithm setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + prox_dual_inst = proximity.IdentityProx() + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + + algo = algorithms.Condat( + self.data1, + self.data2, + grad=grad_inst, + prox=prox_inst, + prox_dual=prox_dual_inst, + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) + def case_pogm(self, kwargs, use_metrics, idty): + """POGM setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + algo = algorithms.POGM( + u=self.data1, + x=self.data1, + y=self.data1, + z=self.data1, + grad=grad_inst, + prox=prox_inst, + **update_kwargs, + ) + + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + GradDescent=[ + algorithms.VanillaGenericGradOpt, + algorithms.AdaGenericGradOpt, + algorithms.ADAMGradOpt, + algorithms.MomentumGradOpt, + algorithms.RMSpropGradOpt, + algorithms.SAGAOptGradOpt, + ] + ) + def case_grad(self, GradDescent, use_metrics, idty): + """Gradient Descent algorithm test.""" + update_kwargs = build_kwargs({}, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + cost_inst = cost.costObj([grad_inst, prox_inst]) + + algo = GradDescent( + self.data1, + grad=grad_inst, + prox=prox_inst, + cost=cost_inst, + **update_kwargs, + ) + algo.iterate() + return algo, update_kwargs + + @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM]) + def case_admm(self, admm, use_metrics, idty): + """ADMM setup.""" + + def optim1(init, obs): + return obs + + def optim2(init, obs): + return obs + + update_kwargs = build_kwargs({}, use_metrics) + algo = admm( + u=self.data1, + v=self.data1, + mu=np.zeros_like(self.data1), + A=linear.Identity(), + B=linear.Identity(), + b=self.data1, + optimizers=(optim1, optim2), + **update_kwargs, + ) + algo.iterate() + return algo, update_kwargs + + +@parametrize_with_cases("algo, kwargs", cases=AlgoCases) +def test_algo(algo, kwargs): + """Test algorithms.""" + if kwargs.get("auto_iterate") is False: + # algo already run + npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) + else: + npt.assert_almost_equal(algo.x_final, AlgoCases.data1) + + if kwargs.get("metrics"): + print(algo.metrics) + npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..62e09095 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,219 @@ +""" +Test for base module. + +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" + +import numpy as np +import numpy.testing as npt +import pytest +from test_helpers import failparam, skipparam + +from modopt.base import backend, np_adjust, transform, types +from modopt.base.backend import LIBRARIES + + +class TestNpAdjust: + """Test for npadjust.""" + + array33 = np.arange(9).reshape((3, 3)) + array233 = np.arange(18).reshape((2, 3, 3)) + arraypad = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 2, 0], + [0, 3, 4, 5, 0], + [0, 6, 7, 8, 0], + [0, 0, 0, 0, 0], + ] + ) + + def test_rotate(self): + """Test rotate.""" + npt.assert_array_equal( + np_adjust.rotate(self.array33), + np.rot90(np.rot90(self.array33)), + err_msg="Incorrect rotation.", + ) + + def test_rotate_stack(self): + """Test rotate_stack.""" + npt.assert_array_equal( + np_adjust.rotate_stack(self.array233), + np.rot90(self.array233, k=2, axes=(1, 2)), + err_msg="Incorrect stack rotation.", + ) + + @pytest.mark.parametrize( + "padding", + [ + 1, + [1, 1], + np.array([1, 1]), + failparam("1", raises=ValueError), + ], + ) + def test_pad2d(self, padding): + """Test pad2d.""" + npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad) + + def test_fancy_transpose(self): + """Test fancy transpose.""" + npt.assert_array_equal( + np_adjust.fancy_transpose(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose", + ) + + def test_ftr(self): + """Test ftr.""" + npt.assert_array_equal( + np_adjust.ftr(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftr", + ) + + def test_ftl(self): + """Test fancy transpose left.""" + npt.assert_array_equal( + np_adjust.ftl(self.array233), + np.array( + [ + [[0, 9], [1, 10], [2, 11]], + [[3, 12], [4, 13], [5, 14]], + [[6, 15], [7, 16], [8, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftl", + ) + + +class TestTransforms: + """Test for the transform module.""" + + cube = np.arange(16).reshape((4, 2, 2)) + map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]) + matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + layout = (2, 2) + fail_layout = (3, 3) + + @pytest.mark.parametrize( + ("func", "indata", "layout", "outdata"), + [ + (transform.cube2map, cube, layout, map), + failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError), + (transform.map2cube, map, layout, cube), + (transform.map2matrix, map, layout, matrix), + (transform.matrix2map, matrix, matrix.shape, map), + ], + ) + def test_map(self, func, indata, layout, outdata): + """Test cube2map.""" + npt.assert_array_equal( + func(indata, layout), + outdata, + ) + if func.__name__ != "map2matrix": + npt.assert_raises(ValueError, func, indata, self.fail_layout) + + def test_cube2matrix(self): + """Test cube2matrix.""" + npt.assert_array_equal( + transform.cube2matrix(self.cube), + self.matrix, + ) + + def test_matrix2cube(self): + """Test matrix2cube.""" + npt.assert_array_equal( + transform.matrix2cube(self.matrix, self.cube[0].shape), + self.cube, + err_msg="Incorrect transformation: matrix2cube", + ) + + +class TestType: + """Test for type module.""" + + data_list = list(range(5)) # noqa: RUF012 + data_int = np.arange(5) + data_flt = np.arange(5).astype(float) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1.0), + (1, 1.0), + (data_list, data_flt), + (data_int, data_flt), + failparam("1.0", 1.0, raises=TypeError), + ], + ) + def test_check_float(self, data, checked): + """Test check float.""" + npt.assert_array_equal(types.check_float(data), checked) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1), + (1, 1), + (data_list, data_int), + (data_flt, data_int), + failparam("1", None, raises=TypeError), + ], + ) + def test_check_int(self, data, checked): + """Test check int.""" + npt.assert_array_equal(types.check_int(data), checked) + + @pytest.mark.parametrize( + ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)] + ) + def test_check_npndarray(self, data, dtype): + """Test check_npndarray.""" + npt.assert_raises( + TypeError, + types.check_npndarray, + data, + dtype=dtype, + ) + + def test_check_callable(self): + """Test callable.""" + npt.assert_raises(TypeError, types.check_callable, 1) + + +@pytest.mark.parametrize( + "backend_name", + [ + skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed") + for name in LIBRARIES + ], +) +def test_tf_backend(backend_name): + """Test Modopt computational backends.""" + xp, checked_backend_name = backend.get_backend(backend_name) + if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]: + raise AssertionError(f"{backend_name} get_backend fails!") + xp_input = backend.change_backend(np.array([10, 10]), backend_name) + if ( + backend.get_array_module(LIBRARIES[backend_name].ones(1)) + != backend.LIBRARIES[backend_name] + or backend.get_array_module(xp_input) != LIBRARIES[backend_name] + ): + raise AssertionError(f"{backend_name} backend fails!") diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py new file mode 100644 index 00000000..0ded847a --- /dev/null +++ b/tests/test_helpers/__init__.py @@ -0,0 +1,5 @@ +"""Utilities for tests.""" + +from .utils import Dummy, failparam, skipparam + +__all__ = ["Dummy", "failparam", "skipparam"] diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py new file mode 100644 index 00000000..41f948a6 --- /dev/null +++ b/tests/test_helpers/utils.py @@ -0,0 +1,27 @@ +""" +Some helper functions for the test parametrization. + +They should be used inside ``@pytest.mark.parametrize`` call. + +:Author: Pierre-Antoine Comby +""" + +import pytest + + +def failparam(*args, raises=None): + """Return a pytest parameterization that should raise an error.""" + if not issubclass(raises, Exception): + raise ValueError("raises should be an expected Exception.") + return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)]) + + +def skipparam(*args, cond=True, reason=""): + """Return a pytest parameterization that should be skip if cond is valid.""" + return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)]) + + +class Dummy: + """Dummy Class.""" + + pass diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 00000000..5c466e5e --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,326 @@ +"""UNIT TESTS FOR MATH. + +This module contains unit tests for the modopt.math module. + +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" + +import pytest +from test_helpers import failparam, skipparam + +import numpy as np +import numpy.testing as npt + + +from modopt.math import convolve, matrix, metrics, stats + +try: + import astropy +except ImportError: # pragma: no cover + ASTROPY_AVAILABLE = False +else: # pragma: no cover + ASTROPY_AVAILABLE = True +try: + from skimage.metrics import structural_similarity as compare_ssim +except ImportError: # pragma: no cover + SKIMAGE_AVAILABLE = False +else: + SKIMAGE_AVAILABLE = True + +rng = np.random.default_rng(1) + + +class TestConvolve: + """Test convolve functions.""" + + array233 = np.arange(18).reshape((2, 3, 3)) + array233_1 = array233 + 1 + result_astropy = np.array( + [ + [210.0, 201.0, 210.0], + [129.0, 120.0, 129.0], + [210.0, 201.0, 210.0], + ] + ) + result_scipy = np.array( + [ + [ + [14.0, 35.0, 38.0], + [57.0, 120.0, 111.0], + [110.0, 197.0, 158.0], + ], + [ + [518.0, 845.0, 614.0], + [975.0, 1578.0, 1137.0], + [830.0, 1331.0, 950.0], + ], + ] + ) + + result_rot_kernel = np.array( + [ + [ + [66.0, 115.0, 82.0], + [153.0, 240.0, 159.0], + [90.0, 133.0, 82.0], + ], + [ + [714.0, 1087.0, 730.0], + [1125.0, 1698.0, 1131.0], + [738.0, 1105.0, 730.0], + ], + ] + ) + + @pytest.mark.parametrize( + ("input_data", "kernel", "method", "result"), + [ + skipparam( + array233[0], + array233_1[0], + "astropy", + result_astropy, + cond=not ASTROPY_AVAILABLE, + reason="astropy not available", + ), + failparam( + array233[0], array233_1, "astropy", result_astropy, raises=ValueError + ), + failparam( + array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError + ), + (array233[0], array233_1[0], "scipy", result_scipy[0]), + ], + ) + def test_convolve(self, input_data, kernel, method, result): + """Test convolve function.""" + npt.assert_allclose(convolve.convolve(input_data, kernel, method), result) + + @pytest.mark.parametrize( + ("result", "rot_kernel"), + [ + (result_scipy, False), + (result_rot_kernel, True), + ], + ) + def test_convolve_stack(self, result, rot_kernel): + """Test convolve stack function.""" + npt.assert_allclose( + convolve.convolve_stack( + self.array233, self.array233_1, rot_kernel=rot_kernel + ), + result, + ) + + +class TestMatrix: + """Test matrix module.""" + + array3 = np.arange(3) + array33 = np.arange(9).reshape((3, 3)) + array23 = np.arange(6).reshape((2, 3)) + gram_schmidt_out = ( + np.array( + [ + [0, 1.0, 2.0], + [3.0, 1.2, -6e-1], + [-1.77635684e-15, 0, 0], + ] + ), + np.array( + [ + [0, 0.4472136, 0.89442719], + [0.91287093, 0.36514837, -0.18257419], + [-1.0, 0, 0], + ] + ), + ) + + @pytest.fixture(scope="module") + def pm_instance(self, request): + """Power Method instance.""" + pm = matrix.PowerMethod( + lambda x_val: x_val.dot(x_val.T), + self.array33.shape, + verbose=True, + rng=np.random.default_rng(0), + ) + return pm + + @pytest.mark.parametrize( + ("return_opt", "output"), + [ + ("orthonormal", gram_schmidt_out[1]), + ("orthogonal", gram_schmidt_out[0]), + ("both", gram_schmidt_out), + failparam("fail!", gram_schmidt_out, raises=ValueError), + ], + ) + def test_gram_schmidt(self, return_opt, output): + """Test gram schmidt.""" + npt.assert_allclose( + matrix.gram_schmidt(self.array33, return_opt=return_opt), output + ) + + def test_nuclear_norm(self): + """Test nuclear norm.""" + npt.assert_almost_equal( + matrix.nuclear_norm(self.array33), + 15.49193338482967, + ) + + def test_project(self): + """Test project.""" + npt.assert_array_equal( + matrix.project(self.array3, self.array3 + 3), + np.array([0, 2.8, 5.6]), + ) + + def test_rot_matrix(self): + """Test rot_matrix.""" + npt.assert_allclose( + matrix.rot_matrix(np.pi / 6), + np.array([[0.8660254, -0.5], [0.5, 0.8660254]]), + ) + + def test_rotate(self): + """Test rotate.""" + npt.assert_array_equal( + matrix.rotate(self.array33, np.pi / 2), + np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]), + ) + + npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2) + + def test_power_method(self, pm_instance, value=1): + """Test power method.""" + npt.assert_almost_equal(pm_instance.spec_rad, value) + npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value) + + +class TestMetrics: + """Test metrics module.""" + + data1 = np.arange(49).reshape(7, 7) + mask = np.ones(data1.shape) + ssim_res = 0.8958315888566867 + ssim_mask_res = 0.8023827544418249 + snr_res = 10.134554256920536 + psnr_res = 14.860761791850397 + mse_res = 0.03265305507330247 + nrmse_res = 0.31136678840022625 + + @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed") + @pytest.mark.parametrize( + ("data1", "data2", "result", "mask"), + [ + (data1, data1**2, ssim_res, None), + (data1, data1**2, ssim_mask_res, mask), + failparam(data1, data1, None, 1, raises=ValueError), + ], + ) + def test_ssim(self, data1, data2, result, mask): + """Test ssim.""" + npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result) + + @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed") + def test_ssim_fail(self): + """Test ssim.""" + npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) + + @pytest.mark.parametrize( + ("metric", "data", "result", "mask"), + [ + (metrics.snr, data1, snr_res, None), + (metrics.snr, data1, snr_res, mask), + (metrics.psnr, data1, psnr_res, None), + (metrics.psnr, data1, psnr_res, mask), + (metrics.mse, data1, mse_res, None), + (metrics.mse, data1, mse_res, mask), + (metrics.nrmse, data1, nrmse_res, None), + (metrics.nrmse, data1, nrmse_res, mask), + failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError), + ], + ) + def test_metric(self, metric, data, result, mask): + """Test snr.""" + npt.assert_almost_equal(metric(data, data**2, mask=mask), result) + + +class TestStats: + """Test stats module.""" + + array33 = np.arange(9).reshape(3, 3) + array233 = np.arange(18).reshape(2, 3, 3) + + @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed") + @pytest.mark.parametrize( + ("norm", "result"), + [ + ( + "max", + np.array( + [ + [0.36787944, 0.60653066, 0.36787944], + [0.60653066, 1.0, 0.60653066], + [0.36787944, 0.60653066, 0.36787944], + ] + ), + ), + ( + "sum", + np.array( + [ + [0.07511361, 0.1238414, 0.07511361], + [0.1238414, 0.20417996, 0.1238414], + [0.07511361, 0.1238414, 0.07511361], + ] + ), + ), + failparam("fail", None, raises=ValueError), + ], + ) + def test_gaussian_kernel(self, norm, result): + """Test Gaussian kernel.""" + npt.assert_allclose( + stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result + ) + + @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed") + def test_import_astropy(self): + """Test missing astropy.""" + npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1) + + def test_mad(self): + """Test mad.""" + npt.assert_equal(stats.mad(self.array33), 2.0) + + def test_sigma_mad(self): + """Test sigma_mad.""" + npt.assert_almost_equal( + stats.sigma_mad(self.array33), + 2.9651999999999998, + ) + + @pytest.mark.parametrize( + ("data1", "data2", "method", "result"), + [ + (array33, array33 + 2, "starck", 12.041199826559248), + failparam(array33, array33, "fail", 0, raises=ValueError), + (array33, array33 + 2, "wiki", 42.110203695399477), + ], + ) + def test_psnr(self, data1, data2, method, result): + """Test PSNR.""" + npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result) + + def test_psnr_stack(self): + """Test psnr stack.""" + npt.assert_almost_equal( + stats.psnr_stack(self.array233, self.array233 + 2), + 12.041199826559248, + ) + + npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33) diff --git a/tests/test_opt.py b/tests/test_opt.py new file mode 100644 index 00000000..2ea58c27 --- /dev/null +++ b/tests/test_opt.py @@ -0,0 +1,572 @@ +"""UNIT TESTS FOR OPT. + +This module contains tests for the modopt.opt module. + +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" + +import numpy as np +import numpy.testing as npt +import pytest +from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref + +from modopt.opt import cost, gradient, linear, proximity, reweight + +from test_helpers import Dummy + +SKLEARN_AVAILABLE = True +try: + import sklearn +except ImportError: + SKLEARN_AVAILABLE = False + +PTWT_AVAILABLE = True +try: + import ptwt + import cupy +except ImportError: + PTWT_AVAILABLE = False + +PYWT_AVAILABLE = True +try: + import pywt + import joblib +except ImportError: + PYWT_AVAILABLE = False + +rng = np.random.default_rng() + + +# Basic functions to be used as operators or as dummy functions +def func_identity(x_val, *args, **kwargs): + """Return x.""" + return x_val + + +def func_double(x_val, *args, **kwargs): + """Double x.""" + return x_val * 2 + + +def func_sq(x_val, *args, **kwargs): + """Square x.""" + return x_val**2 + + +def func_cube(x_val, *args, **kwargs): + """Cube x.""" + return x_val**3 + + +@case(tags="cost") +@parametrize( + ("cost_interval", "n_calls", "converged"), + [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)], +) +def case_cost_op(cost_interval, n_calls, converged): + """Case function for costs.""" + dummy_inst1 = Dummy() + dummy_inst1.cost = func_sq + dummy_inst2 = Dummy() + dummy_inst2.cost = func_cube + + cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval) + + for _ in range(n_calls + 1): + cost_obj.get_cost(2) + return cost_obj, converged + + +@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost") +def test_costs(cost_obj, converged): + """Test cost.""" + npt.assert_equal(cost_obj.get_cost(2), converged) + if cost_obj._cost_interval: + npt.assert_equal(cost_obj.cost, 12) + + +def test_raise_cost(): + """Test error raising for cost.""" + npt.assert_raises(TypeError, cost.costObj, 1) + npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()]) + + +@case(tags="grad") +@parametrize(call=("op", "trans_op", "trans_op_op")) +def case_grad_parent(call): + """Case for gradient parent.""" + input_data = np.arange(9).reshape(3, 3) + callables = { + "op": func_sq, + "trans_op": func_cube, + "get_grad": func_identity, + "cost": lambda input_val: 1.0, + } + + grad_op = gradient.GradParent( + input_data, + **callables, + data_type=np.floating, + ) + if call != "trans_op_op": + result = callables[call](input_data) + else: + result = callables["trans_op"](callables["op"](input_data)) + + grad_call = getattr(grad_op, call)(input_data) + return grad_call, result + + +@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad") +def test_grad_op(grad_values, result): + """Test Gradient operator.""" + npt.assert_equal(grad_values, result) + + +@pytest.fixture +def grad_basic(): + """Case for GradBasic.""" + input_data = np.arange(9).reshape(3, 3) + grad_op = gradient.GradBasic( + input_data, + func_sq, + func_cube, + verbose=True, + ) + grad_op.get_grad(input_data) + return grad_op + + +def test_grad_basic(grad_basic): + """Test grad basic.""" + npt.assert_array_equal( + grad_basic.grad, + np.array( + [ + [0, 0, 8.0], + [2.16000000e2, 1.72800000e3, 8.0e3], + [2.70000000e4, 7.40880000e4, 1.75616000e5], + ] + ), + err_msg="Incorrect gradient.", + ) + + +def test_grad_basic_cost(grad_basic): + """Test grad_basic cost.""" + npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0) + + +def test_grad_op_raises(): + """Test raise error.""" + npt.assert_raises( + TypeError, + gradient.GradParent, + 1, + func_sq, + func_cube, + ) + + +############# +# LINEAR OP # +############# + + +class LinearCases: + """Linear operator cases.""" + + def case_linear_identity(self): + """Case linear operator identity.""" + linop = linear.Identity() + + data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 + + return linop, data_op, data_adj_op, res_op, res_adj_op + + def case_linear_wavelet_convolve(self): + """Case linear operator wavelet.""" + linop = linear.WaveletConvolve( + filters=np.arange(8).reshape(2, 2, 2).astype(float) + ) + data_op = np.arange(4).reshape(1, 2, 2).astype(float) + data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) + res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) + res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + @parametrize( + compute_backend=[ + pytest.param( + "numpy", + marks=pytest.mark.skipif( + not PYWT_AVAILABLE, reason="PyWavelet not available." + ), + ), + pytest.param( + "cupy", + marks=pytest.mark.skipif( + not PTWT_AVAILABLE, reason="Pytorch Wavelet not available." + ), + ), + ] + ) + def case_linear_wavelet_transform(self, compute_backend): + """Case linear wavelet operator.""" + linop = linear.WaveletTransform( + wavelet_name="haar", + shape=(8, 8), + level=2, + ) + data_op = np.arange(64).reshape(8, 8).astype(float) + res_op, slices, shapes = pywt.ravel_coeffs( + pywt.wavedecn(data_op, "haar", level=2) + ) + data_adj_op = linop.op(data_op) + res_adj_op = pywt.waverecn( + pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar" + ) + return linop, data_op, data_adj_op, res_op, res_adj_op + + @parametrize(weights=[[1.0, 1.0], None]) + def case_linear_combo(self, weights): + """Case linear operator combo with weights.""" + parent = linear.LinearParent( + func_sq, + func_cube, + ) + linop = linear.LinearCombo([parent, parent], weights) + + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 8.0 * (2 if weights else 1), + ) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + @parametrize(factor=[1, 1 + 1j]) + def case_linear_matrix(self, factor): + """Case linear operator from matrix.""" + linop = linear.MatrixOperator(np.eye(5) * factor) + data_op = np.arange(5) + data_adj_op = np.arange(5) + res_op = np.arange(5) * factor + res_adj_op = np.arange(5) * np.conjugate(factor) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get adj_op relative data.""" + return linop.adj_op, data_adj_op, res_adj_op + + +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get op relative data.""" + return linop.op, data_op, res_op + + +@parametrize( + ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)] +) +def test_linear_operator(action, data, result): + """Test linear operator.""" + npt.assert_almost_equal(action(data), result) + + +dummy_with_op = Dummy() +dummy_with_op.op = lambda x: x + + +@pytest.mark.parametrize( + ("args", "error"), + [ + ([linear.LinearParent(func_sq, func_cube)], TypeError), + ([[]], ValueError), + ([[Dummy()]], ValueError), + ([[dummy_with_op]], ValueError), + ([[]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError), + ], +) +def test_linear_combo_errors(args, error): + """Test linear combo_errors.""" + npt.assert_raises(error, linear.LinearCombo, *args) + + +############# +# Proximity # +############# + + +class ProxCases: + """Class containing all proximal operator cases. + + Each case should return 4 parameters: + 1. The proximal operator + 2. test input data + 3. Expected result data + 4. Expected cost value. + """ + + weights = np.ones(9).reshape(3, 3).astype(float) * 3 + array33 = np.arange(9).reshape(3, 3).astype(float) + array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + array33_st2 = array33_st * -1 + + array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]]) + + array233 = np.arange(18).reshape(2, 3, 3).astype(float) + array233_2 = np.array( + [ + [ + [2.73843189, 3.14594066, 3.55344943], + [3.9609582, 4.36846698, 4.77597575], + [5.18348452, 5.59099329, 5.99850206], + ], + [ + [8.07085295, 9.2718846, 10.47291625], + [11.67394789, 12.87497954, 14.07601119], + [15.27704284, 16.47807449, 17.67910614], + ], + ] + ) + array233_3 = np.array( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [ + [4.00795282, 4.60438026, 5.2008077], + [5.79723515, 6.39366259, 6.99009003], + [7.58651747, 8.18294492, 8.77937236], + ], + ] + ) + + def case_prox_parent(self): + """Case prox parent.""" + return ( + proximity.ProximityParent( + func_sq, + func_double, + ), + 3, + 9, + 6, + ) + + def case_prox_identity(self): + """Case prox identity.""" + return proximity.IdentityProx(), 3, 3, 0 + + def case_prox_positivity(self): + """Case prox positivity.""" + return proximity.Positivity(), -3, 0, 0 + + def case_prox_sparsethresh(self): + """Case prox sparsethreshosld.""" + return ( + proximity.SparseThreshold(linear.Identity(), weights=self.weights), + self.array33, + self.array33_st, + 108, + ) + + @parametrize( + "lowr_type, initial_rank, operator, result, cost", + [ + ("standard", None, None, array233_2, 469.3913294246498), + ("standard", 1, None, array233_2, 469.3913294246498), + ("ngole", None, func_double, array233_3, 469.3913294246498), + ], + ) + def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost): + """Case prox lowrank.""" + return ( + proximity.LowRankMatrix( + 10, + lowr_type=lowr_type, + initial_rank=initial_rank, + operator=operator, + thresh_type="hard" if lowr_type == "standard" else "soft", + ), + self.array233, + result, + cost, + ) + + def case_prox_linear_comp(self): + """Case prox linear comp.""" + return ( + proximity.LinearCompositionProx( + linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0] + ), + self.array33, + self.array33_st, + 108, + ) + + def case_prox_ridge(self): + """Case prox ridge.""" + return ( + proximity.Ridge(linear.Identity(), self.weights), + self.array33 * (1 + 1j), + self.array33 * (1 + 1j) / 7, + 1224, + ) + + @parametrize("alpha, beta", [(0, weights), (weights, 0)]) + def case_prox_elasticnet(self, alpha, beta): + """Case prox elastic net.""" + if np.all(alpha == 0): + data = self.case_prox_sparsethresh()[1:] + else: + data = self.case_prox_ridge()[1:] + return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data) + + @parametrize( + "beta, k_value, data, result, cost", + [ + (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2), + (3, 5, array33.flatten(), array33_support.flatten(), 684.0), + ( + 6.0, + 9, + array33.flatten() * (1 + 1j), + array33.flatten() * (1 + 1j) / 7, + 1224, + ), + ], + ) + def case_prox_Ksupport(self, beta, k_value, data, result, cost): + """Case prox K-support norm.""" + return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost) + + @parametrize(use_weights=[True, False]) + def case_prox_grouplasso(self, use_weights): + """Case GroupLasso proximity.""" + if use_weights: + weights = np.tile(self.weights, (4, 1, 1)) + else: + weights = np.tile(np.zeros((3, 3)), (4, 1, 1)) + + random_data = 3 * rng.random(weights[0].shape) + random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1)) + if use_weights: + gl_result_data = 2 * random_data_tile - 3 + gl_result_data = ( + np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2 + ) + cost = np.sum(random_data_tile) * 6 + else: + gl_result_data = random_data_tile + cost = 0 + return ( + proximity.GroupLASSO( + weights=weights, + ), + random_data_tile, + gl_result_data, + cost, + ) + + @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.") + def case_prox_owl(self): + """Case prox for Ordered Weighted L1 Norm.""" + return ( + proximity.OrderedWeightedL1Norm(self.weights.flatten()), + self.array33.flatten(), + self.array33_st.flatten(), + 108.0, + ) + + +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_op(operator, input_data, op_result, cost_result): + """Test proximity operator op.""" + npt.assert_almost_equal(operator.op(input_data), op_result) + + +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_cost(operator, input_data, op_result, cost_result): + """Test proximity operator cost.""" + npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result) + + +@parametrize( + "arg, error", + [ + (1, TypeError), + ([], ValueError), + ([Dummy()], ValueError), + ([dummy_with_op], ValueError), + ], +) +def test_error_prox_combo(arg, error): + """Test errors for proximity combo.""" + npt.assert_raises(error, proximity.ProximityCombo, arg) + + +@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed") +def test_fail_sklearn(): + """Test fail OWL with sklearn.""" + npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) + + +@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.") +def test_fail_owl(): + """Test errors for Ordered Weighted L1 Norm.""" + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + np.arange(10), + ) + + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + -np.arange(10), + ) + + +def test_fail_lowrank(): + """Test fail for lowrank.""" + prox_op = proximity.LowRankMatrix(10, lowr_type="fail") + npt.assert_raises(ValueError, prox_op.op, 0) + + +def test_fail_Ksupport_norm(): + """Test fail for K-support norm.""" + npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) + + +def test_reweight(): + """Test for reweight module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + 1 + data2 = np.array( + [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], + ) + + rw = reweight.cwbReweight(data1) + rw.reweight(data1) + + npt.assert_array_equal( + rw.weights, + data2, + err_msg="Incorrect CWB re-weighting.", + ) + + npt.assert_raises(ValueError, rw.reweight, data1[0]) diff --git a/tests/test_signal.py b/tests/test_signal.py new file mode 100644 index 00000000..6dbb0bba --- /dev/null +++ b/tests/test_signal.py @@ -0,0 +1,327 @@ +"""UNIT TESTS FOR SIGNAL. + +This module contains unit tests for the modopt.signal module. + +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" + +import numpy as np +import numpy.testing as npt +import pytest +from test_helpers import failparam + +from modopt.signal import filter, noise, positivity, svd, validation, wavelet + + +class TestFilter: + """Test filter module.""" + + @pytest.mark.parametrize( + ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] + ) + def test_gaussian_filter(self, norm, result): + """Test gaussian filter.""" + npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) + + def test_mex_hat(self): + """Test mexican hat filter.""" + npt.assert_almost_equal( + filter.mex_hat(2, 1), + -0.35213905225713371, + ) + + def test_mex_hat_dir(self): + """Test directional mexican hat filter.""" + npt.assert_almost_equal( + filter.mex_hat_dir(1, 2, 1), + 0.17606952612856686, + ) + + +class TestNoise: + """Test noise module.""" + + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.array( + [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]], + ) + data3 = np.array( + [ + [0.3455842, 1.8216181, 2.3304371], + [1.6968428, 4.9053559, 5.4463746], + [5.4630468, 7.5811181, 8.3645724], + ] + ) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + + @pytest.mark.parametrize( + ("data", "noise_type", "sigma", "data_noise"), + [ + (data1, "poisson", 1, data2), + (data1, "gauss", 1, data3), + (data1, "gauss", (1, 1, 1), data3), + failparam(data1, "fail", 1, data1, raises=ValueError), + ], + ) + def test_add_noise(self, data, noise_type, sigma, data_noise): + """Test add_noise.""" + rng = np.random.default_rng(1) + npt.assert_almost_equal( + noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng), + data_noise, + ) + + @pytest.mark.parametrize( + ("threshold_type", "result"), + [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)], + ) + def test_thresh(self, threshold_type, result): + """Test threshold.""" + npt.assert_array_equal( + noise.thresh(self.data1, 5, threshold_type=threshold_type), result + ) + + +class TestPositivity: + """Test positivity module.""" + + data1 = np.arange(9).reshape(3, 3).astype(float) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (-1.0, -float(0)), + (-1, 0), + (data1 - 5, data5), + ( + np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object), + np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object), + ), + failparam("-1", None, raises=TypeError), + ], + ) + def test_positive(self, value, expected): + """Test positive.""" + if isinstance(value, np.ndarray) and value.dtype == "O": + for v, e in zip(positivity.positive(value), expected): + npt.assert_array_equal(v, e) + else: + npt.assert_array_equal(positivity.positive(value), expected) + + +class TestSVD: + """Test for svd module.""" + + @pytest.fixture + def data(self): + """Initialize test data.""" + data1 = np.arange(18).reshape(9, 2).astype(float) + data2 = np.arange(32).reshape(16, 2).astype(float) + data3 = np.array( + [ + np.array( + [ + [-0.01744594, -0.61438865], + [-0.08435304, -0.50397984], + [-0.15126014, -0.39357102], + [-0.21816724, -0.28316221], + [-0.28507434, -0.17275339], + [-0.35198144, -0.06234457], + [-0.41888854, 0.04806424], + [-0.48579564, 0.15847306], + [-0.55270274, 0.26888188], + ] + ), + np.array([42.23492742, 1.10041151]), + np.array( + [ + [-0.67608034, -0.73682791], + [0.73682791, -0.67608034], + ] + ), + ], + dtype=object, + ) + data4 = np.array( + [ + [-1.05426832e-16, 1.0], + [2.0, 3.0], + [4.0, 5.0], + [6.0, 7.0], + [8.0, 9.0], + [1.0e1, 1.1e1], + [1.2e1, 1.3e1], + [1.4e1, 1.5e1], + [1.6e1, 1.7e1], + ] + ) + + data5 = np.array( + [ + [0.49815487, 0.54291537], + [2.40863386, 2.62505584], + [4.31911286, 4.70719631], + [6.22959185, 6.78933678], + [8.14007085, 8.87147725], + [10.05054985, 10.95361772], + [11.96102884, 13.03575819], + [13.87150784, 15.11789866], + [15.78198684, 17.20003913], + ] + ) + return (data1, data2, data3, data4, data5) + + @pytest.fixture + def svd0(self, data): + """Compute SVD of first data sample.""" + return svd.calculate_svd(data[0]) + + def test_find_n_pc(self, data): + """Test find number of principal component.""" + npt.assert_equal( + svd.find_n_pc(svd.svd(data[1])[0]), + 2, + err_msg="Incorrect number of principal components.", + ) + + def test_n_pc_fail_non_square(self): + """Test find_n_pc.""" + npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) + + def test_calculate_svd(self, data, svd0): + """Test calculate_svd.""" + errors = [] + for i, name in enumerate("USV"): + try: + npt.assert_almost_equal(svd0[i], data[2][i]) + except AssertionError: + errors.append(name) + if errors: + raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors)) + + @pytest.mark.parametrize( + ("n_pc", "idx_res"), + [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)], + ) + def test_svd_thresh(self, data, n_pc, idx_res): + """Test svd_tresh.""" + npt.assert_almost_equal( + svd.svd_thresh(data[0], n_pc=n_pc), + data[idx_res], + ) + + def test_svd_tresh_invalid_type(self): + """Test svd_tresh failure.""" + npt.assert_raises(TypeError, svd.svd_thresh, 1) + + @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)]) + def test_svd_thresh_coef(self, data, operator): + """Test svd_tresh_coef.""" + npt.assert_almost_equal( + svd.svd_thresh_coef(data[0], operator, 0), + data[0], + err_msg="Incorrect SVD coefficient tresholding", + ) + + # TODO test_svd_thresh_coef_fast + + +class TestValidation: + """Test validation Module.""" + + array33 = np.arange(9).reshape(3, 3) + + def test_transpose_test(self): + """Test transpose_test.""" + npt.assert_equal( + validation.transpose_test( + lambda x_val, y_val: x_val.dot(y_val), + lambda x_val, y_val: x_val.dot(y_val.T), + self.array33.shape, + x_args=self.array33, + rng=2, + ), + None, + ) + + +class TestWavelet: + """Test Wavelet Module.""" + + @pytest.fixture + def data(self): + """Set test parameter values.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.arange(36).reshape(4, 3, 3).astype(float) + data3 = np.array( + [ + [ + [6.0, 20, 26.0], + [36.0, 84.0, 84.0], + [90, 164.0, 134.0], + ], + [ + [78.0, 155.0, 134.0], + [225.0, 408.0, 327.0], + [270, 461.0, 350], + ], + [ + [150, 290, 242.0], + [414.0, 732.0, 570], + [450, 758.0, 566.0], + ], + [ + [222.0, 425.0, 350], + [603.0, 1056.0, 813.0], + [630, 1055.0, 782.0], + ], + ] + ) + + data4 = np.array( + [ + [6496.0, 9796.0, 6544.0], + [9924.0, 14910, 9924.0], + [6544.0, 9796.0, 6496.0], + ] + ) + + data5 = np.array( + [ + [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], + [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], + [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], + ] + ) + return (data1, data2, data3, data4, data5) + + @pytest.mark.parametrize( + ("idx_data", "idx_filter", "idx_res", "filter_rot"), + [(0, 1, 2, False), (1, 1, 3, True)], + ) + def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot): + """Test filter_convolve.""" + npt.assert_almost_equal( + wavelet.filter_convolve( + data[idx_data], data[idx_filter], filter_rot=filter_rot + ), + data[idx_res], + err_msg="Inccorect filter comvolution.", + ) + + def test_filter_convolve_stack(self, data): + """Test filter_convolve_stack.""" + npt.assert_almost_equal( + wavelet.filter_convolve_stack(data[0], data[0]), + data[4], + err_msg="Inccorect filter stack comvolution.", + )