Skip to content

Commit

Permalink
[Hierarchical] Remove dependence on constraint_mat (#2968)
Browse files Browse the repository at this point in the history
  • Loading branch information
rshyamsundar authored Aug 17, 2023
1 parent 6058ccb commit 0cb059e
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 99 deletions.
2 changes: 0 additions & 2 deletions src/gluonts/mx/model/deepvar_hierarchical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@

# Relative imports
from ._estimator import (
constraint_mat,
projection_mat,
DeepVARHierarchicalEstimator,
)
from ._network import reconcile_samples, coherency_error

__all__ = [
"DeepVARHierarchicalEstimator",
"constraint_mat",
"projection_mat",
"reconcile_samples",
"coherency_error",
Expand Down
41 changes: 4 additions & 37 deletions src/gluonts/mx/model/deepvar_hierarchical/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,6 @@
logger = logging.getLogger(__name__)


def constraint_mat(S: np.ndarray) -> np.ndarray:
"""
Generates the constraint matrix in the equation: Ay = 0 (y being the
values/forecasts of all time series in the hierarchy).
Parameters
----------
S
Summation or aggregation matrix. Shape:
(total_num_time_series, num_bottom_time_series)
Returns
-------
Numpy ND array
Coefficient matrix of the linear constraints, shape
(num_agg_time_series, num_time_series)
"""

# Re-arrange S matrix to form A matrix
# S = [S_agg|I_m_K]^T dim:(m,m_K)
# A = [I_magg | -S_agg] dim:(m_agg,m)

m, m_K = S.shape
m_agg = m - m_K

# The top `m_agg` rows of the matrix `S` give the aggregation constraint
# matrix.
S_agg = S[:m_agg, :]
A = np.hstack((np.eye(m_agg), -S_agg))
return A


def projection_mat(
S: np.ndarray,
D: Optional[np.ndarray] = None,
Expand All @@ -85,7 +53,7 @@ def projection_mat(
.. math::
P = S (S^T S)^{-1} S^T, if D is None,\\
P = S (S^T D S)^{-1} S^TD, otherwise.
P = S (S^T D S)^{-1} S^T D, otherwise.
Parameters
----------
Expand Down Expand Up @@ -304,11 +272,10 @@ def __init__(
not coherent_train_samples
), "Cannot project only during training (and not during prediction)"

A = constraint_mat(S.astype(self.dtype))
M = projection_mat(S=S, D=D)
self.S = S
ctx = self.trainer.ctx
self.M = mx.nd.array(M, ctx=ctx)
self.A = mx.nd.array(A, ctx=ctx)
self.num_samples_for_loss = num_samples_for_loss
self.likelihood_weight = likelihood_weight
self.CRPS_weight = CRPS_weight
Expand All @@ -322,7 +289,7 @@ def __init__(
def create_training_network(self) -> DeepVARHierarchicalTrainingNetwork:
return DeepVARHierarchicalTrainingNetwork(
M=self.M,
A=self.A,
S=self.S,
num_samples_for_loss=self.num_samples_for_loss,
likelihood_weight=self.likelihood_weight,
CRPS_weight=self.CRPS_weight,
Expand Down Expand Up @@ -354,7 +321,7 @@ def create_predictor(

prediction_network = DeepVARHierarchicalPredictionNetwork(
M=self.M,
A=self.A,
S=self.S,
log_coherency_error=self.log_coherency_error,
coherent_pred_samples=self.coherent_pred_samples,
target_dim=self.target_dim,
Expand Down
57 changes: 20 additions & 37 deletions src/gluonts/mx/model/deepvar_hierarchical/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,66 +97,47 @@ def reconcile_samples(
return out


def coherency_error(A: Tensor, samples: Tensor) -> float:
def coherency_error(S: np.ndarray, samples: np.ndarray) -> float:
r"""
Computes the maximum relative coherency error among all the aggregated
time series
Computes the maximum relative coherency error
.. math::
\max_i \frac{|y_i - s_i|} {|y_i|},
\max_i | (S @ y_b)_i - y_i | / y_i
where :math:`i` refers to the aggregated time series index, :math:`y_i` is
the (direct) forecast obtained for the :math:`i^{th}` time series
and :math:`s_i` is its aggregated forecast obtained by summing the
corresponding bottom-level forecasts. If :math:`y_i` is zero, then the
absolute difference, :math:`|s_i|`, is used instead.
This can be comupted as follows given the constraint matrix A:
.. math::
\max \frac{|A \times samples|} {|samples[:r]|},
where :math:`r` is the number aggregated time series.
where :math:`y` refers to the `samples` and :math:`y_b` refers to the
samples at the bottom level.
Parameters
----------
A
The constraint matrix A in the equation: Ay = 0 (y being the
values/forecasts of all time series in the hierarchy).
S
The summation matrix S. Shape:
(total_num_time_series, num_bottom_time_series)
samples
Samples. Shape: `(*batch_shape, target_dim)`.
Returns
-------
Float
Coherency error
"""
samples_bottom_level = samples[..., -S.shape[1] :]

num_agg_ts = A.shape[0]
forecasts_agg_ts = samples.slice_axis(
axis=-1, begin=0, end=num_agg_ts
).asnumpy()

abs_err = mx.nd.abs(mx.nd.dot(samples, A, transpose_b=True)).asnumpy()
rel_err = np.where(
forecasts_agg_ts == 0,
abs_err,
abs_err / np.abs(forecasts_agg_ts),
errs = np.abs(samples_bottom_level @ S.T - samples)
rel_errs = np.where(
samples == 0.0,
errs,
errs / np.abs(samples),
)

return np.max(rel_err)
return rel_errs.max()


class DeepVARHierarchicalNetwork(DeepVARNetwork):
@validated()
def __init__(
self,
M,
A,
S,
num_layers: int,
num_cells: int,
cell_type: str,
Expand Down Expand Up @@ -191,7 +172,7 @@ def __init__(
)

self.M = M
self.A = A
self.S = S
self.seq_axis = seq_axis

def get_samples_for_loss(self, distr: Distribution) -> Tensor:
Expand Down Expand Up @@ -312,7 +293,9 @@ def post_process_samples(self, samples: Tensor) -> Tensor:

# Show coherency error: A*X_proj
if self.log_coherency_error:
coh_error = coherency_error(self.A, samples=samples_to_return)
coh_error = coherency_error(
S=self.S, samples=samples_to_return.asnumpy()
)
logger.info(
"Coherency error of the predicted samples for time step"
f" {self.forecast_time_step}: {coh_error}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
from gluonts.itertools import prod
from gluonts.mx.model.deepar import DeepAREstimator
from gluonts.mx.model.deepar._network import DeepARPredictionNetwork
from gluonts.mx.model.deepvar_hierarchical._estimator import (
constraint_mat,
projection_mat,
)
from gluonts.mx.model.deepvar_hierarchical._estimator import projection_mat
from gluonts.mx.model.deepvar_hierarchical._network import coherency_error
from gluonts.mx.distribution import Distribution, EmpiricalDistribution
from gluonts.mx import Tensor
Expand Down Expand Up @@ -91,14 +88,13 @@ def __init__(
self.loss_function = loss_function
self.dtype = dtype

A = constraint_mat(self.temporal_hierarchy.agg_mat)
if naive_reconciliation:
M = utils.naive_reconcilation_mat(
self.temporal_hierarchy.agg_mat, self.temporal_hierarchy.nodes
)
else:
M = projection_mat(S=self.temporal_hierarchy.agg_mat)
self.M, self.A = mx.nd.array(M), mx.nd.array(A)
self.M = mx.nd.array(M)

self.estimators = estimators

Expand Down Expand Up @@ -681,7 +677,8 @@ def sampling_decoder(
reconciled_samples_at_all_levels = samples_at_all_levels

rec_err = coherency_error(
A=self.A, samples=reconciled_samples_at_all_levels
S=self.temporal_hierarchy.agg_mat,
samples=reconciled_samples_at_all_levels.asnumpy(),
)
print(f"Reconciliation error: {rec_err}")

Expand Down
10 changes: 3 additions & 7 deletions test/mx/model/deepvar_hierarchical/test_coherency_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import mxnet as mx
import numpy as np
import pytest

from gluonts.mx.model.deepvar_hierarchical import (
constraint_mat,
coherency_error,
)
from gluonts.mx.model.deepvar_hierarchical import coherency_error


TOL = 1e-4

Expand All @@ -35,7 +32,6 @@
)

num_bottom_ts = S.shape[1]
A = constraint_mat(S)


@pytest.mark.parametrize(
Expand All @@ -56,4 +52,4 @@
def test_coherency_error(bottom_ts):
all_ts = S @ bottom_ts

assert coherency_error(mx.nd.array(A), mx.nd.array(all_ts)) < TOL
assert coherency_error(S, all_ts) < TOL
40 changes: 34 additions & 6 deletions test/mx/model/deepvar_hierarchical/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import numpy as np
import pytest

from gluonts.mx.model.deepvar_hierarchical._estimator import (
constraint_mat,
projection_mat,
)
from gluonts.mx.model.deepvar_hierarchical._estimator import projection_mat


TOL = 1e-12
Expand All @@ -37,7 +34,38 @@
)

num_bottom_ts = S.shape[1]
A = constraint_mat(S)


def constraint_mat(S: np.ndarray) -> np.ndarray:
"""
Generates the constraint matrix in the equation: Ay = 0 (y being the
values/forecasts of all time series in the hierarchy).
Parameters
----------
S
Summation or aggregation matrix. Shape:
(total_num_time_series, num_bottom_time_series)
Returns
-------
Numpy ND array
Coefficient matrix of the linear constraints, shape
(num_agg_time_series, num_time_series)
"""

# Re-arrange S matrix to form A matrix
# S = [S_agg|I_m_K]^T dim:(m,m_K)
# A = [I_magg | -S_agg] dim:(m_agg,m)

m, m_K = S.shape
m_agg = m - m_K

# The top `m_agg` rows of the matrix `S` give the aggregation constraint
# matrix.
S_agg = S[:m_agg, :]
A = np.hstack((np.eye(m_agg), -S_agg))
return A


def null_space_projection_mat(
Expand Down Expand Up @@ -98,6 +126,6 @@ def null_space_projection_mat(
],
)
def test_projection_mat(D):
p1 = null_space_projection_mat(A=A, D=D)
p1 = null_space_projection_mat(A=constraint_mat(S), D=D)
p2 = projection_mat(S=S, D=D)
assert (np.abs(p1 - p2)).sum() < TOL
4 changes: 1 addition & 3 deletions test/mx/model/deepvar_hierarchical/test_reconcile_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest

from gluonts.mx.model.deepvar_hierarchical import (
constraint_mat,
projection_mat,
reconcile_samples,
coherency_error,
Expand All @@ -37,7 +36,6 @@
)

num_bottom_ts = S.shape[1]
A = constraint_mat(S)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -86,4 +84,4 @@ def test_reconciliation_error(samples, D, seq_axis):
seq_axis=seq_axis,
)

assert coherency_error(mx.nd.array(A), coherent_samples) < TOL
assert coherency_error(S, coherent_samples.asnumpy()) < TOL

0 comments on commit 0cb059e

Please sign in to comment.