From 0cb059e43926a1af8d1a107175b1e38a0ae404ad Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Thu, 17 Aug 2023 13:49:18 +0200 Subject: [PATCH] [Hierarchical] Remove dependence on constraint_mat (#2968) --- .../mx/model/deepvar_hierarchical/__init__.py | 2 - .../model/deepvar_hierarchical/_estimator.py | 41 ++----------- .../mx/model/deepvar_hierarchical/_network.py | 57 +++++++------------ .../model/cop_deepar/_network.py | 11 ++-- .../test_coherency_error.py | 10 +--- .../deepvar_hierarchical/test_projection.py | 40 +++++++++++-- .../test_reconcile_samples.py | 4 +- 7 files changed, 66 insertions(+), 99 deletions(-) diff --git a/src/gluonts/mx/model/deepvar_hierarchical/__init__.py b/src/gluonts/mx/model/deepvar_hierarchical/__init__.py index 47180f9a84..410d3bbb9e 100755 --- a/src/gluonts/mx/model/deepvar_hierarchical/__init__.py +++ b/src/gluonts/mx/model/deepvar_hierarchical/__init__.py @@ -13,7 +13,6 @@ # Relative imports from ._estimator import ( - constraint_mat, projection_mat, DeepVARHierarchicalEstimator, ) @@ -21,7 +20,6 @@ __all__ = [ "DeepVARHierarchicalEstimator", - "constraint_mat", "projection_mat", "reconcile_samples", "coherency_error", diff --git a/src/gluonts/mx/model/deepvar_hierarchical/_estimator.py b/src/gluonts/mx/model/deepvar_hierarchical/_estimator.py index 6d81af4652..2fe9fcbb7c 100755 --- a/src/gluonts/mx/model/deepvar_hierarchical/_estimator.py +++ b/src/gluonts/mx/model/deepvar_hierarchical/_estimator.py @@ -41,38 +41,6 @@ logger = logging.getLogger(__name__) -def constraint_mat(S: np.ndarray) -> np.ndarray: - """ - Generates the constraint matrix in the equation: Ay = 0 (y being the - values/forecasts of all time series in the hierarchy). - - Parameters - ---------- - S - Summation or aggregation matrix. Shape: - (total_num_time_series, num_bottom_time_series) - - Returns - ------- - Numpy ND array - Coefficient matrix of the linear constraints, shape - (num_agg_time_series, num_time_series) - """ - - # Re-arrange S matrix to form A matrix - # S = [S_agg|I_m_K]^T dim:(m,m_K) - # A = [I_magg | -S_agg] dim:(m_agg,m) - - m, m_K = S.shape - m_agg = m - m_K - - # The top `m_agg` rows of the matrix `S` give the aggregation constraint - # matrix. - S_agg = S[:m_agg, :] - A = np.hstack((np.eye(m_agg), -S_agg)) - return A - - def projection_mat( S: np.ndarray, D: Optional[np.ndarray] = None, @@ -85,7 +53,7 @@ def projection_mat( .. math:: P = S (S^T S)^{-1} S^T, if D is None,\\ - P = S (S^T D S)^{-1} S^TD, otherwise. + P = S (S^T D S)^{-1} S^T D, otherwise. Parameters ---------- @@ -304,11 +272,10 @@ def __init__( not coherent_train_samples ), "Cannot project only during training (and not during prediction)" - A = constraint_mat(S.astype(self.dtype)) M = projection_mat(S=S, D=D) + self.S = S ctx = self.trainer.ctx self.M = mx.nd.array(M, ctx=ctx) - self.A = mx.nd.array(A, ctx=ctx) self.num_samples_for_loss = num_samples_for_loss self.likelihood_weight = likelihood_weight self.CRPS_weight = CRPS_weight @@ -322,7 +289,7 @@ def __init__( def create_training_network(self) -> DeepVARHierarchicalTrainingNetwork: return DeepVARHierarchicalTrainingNetwork( M=self.M, - A=self.A, + S=self.S, num_samples_for_loss=self.num_samples_for_loss, likelihood_weight=self.likelihood_weight, CRPS_weight=self.CRPS_weight, @@ -354,7 +321,7 @@ def create_predictor( prediction_network = DeepVARHierarchicalPredictionNetwork( M=self.M, - A=self.A, + S=self.S, log_coherency_error=self.log_coherency_error, coherent_pred_samples=self.coherent_pred_samples, target_dim=self.target_dim, diff --git a/src/gluonts/mx/model/deepvar_hierarchical/_network.py b/src/gluonts/mx/model/deepvar_hierarchical/_network.py index c27dfc11f5..1764aad498 100755 --- a/src/gluonts/mx/model/deepvar_hierarchical/_network.py +++ b/src/gluonts/mx/model/deepvar_hierarchical/_network.py @@ -97,34 +97,22 @@ def reconcile_samples( return out -def coherency_error(A: Tensor, samples: Tensor) -> float: +def coherency_error(S: np.ndarray, samples: np.ndarray) -> float: r""" - Computes the maximum relative coherency error among all the aggregated - time series + Computes the maximum relative coherency error .. math:: - \max_i \frac{|y_i - s_i|} {|y_i|}, + \max_i | (S @ y_b)_i - y_i | / y_i - where :math:`i` refers to the aggregated time series index, :math:`y_i` is - the (direct) forecast obtained for the :math:`i^{th}` time series - and :math:`s_i` is its aggregated forecast obtained by summing the - corresponding bottom-level forecasts. If :math:`y_i` is zero, then the - absolute difference, :math:`|s_i|`, is used instead. - - This can be comupted as follows given the constraint matrix A: - - .. math:: - - \max \frac{|A \times samples|} {|samples[:r]|}, - - where :math:`r` is the number aggregated time series. + where :math:`y` refers to the `samples` and :math:`y_b` refers to the + samples at the bottom level. Parameters ---------- - A - The constraint matrix A in the equation: Ay = 0 (y being the - values/forecasts of all time series in the hierarchy). + S + The summation matrix S. Shape: + (total_num_time_series, num_bottom_time_series) samples Samples. Shape: `(*batch_shape, target_dim)`. @@ -132,23 +120,16 @@ def coherency_error(A: Tensor, samples: Tensor) -> float: ------- Float Coherency error - - """ + samples_bottom_level = samples[..., -S.shape[1] :] - num_agg_ts = A.shape[0] - forecasts_agg_ts = samples.slice_axis( - axis=-1, begin=0, end=num_agg_ts - ).asnumpy() - - abs_err = mx.nd.abs(mx.nd.dot(samples, A, transpose_b=True)).asnumpy() - rel_err = np.where( - forecasts_agg_ts == 0, - abs_err, - abs_err / np.abs(forecasts_agg_ts), + errs = np.abs(samples_bottom_level @ S.T - samples) + rel_errs = np.where( + samples == 0.0, + errs, + errs / np.abs(samples), ) - - return np.max(rel_err) + return rel_errs.max() class DeepVARHierarchicalNetwork(DeepVARNetwork): @@ -156,7 +137,7 @@ class DeepVARHierarchicalNetwork(DeepVARNetwork): def __init__( self, M, - A, + S, num_layers: int, num_cells: int, cell_type: str, @@ -191,7 +172,7 @@ def __init__( ) self.M = M - self.A = A + self.S = S self.seq_axis = seq_axis def get_samples_for_loss(self, distr: Distribution) -> Tensor: @@ -312,7 +293,9 @@ def post_process_samples(self, samples: Tensor) -> Tensor: # Show coherency error: A*X_proj if self.log_coherency_error: - coh_error = coherency_error(self.A, samples=samples_to_return) + coh_error = coherency_error( + S=self.S, samples=samples_to_return.asnumpy() + ) logger.info( "Coherency error of the predicted samples for time step" f" {self.forecast_time_step}: {coh_error}" diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py index ce131ebc8f..72b8329e25 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py @@ -20,10 +20,7 @@ from gluonts.itertools import prod from gluonts.mx.model.deepar import DeepAREstimator from gluonts.mx.model.deepar._network import DeepARPredictionNetwork -from gluonts.mx.model.deepvar_hierarchical._estimator import ( - constraint_mat, - projection_mat, -) +from gluonts.mx.model.deepvar_hierarchical._estimator import projection_mat from gluonts.mx.model.deepvar_hierarchical._network import coherency_error from gluonts.mx.distribution import Distribution, EmpiricalDistribution from gluonts.mx import Tensor @@ -91,14 +88,13 @@ def __init__( self.loss_function = loss_function self.dtype = dtype - A = constraint_mat(self.temporal_hierarchy.agg_mat) if naive_reconciliation: M = utils.naive_reconcilation_mat( self.temporal_hierarchy.agg_mat, self.temporal_hierarchy.nodes ) else: M = projection_mat(S=self.temporal_hierarchy.agg_mat) - self.M, self.A = mx.nd.array(M), mx.nd.array(A) + self.M = mx.nd.array(M) self.estimators = estimators @@ -681,7 +677,8 @@ def sampling_decoder( reconciled_samples_at_all_levels = samples_at_all_levels rec_err = coherency_error( - A=self.A, samples=reconciled_samples_at_all_levels + S=self.temporal_hierarchy.agg_mat, + samples=reconciled_samples_at_all_levels.asnumpy(), ) print(f"Reconciliation error: {rec_err}") diff --git a/test/mx/model/deepvar_hierarchical/test_coherency_error.py b/test/mx/model/deepvar_hierarchical/test_coherency_error.py index 623211ab89..0bc68dbac5 100644 --- a/test/mx/model/deepvar_hierarchical/test_coherency_error.py +++ b/test/mx/model/deepvar_hierarchical/test_coherency_error.py @@ -11,14 +11,11 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import mxnet as mx import numpy as np import pytest -from gluonts.mx.model.deepvar_hierarchical import ( - constraint_mat, - coherency_error, -) +from gluonts.mx.model.deepvar_hierarchical import coherency_error + TOL = 1e-4 @@ -35,7 +32,6 @@ ) num_bottom_ts = S.shape[1] -A = constraint_mat(S) @pytest.mark.parametrize( @@ -56,4 +52,4 @@ def test_coherency_error(bottom_ts): all_ts = S @ bottom_ts - assert coherency_error(mx.nd.array(A), mx.nd.array(all_ts)) < TOL + assert coherency_error(S, all_ts) < TOL diff --git a/test/mx/model/deepvar_hierarchical/test_projection.py b/test/mx/model/deepvar_hierarchical/test_projection.py index 048cb1f82b..e895b6d624 100644 --- a/test/mx/model/deepvar_hierarchical/test_projection.py +++ b/test/mx/model/deepvar_hierarchical/test_projection.py @@ -16,10 +16,7 @@ import numpy as np import pytest -from gluonts.mx.model.deepvar_hierarchical._estimator import ( - constraint_mat, - projection_mat, -) +from gluonts.mx.model.deepvar_hierarchical._estimator import projection_mat TOL = 1e-12 @@ -37,7 +34,38 @@ ) num_bottom_ts = S.shape[1] -A = constraint_mat(S) + + +def constraint_mat(S: np.ndarray) -> np.ndarray: + """ + Generates the constraint matrix in the equation: Ay = 0 (y being the + values/forecasts of all time series in the hierarchy). + + Parameters + ---------- + S + Summation or aggregation matrix. Shape: + (total_num_time_series, num_bottom_time_series) + + Returns + ------- + Numpy ND array + Coefficient matrix of the linear constraints, shape + (num_agg_time_series, num_time_series) + """ + + # Re-arrange S matrix to form A matrix + # S = [S_agg|I_m_K]^T dim:(m,m_K) + # A = [I_magg | -S_agg] dim:(m_agg,m) + + m, m_K = S.shape + m_agg = m - m_K + + # The top `m_agg` rows of the matrix `S` give the aggregation constraint + # matrix. + S_agg = S[:m_agg, :] + A = np.hstack((np.eye(m_agg), -S_agg)) + return A def null_space_projection_mat( @@ -98,6 +126,6 @@ def null_space_projection_mat( ], ) def test_projection_mat(D): - p1 = null_space_projection_mat(A=A, D=D) + p1 = null_space_projection_mat(A=constraint_mat(S), D=D) p2 = projection_mat(S=S, D=D) assert (np.abs(p1 - p2)).sum() < TOL diff --git a/test/mx/model/deepvar_hierarchical/test_reconcile_samples.py b/test/mx/model/deepvar_hierarchical/test_reconcile_samples.py index fa8fa3f48a..44f8d89b14 100644 --- a/test/mx/model/deepvar_hierarchical/test_reconcile_samples.py +++ b/test/mx/model/deepvar_hierarchical/test_reconcile_samples.py @@ -16,7 +16,6 @@ import pytest from gluonts.mx.model.deepvar_hierarchical import ( - constraint_mat, projection_mat, reconcile_samples, coherency_error, @@ -37,7 +36,6 @@ ) num_bottom_ts = S.shape[1] -A = constraint_mat(S) @pytest.mark.parametrize( @@ -86,4 +84,4 @@ def test_reconciliation_error(samples, D, seq_axis): seq_axis=seq_axis, ) - assert coherency_error(mx.nd.array(A), coherent_samples) < TOL + assert coherency_error(S, coherent_samples.asnumpy()) < TOL