Source code for jexplore.tools.covariance
"""
Covariance computation
======================
This modules provides the Basic functions for covariance computation.
"""
import jax
import jax.numpy as jnp
def _default_initial_covariance(nchain, ndim, value=1e-3):
"""Return a set of constant diagonal covariance per chain."""
cov = value * jnp.eye(ndim)
covs = cov[:, :, jnp.newaxis] + jnp.zeros((nchain))
return covs.T # nchain, ndim, ndim
[docs]
def compute_cov(
samples: jax.Array,
covs: jax.Array | None = None,
cholesky: bool = True,
default: float = 1e-3,
) -> jax.Array:
"""
Compute the covariance from a set of samples
:param array samples: array of the past samples (nchains, dim, npoints).
:param array covs: (nchain, dim, dim) array bearing one covariance matrices
of all chains. Default: None, i.e. the covariance
matrices are computed from samples.
:param bool cholesky: if true returns the array of lower triangular matrices
resulting from cholesky decomposition of the convariance
matrices of all chains. Otherwise it returns the array
of the covariance matrices. Default: True.
:param float default: if covs is none and the samples only have 1 point the
function will return a constant diagonal covariance for
all chains. This parameters provides the value for the
covariance diagonal terms in this case. Defaut: 1e-3.
:return: array of the covariance matricaes of the chains or
of their cholesky lower triangular matrices.
:rtype: array
"""
if covs is not None:
pass
elif samples.shape[2] == 1:
covs = _default_initial_covariance(
samples.shape[0], samples.shape[1], value=default
)
else:
covs = jax.vmap(lambda x: jnp.cov(x, ddof=1))(samples)
if cholesky:
return jnp.linalg.cholesky(covs)
return covs