Source code for jexplore.tools.covariance

"""
Covariance computation
======================

This modules provides the Basic functions for covariance computation.
"""

import jax
import jax.numpy as jnp


def _default_initial_covariance(nchain, ndim, value=1e-3):
    """Return a set of constant diagonal covariance per chain."""
    cov = value * jnp.eye(ndim)
    covs = cov[:, :, jnp.newaxis] + jnp.zeros((nchain))
    return covs.T  # nchain, ndim, ndim


[docs] def compute_cov( samples: jax.Array, covs: jax.Array | None = None, cholesky: bool = True, default: float = 1e-3, ) -> jax.Array: """ Compute the covariance from a set of samples :param array samples: array of the past samples (nchains, dim, npoints). :param array covs: (nchain, dim, dim) array bearing one covariance matrices of all chains. Default: None, i.e. the covariance matrices are computed from samples. :param bool cholesky: if true returns the array of lower triangular matrices resulting from cholesky decomposition of the convariance matrices of all chains. Otherwise it returns the array of the covariance matrices. Default: True. :param float default: if covs is none and the samples only have 1 point the function will return a constant diagonal covariance for all chains. This parameters provides the value for the covariance diagonal terms in this case. Defaut: 1e-3. :return: array of the covariance matricaes of the chains or of their cholesky lower triangular matrices. :rtype: array """ if covs is not None: pass elif samples.shape[2] == 1: covs = _default_initial_covariance( samples.shape[0], samples.shape[1], value=default ) else: covs = jax.vmap(lambda x: jnp.cov(x, ddof=1))(samples) if cholesky: return jnp.linalg.cholesky(covs) return covs