-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rename SubsetSumFunction to ApproximateGradientSumFunction, use Funct…
…ionNumberGenerator
- Loading branch information
1 parent
5a0aaf0
commit 8f4d25c
Showing
5 changed files
with
178 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from numbers import Number | ||
import numpy as np | ||
from functools import reduce | ||
from cil.optimisation.utilities import FunctionNumberGenerator | ||
|
||
class Function(object): | ||
|
||
|
@@ -326,63 +327,50 @@ def __add__(self, other): | |
else: | ||
return super(SumFunction, self).__add__(other) | ||
|
||
class SubsetSumFunction(SumFunction): | ||
class ApproximateGradientSumFunction(SumFunction): | ||
|
||
r"""SubsetSumFunction represents the following sum | ||
r"""ApproximateGradientSumFunction represents the following sum | ||
This comment has been minimized.
Sorry, something went wrong. |
||
.. math:: \sum_{i=1}^{n} F_{i} = (F_{1} + F_{2} + ... + F_{n}) | ||
where :math:`n` is the number of subsets. | ||
where :math:`n` is the number of functions. | ||
Parameters: | ||
----------- | ||
functions : list(functions) | ||
A list of functions: :code:`[F_{1}, F_{2}, ..., F_{n}]`. Each function is assumed to be smooth function with an implemented :func:`~Function.gradient` method. | ||
sampling : :obj:`string`, Default = :code:`random` | ||
Selection process for each function in the list. It can be :code:`random` or :code:`sequential`. | ||
replacement : :obj:`boolean`. Default = :code:`True` | ||
The same subset can be selected when :code:`replacement=True`. | ||
suffle : :obj:`string`. Default = :code:`random`. | ||
This is only for the :code:`replacement=False` case. For :code:`suffle="single"`, we permute the list only once. | ||
For :code:`suffle="random"`, we permute the list of subsets after one epoch. | ||
selection : :obj:`string`, Default = :code:`random`. It can be :code:`random`, :code:`random_permutation`, :code:`fixed_permutation`. | ||
Selection method for each function in the list. | ||
- :code:`random`: Every function is selected randomly with replacement. | ||
- :code:`random_permutation`: Every function is selected randomly without replacement. After selecting all the functions in the list, i.e., after one epoch, the list is randomly permuted. | ||
- :code:`fixed_permuation`: Every function is selected randomly without replacement and the list of function is permuted only once. | ||
Example | ||
------- | ||
.. math:: \sum_{i=1}^{n} F_{i}(x) = \sum_{i=1}^{n}\|A_{i} x - b_{i}\|^{2} | ||
>>> f = SubsetSumFunction([LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets)) | ||
>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets)) | ||
>>> f = ApproximateGradientSumFunction(list_of_functions) | ||
>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets)) | ||
>>> fng = FunctionNumberGenetor(len(list_of_function), sampling_method="fixed_permutation") | ||
>>> f = ApproximateGradientSumFunction(list_of_functions, fng) | ||
""" | ||
|
||
def __init__(self, functions, | ||
sampling = "random", | ||
replacement = True, | ||
suffle = "random"): | ||
|
||
super(SubsetSumFunction, self).__init__(*functions) | ||
|
||
self.sampling = sampling | ||
self.replacement = replacement | ||
self.suffle = suffle | ||
self.data_passes = [0] | ||
self.subsets_used = [] | ||
self.subset_num = -1 | ||
self.index = 0 | ||
|
||
if self.replacement is False: | ||
|
||
# create a list of subsets without replacement, first permutation | ||
self.list_of_subsets = np.random.choice(range(self.num_subsets),self.num_subsets, replace=False) | ||
|
||
if self.suffle not in ["random","single"]: | ||
raise NotImplementedError("Only {} and {} are implemented for the replacement=False case.".format("random","single")) | ||
|
||
|
||
@property | ||
def num_subsets(self): | ||
return self.num_functions | ||
def __init__(self, functions, selection=None): | ||
|
||
super(ApproximateGradientSumFunction, self).__init__(*functions) | ||
|
||
self.functions_used = [] | ||
self.functions_passes = [0] | ||
This comment has been minimized.
Sorry, something went wrong.
zeljkozeljko
|
||
|
||
if isinstance(selection, str): | ||
self.selection = FunctionNumberGenerator(len(functions), sampling_method=selection ) | ||
else: | ||
self.selection = selection | ||
|
||
def __call__(self, x): | ||
|
||
|
@@ -392,51 +380,28 @@ def __call__(self, x): | |
""" | ||
|
||
return super(SubsetSumFunction, self).__call__(x) | ||
return super(ApproximateGradientSumFunction, self).__call__(x) | ||
|
||
def full_gradient(self, x, out=None): | ||
|
||
r""" Computes the full gradient at :code:`x`. It is the sum of all the gradients for each function. """ | ||
return super(SubsetSumFunction, self).gradient(x, out=out) | ||
return super(ApproximateGradientSumFunction, self).gradient(x, out=out) | ||
|
||
def gradient(self, x, out=None): | ||
|
||
""" Computes the gradient for each subset function at :code:`x`.""" | ||
def approximate_gradient(self, function_num, x, out=None): | ||
|
||
""" Computes the gradient for each selected function at :code:`x`.""" | ||
raise NotImplemented | ||
|
||
def gradient(self, x, out=None): | ||
|
||
self.next_function_num() | ||
return self.approximate_gradient(self.function_num, x, out=out) | ||
|
||
def next_subset(self): | ||
def next_function_num(self): | ||
|
||
if self.sampling=="random" : | ||
|
||
if self.replacement is False: | ||
|
||
# list of subsets already permuted | ||
self.subset_num = self.list_of_subsets[self.index] | ||
self.index+=1 | ||
|
||
if self.index == self.num_subsets: | ||
self.index=0 | ||
|
||
# For random suffle, at the end of each epoch, we permute the list again | ||
if self.suffle=="random": | ||
self.list_of_subsets = np.random.choice(range(self.num_subsets),self.num_subsets, replace=False) | ||
|
||
else: | ||
|
||
# at each iteration (not epoch) a subset is randomly selected | ||
self.subset_num = np.random.randint(0, self.num_subsets) | ||
|
||
elif self.sampling=="sequential": | ||
|
||
if self.subset_num + 1 >= self.num_subsets: | ||
self.subset_num = 0 | ||
else: | ||
self.subset_num += 1 | ||
else: | ||
raise NotImplementedError("Only {} and {} are implemented at the moment.".format("random","sequential")) | ||
|
||
self.subsets_used.append(self.subset_num) | ||
# Select the next function using the selection:=function_number_generator | ||
self.function_num = next(self.selection) | ||
self.functions_used.append(self.function_num) | ||
|
||
class ScaledFunction(Function): | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
Wrappers/Python/test/test_ApproximateGradientSumFunction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import unittest | ||
from utils import initialise_tests | ||
from cil.optimisation.operators import MatrixOperator | ||
from cil.optimisation.functions import LeastSquares, ApproximateGradientSumFunction | ||
from cil.optimisation.utilities import FunctionNumberGenerator | ||
from cil.framework import VectorData | ||
import numpy as np | ||
|
||
|
||
initialise_tests() | ||
|
||
class TestApproximateGradientSumFunction(unittest.TestCase): | ||
|
||
def setUp(self): | ||
|
||
np.random.seed(10) | ||
n = 50 | ||
m = 500 | ||
self.n_subsets = 10 | ||
|
||
Anp = np.random.uniform(0,1, (m, n)).astype('float32') | ||
xnp = np.random.uniform(0,1, (n,)).astype('float32') | ||
bnp = Anp.dot(xnp) | ||
|
||
Ai = np.vsplit(Anp, self.n_subsets) | ||
bi = [bnp[i:i+n] for i in range(0, m, n)] | ||
|
||
self.Aop = MatrixOperator(Anp) | ||
self.bop = VectorData(bnp) | ||
ig = self.Aop.domain | ||
self.x_cil = ig.allocate('random') | ||
|
||
self.fi_cil = [] | ||
for i in range(self.n_subsets): | ||
Ai_cil = MatrixOperator(Ai[i]) | ||
bi_cil = VectorData(bi[i]) | ||
self.fi_cil.append(LeastSquares(Ai_cil, bi_cil, c=1.0)) | ||
|
||
self.f = LeastSquares(self.Aop, b=self.bop, c=1.0) | ||
generator = FunctionNumberGenerator(self.n_subsets) | ||
self.f_subset_sum_function = ApproximateGradientSumFunction(self.fi_cil, generator) # default with replacement | ||
|
||
generator = FunctionNumberGenerator(self.n_subsets, sampling_method="random_permutation") | ||
self.f_subset_sum_function_random_suffle = ApproximateGradientSumFunction(self.fi_cil, generator) | ||
|
||
generator = FunctionNumberGenerator(self.n_subsets, sampling_method="fixed_permutation") | ||
self.f_subset_sum_function_single_suffle = ApproximateGradientSumFunction(self.fi_cil, generator) | ||
|
||
def test_call_method(self): | ||
|
||
res1 = self.f(self.x_cil) | ||
res2 = self.f_subset_sum_function(self.x_cil) | ||
np.testing.assert_allclose(res1, res2) | ||
|
||
def test_full_gradient(self): | ||
|
||
res1 = self.f.gradient(self.x_cil) | ||
res2 = self.f_subset_sum_function.full_gradient(self.x_cil) | ||
np.testing.assert_allclose(res1.array, res2.array, atol=1e-3) | ||
|
||
# def test_sampling_sequential(self): | ||
|
||
# # check sequential selection | ||
# for i in range(self.n_subsets): | ||
# self.f_subset_sum_function_sequential.next_subset() | ||
# np.testing.assert_equal(self.f_subset_sum_function_sequential.subset_num, i) | ||
|
||
def test_sampling_random_with_replacement(self): | ||
|
||
# check random selection with replacement | ||
epochs = 3 | ||
choices = [] | ||
for i in range(epochs): | ||
for j in range(self.n_subsets): | ||
self.f_subset_sum_function.next_function_num() | ||
choices.append(self.f_subset_sum_function.function_num) | ||
self.assertTrue( choices == self.f_subset_sum_function.functions_used) | ||
|
||
def test_sampling_random_without_replacement_random_suffle(self): | ||
|
||
# check random selection with no replacement | ||
epochs = 3 | ||
choices = [] | ||
for i in range(epochs): | ||
for j in range(self.n_subsets): | ||
self.f_subset_sum_function_random_suffle.next_function_num() | ||
choices.append(self.f_subset_sum_function_random_suffle.function_num) | ||
self.assertTrue( choices == self.f_subset_sum_function_random_suffle.functions_used) | ||
|
||
def test_sampling_random_without_replacement_single_suffle(self): | ||
|
||
# check random selection with no replacement | ||
epochs = 3 | ||
choices = [] | ||
for i in range(epochs): | ||
for j in range(self.n_subsets): | ||
self.f_subset_sum_function_single_suffle.next_function_num() | ||
choices.append(self.f_subset_sum_function_single_suffle.function_num) | ||
self.assertTrue( choices == self.f_subset_sum_function_single_suffle.functions_used) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I'm a bit confused by the use of "represents" here. Would something along the lines of
ApproximateGradientFunction
computes an approximate gradient of the sum function .... be better?