Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Syama Sundar Rangapuram committed Aug 14, 2023
1 parent 6062a57 commit 5c03a33
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/gluonts/mx/model/deepvar_hierarchical/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def null_space_projection_mat(
if D is None:
return np.eye(num_ts) - A.T @ np.linalg.pinv(A @ A.T) @ A
else:
assert np.all(
np.linalg.eigvals(D) > 0
), "`D` must be positive definite."
D_inv = np.linalg.inv(D)
return (
np.eye(num_ts) - D_inv @ A.T @ np.linalg.pinv(A @ D_inv @ A.T) @ A
Expand Down
27 changes: 24 additions & 3 deletions test/mx/model/deepvar_hierarchical/test_reconcile_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

num_bottom_ts = S.shape[1]
A = constraint_mat(S)
reconciliation_mat = null_space_projection_mat(A)


@pytest.mark.parametrize(
Expand All @@ -53,6 +52,28 @@
100.0 + 2.0 * np.random.standard_normal(size=(10, 32, S.shape[0])),
],
)
@pytest.mark.parametrize(
"D",
[
None,
# Root gets the maximum weight and the two aggregated levels get
# more weight than the leaf level.
np.diag(
[4, 2, 2, 1, 1, 1, 1]
),
# Random diagonal matrix
np.diag(
np.random.rand(S.shape[0])
),
# Random positive definite matrix
np.diag(
np.random.rand(S.shape[0])
) + np.dot(
np.array([[4, 2, 2, 1, 1, 1, 1]]).T,
np.array([[4, 2, 2, 1, 1, 1, 1]])
)
],
)
@pytest.mark.parametrize(
"seq_axis",
[
Expand All @@ -63,9 +84,9 @@
[1, 0],
],
)
def test_reconciliation_error(samples, seq_axis):
def test_reconciliation_error(samples, D, seq_axis):
coherent_samples = reconcile_samples(
reconciliation_mat=mx.nd.array(reconciliation_mat),
reconciliation_mat=mx.nd.array(null_space_projection_mat(A=A, D=D)),
samples=mx.nd.array(samples),
seq_axis=seq_axis,
)
Expand Down

0 comments on commit 5c03a33

Please sign in to comment.