You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've spent a few thoughts on the implementation of the DSS. The main issue is in the batched calculation of the covariance matrix. This is what I've came up with for numpy and numba.
importnumpyasnpfromnumbaimportguvectorizeD=15# VariablesN=50# ObservationsM=1000# Ensemble Memberssigma=np.diag(np.ones(D))
sigma[sigma==0] =0.2mu=np.arange(0, D)
observations=np.tile(mu, (N, 1))
samples=np.random.multivariate_normal(mean=mu, cov=sigma, size=(N, M))
samples.shape# Numpy has the dimension of D as the last index# This fits the scoringrules notationn_axis=-3# forecast axism_axis=-2# ensemble axisv_axis=-1# variable axisdefbatch_covariance(x, axis:int=0, rowvar:bool=False):
"""Calculate the batched covariance matrix for a 3d tensor ``x`` along choosen axis."""splitted=np.split(x, indices_or_sections=x.shape[axis], axis=axis)
cov=np.stack([np.cov(i.squeeze(), rowvar=rowvar) foriinsplitted])
returncovdefbatch_precision(x, axis:int=0, rowvar:bool=False):
"""Calculated the batched precision matrix for a 3d tensor ``x`` along choosen axis."""batched_cov=batch_covariance(x, axis=axis, rowvar=rowvar)
batched_prec=np.linalg.inv(batched_cov) # expects that cov is on the last 2 axes [.., D, D]returnbatched_prec@guvectorize( ["void(float32[:], float32[:,:], float32[:])","void(float64[:], float64[:,:], float64[:])", ],"(d),(m,d)->()",)def_dss_gufunc(obs:np.ndarray, fct:np.ndarray, out:np.ndarray):
M=fct.shape[0]
bias=obs- (np.sum(fct, axis=0) /M)
cov=np.cov(fct, rowvar=False).astype(bias.dtype)
prec=np.linalg.inv(cov).astype(bias.dtype)
log_det=np.log(np.linalg.det(cov))
bias_precision=bias.T @ prec @ biasout[0] =log_det+bias_precision# %%time## Calculation of the DSS in numpycov=batch_covariance(samples, n_axis)
precision=batch_precision(samples, n_axis)
log_det=np.log(np.linalg.det(cov))
bias= (observations-np.mean(samples, axis=m_axis))
bias_precision= (
bias[:, None, :] @
precision @
bias[:, :, None]
)
bias_precision=bias_precision.squeeze()
dss_numpy=log_det+bias_precision## numba calculationdss_numba=_dss_gufunc(observations, samples)
print(np.allclose(dss_numba, dss_numpy))
print(dss_numba)
I was assuming that tensorflow and torch would, per default, offer something like a batched cov(tensor, axis), but turns out it is only available in tensorflow_probabilityhere. Not sure you want to add another (big) dependency. Otherwise, the implementation for the other backends would probably look pretty similar.
So the general proposal would be:
Add array_cov(x) for the 2-d case and batched_cov(x, axis) for the 3-d case to the backend. Same for the precision matrix.
Add det(x) to the backend in the default behaviour as in np.linalg.det(), i.e. assuming that the array of which we want the determinant is on the last 2 axes.
Questions:
How to deal with the other backends?
Are you fine with adding a numpy/numba version and raising NotImplementedError for the other backends in an initial version?
The text was updated successfully, but these errors were encountered:
Hi,
I've spent a few thoughts on the implementation of the DSS. The main issue is in the batched calculation of the covariance matrix. This is what I've came up with for
numpy
andnumba
.I was assuming that
tensorflow
andtorch
would, per default, offer something like a batchedcov(tensor, axis)
, but turns out it is only available intensorflow_probability
here. Not sure you want to add another (big) dependency. Otherwise, the implementation for the other backends would probably look pretty similar.So the general proposal would be:
array_cov(x)
for the 2-d case andbatched_cov(x, axis)
for the 3-d case to the backend. Same for the precision matrix.det(x)
to the backend in the default behaviour as innp.linalg.det()
, i.e. assuming that the array of which we want the determinant is on the last 2 axes.Questions:
numpy
/numba
version and raisingNotImplementedError
for the other backends in an initial version?The text was updated successfully, but these errors were encountered: