-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ApproximateGradientSumFunction and SGFunction #1345
Changes from all commits
9b22353
670e3c8
4f63f10
5526597
277b85d
a29fd55
c0881d0
1ff0bd2
1c015b1
8c3e8a2
136a2fe
f06c603
f867d19
5a0aaf0
8f4d25c
b1b99b9
50aad5d
fd2c75e
a81b01f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from numbers import Number | ||
import numpy as np | ||
from functools import reduce | ||
from cil.optimisation.utilities import FunctionNumberGenerator | ||
|
||
class Function(object): | ||
|
||
|
@@ -206,11 +207,12 @@ class SumFunction(Function): | |
""" | ||
|
||
def __init__(self, *functions ): | ||
|
||
super(SumFunction, self).__init__() | ||
if len(functions) < 2: | ||
self.num_functions = len(functions) | ||
if self.num_functions < 2: | ||
raise ValueError('At least 2 functions need to be passed') | ||
self.functions = functions | ||
super(SumFunction, self).__init__() | ||
|
||
@property | ||
def L(self): | ||
|
@@ -300,6 +302,11 @@ def gradient(self, x, out=None): | |
else: | ||
out += f.gradient(x) | ||
|
||
def __getitem__(self, ind): | ||
""" Return the function in the :code:`ind` index. | ||
""" | ||
return self.functions[ind] | ||
|
||
def __add__(self, other): | ||
|
||
""" Addition for the SumFunction. | ||
|
@@ -320,7 +327,82 @@ def __add__(self, other): | |
else: | ||
return super(SumFunction, self).__add__(other) | ||
|
||
class ApproximateGradientSumFunction(SumFunction): | ||
|
||
r"""ApproximateGradientSumFunction represents the following sum | ||
|
||
.. math:: \sum_{i=1}^{n} F_{i} = (F_{1} + F_{2} + ... + F_{n}) | ||
|
||
where :math:`n` is the number of functions. | ||
|
||
Parameters: | ||
----------- | ||
functions : list(functions) | ||
A list of functions: :code:`[F_{1}, F_{2}, ..., F_{n}]`. Each function is assumed to be smooth function with an implemented :func:`~Function.gradient` method. | ||
selection : :obj:`string`, Default = :code:`random`. It can be :code:`random`, :code:`random_permutation`, :code:`fixed_permutation`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea to allow a string here, but the doc needs to mention the generator option. Also, I recommend referring to the doc of |
||
Selection method for each function in the list. | ||
- :code:`random`: Every function is selected randomly with replacement. | ||
- :code:`random_permutation`: Every function is selected randomly without replacement. After selecting all the functions in the list, i.e., after one epoch, the list is randomly permuted. | ||
- :code:`fixed_permuation`: Every function is selected randomly without replacement and the list of function is permuted only once. | ||
|
||
Example | ||
------- | ||
|
||
.. math:: \sum_{i=1}^{n} F_{i}(x) = \sum_{i=1}^{n}\|A_{i} x - b_{i}\|^{2} | ||
|
||
>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets)) | ||
>>> f = ApproximateGradientSumFunction(list_of_functions) | ||
|
||
>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets)) | ||
>>> fng = FunctionNumberGenetor(len(list_of_function), sampling_method="fixed_permutation") | ||
>>> f = ApproximateGradientSumFunction(list_of_functions, fng) | ||
|
||
|
||
""" | ||
|
||
def __init__(self, functions, selection=None): | ||
|
||
super(ApproximateGradientSumFunction, self).__init__(*functions) | ||
|
||
self.functions_used = [] | ||
self.functions_passes = [0] | ||
|
||
if isinstance(selection, str): | ||
self.selection = FunctionNumberGenerator(len(functions), sampling_method=selection ) | ||
else: | ||
self.selection = selection | ||
|
||
def __call__(self, x): | ||
|
||
r"""Returns the value of the sum of functions at :math:`x`. | ||
|
||
.. math:: \sum_{i=1}^{n}(F_{i}(x)) = (F_{1}(x) + F_{2}(x) + ... + F_{n}(x)) | ||
|
||
""" | ||
|
||
return super(ApproximateGradientSumFunction, self).__call__(x) | ||
|
||
def full_gradient(self, x, out=None): | ||
|
||
r""" Computes the full gradient at :code:`x`. It is the sum of all the gradients for each function. """ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why Also, currently doc is wrong as it is the average of the gradients of all functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc is correct now. still remains why we'd call this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the "old" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be resolved now |
||
return super(ApproximateGradientSumFunction, self).gradient(x, out=out) | ||
|
||
def approximate_gradient(self, function_num, x, out=None): | ||
|
||
""" Computes the gradient for each selected function at :code:`x`.""" | ||
raise NotImplemented | ||
|
||
def gradient(self, x, out=None): | ||
|
||
self.next_function_num() | ||
return self.approximate_gradient(self.function_num, x, out=out) | ||
|
||
def next_function_num(self): | ||
|
||
# Select the next function using the selection:=function_number_generator | ||
self.function_num = next(self.selection) | ||
self.functions_used.append(self.function_num) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should do a range check and write a nice error message if out of range |
||
|
||
class ScaledFunction(Function): | ||
|
||
r""" ScaledFunction represents the scalar multiplication with a Function. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# -*- coding: utf-8 -*- | ||
# This work is part of the Core Imaging Library (CIL) developed by CCPi | ||
# (Collaborative Computational Project in Tomographic Imaging), with | ||
# substantial contributions by UKRI-STFC and University of Manchester. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from cil.optimisation.functions import ApproximateGradientSumFunction | ||
|
||
class SGFunction(ApproximateGradientSumFunction): | ||
|
||
r""" Stochastic Gradient Function (SGFunction) | ||
|
||
The SGFunction represents the objective function :math:`\sum_{i=1}^{n}F_{i}(x)`, where | ||
:math:`n` denotes the number of subsets. | ||
|
||
Parameters: | ||
----------- | ||
functions : list(functions) | ||
A list of functions: :code:`[F_{1}, F_{2}, ..., F_{n}]`. Each function is assumed to be smooth function with an implemented :func:`~Function.gradient` method. | ||
sampling : :obj:`string`, Default = :code:`random` | ||
Selection process for each function in the list. It can be :code:`random` or :code:`sequential`. | ||
replacement : :obj:`boolean`. Default = :code:`True` | ||
The same subset can be selected when :code:`replacement=True`. | ||
Comment on lines
+31
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. somehow avoid duplication as this is incorrect |
||
|
||
Note | ||
---- | ||
|
||
The :meth:`~SGFunction.gradient` computes the `gradient` of one function from the list :math:`[F_{1},\cdots,F_{n}]`, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. times N |
||
|
||
.. math:: \partial F_{i}(x) . | ||
|
||
The ith function is selected from the :meth:`~SubsetSumFunction.next_subset` method. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. incorrect |
||
|
||
""" | ||
|
||
def __init__(self, functions, selection="random"): | ||
|
||
super(SGFunction, self).__init__(functions, selection) | ||
|
||
def approximate_gradient(self,subset_num, x, out): | ||
|
||
""" Returns the gradient of the selected function at :code:`x`. The function is selected using the :meth:`~SubsetSumFunction.next_subset` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. times N |
||
""" | ||
# Compute new gradient for current function | ||
self.functions[self.function_num].gradient(x, out=out) | ||
out*=self.num_functions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# -*- coding: utf-8 -*- | ||
# This work is part of the Core Imaging Library (CIL) developed by CCPi | ||
# (Collaborative Computational Project in Tomographic Imaging), with | ||
# substantial contributions by UKRI-STFC and University of Manchester. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
|
||
def FunctionNumberGenerator(num_functions, sampling_method = "random"): | ||
|
||
r""" FunctionNumberGenerator | ||
|
||
The FunctionNumberGenerator selects randomly a number from a list of numbers/functions of length :code:`num_functions`. | ||
|
||
Parameters: | ||
----------- | ||
num_functions : :obj: int | ||
The total number of functions used in our Stochastic estimators e.g., | ||
sampling_method : :obj:`string`, Default = :code:`random` | ||
Selection process for each function in the list. It can be :code:`random`, :code:`random_permutation`, :code:`fixed_permutation`. | ||
- :code:`random`: Every function is selected randomly with replacement. | ||
- :code:`random_permutation`: Every function is selected randomly without replacement. After selecting all the functions in the list, i.e., after one epoch, the list is randomly permuted. | ||
- :code:`fixed_permuation`: Every function is selected randomly without replacement and the list of function is permuted only once. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why anyone would want to use Probably the only 2 fixed orders that make sense are 0,1,2,3... and 0,n/2,n/3,3n/4,... (with the last appropriate when the functions correspond to projections ordered by angle). Maybe you could call these "fixed_sequential" and "fixed_staggered"? @robbietuk would have code to implement the "staggered" approach |
||
|
||
Example | ||
------- | ||
|
||
>>> fng = FunctionNumberGenerator(10) | ||
>>> print(next(fng)) | ||
|
||
>>> number_of_functions = 10 | ||
>>> fng = FunctionNumberGenerator(number_of_functions, sampling_method="fixed_permutation") | ||
>>> epochs = 2 | ||
>>> generated_numbers=[print(next(fng), end=' ') for _ in range(epochs*number_of_functions)] | ||
|
||
""" | ||
|
||
if not isinstance(num_functions, int): | ||
raise ValueError(" Integer is required for `num_functions`, {} is passed. ".format(num_functions) ) | ||
|
||
# Accepted sampling_methods | ||
default_sampling_methods = ["random", "random_permutation","fixed_permutation"] | ||
|
||
if sampling_method not in default_sampling_methods: | ||
raise NotImplementedError("Only {} are implemented at the moment.".format(default_sampling_methods)) | ||
|
||
replacement=False | ||
if sampling_method=="random": | ||
replacement=True | ||
else: | ||
if sampling_method=="random_permutation": | ||
shuffle="random" | ||
else: | ||
shuffle="single" | ||
|
||
function_num = -1 | ||
index = 0 | ||
|
||
if replacement is False: | ||
# create a list of functions without replacement, first permutation | ||
list_of_functions = np.random.choice(range(num_functions),num_functions, replace=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this variable should be called |
||
|
||
while(True): | ||
|
||
if replacement is False: | ||
|
||
# list of functions already permuted | ||
function_num = list_of_functions[index] | ||
index+=1 | ||
|
||
if index == num_functions: | ||
index=0 | ||
|
||
# For random shuffle, at the end of each epoch, we permute the list again | ||
if shuffle=="random": | ||
list_of_functions = np.random.choice(range(num_functions),num_functions, replace=False) | ||
else: | ||
|
||
# at each iteration (not epoch) function is randomly selected | ||
function_num = np.random.randint(0, num_functions) | ||
|
||
yield function_num | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# -*- coding: utf-8 -*- | ||
# This work is part of the Core Imaging Library (CIL) developed by CCPi | ||
# (Collaborative Computational Project in Tomographic Imaging), with | ||
# substantial contributions by UKRI-STFC and University of Manchester. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .FunctionNumberGenerator import FunctionNumberGenerator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# -*- coding: utf-8 -*- | ||
# This work is part of the Core Imaging Library (CIL) developed by CCPi | ||
# (Collaborative Computational Project in Tomographic Imaging), with | ||
# substantial contributions by UKRI-STFC and University of Manchester. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from setuptools import setup | ||
import os | ||
import sys | ||
|
||
|
||
|
||
cil_version = '21.2.0' | ||
if '141' != '-1': | ||
cil_version += '.dev141' | ||
|
||
|
||
setup( | ||
name="cil", | ||
version=cil_version, | ||
packages=['cil' , 'cil.io', | ||
'cil.framework', 'cil.optimisation', | ||
'cil.optimisation.functions', | ||
'cil.optimisation.algorithms', | ||
'cil.optimisation.operators', | ||
'cil.optimisation.utilities', | ||
'cil.processors', | ||
'cil.utilities', 'cil.utilities.jupyter', | ||
'cil.plugins', | ||
'cil.plugins.ccpi_regularisation', | ||
'cil.plugins.ccpi_regularisation.functions', | ||
'cil.plugins.tigre', | ||
'cil.plugins.astra', | ||
'cil.plugins.astra.operators', | ||
'cil.plugins.astra.processors', | ||
'cil.plugins.astra.utilities', | ||
'cil.recon'], | ||
|
||
|
||
# metadata for upload to PyPI | ||
author="CCPi developers", | ||
maintainer="Edoardo Pasca", | ||
maintainer_email="edoardo.pasca@stfc.ac.uk", | ||
description='CCPi Core Imaging Library', | ||
license="Apache v2.0", | ||
keywords="Python Framework", | ||
url="http://www.ccpi.ac.uk/cil", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why we'd want to throw an error here. Seems more useful to handle only 1 function, or even zero.