Source code for jexplore.tools.distributions

"""
Distributions
=============

Classes for sampling and evaluating some relevant distributions.
"""

from typing import Callable, Protocol, TypeAlias

import jax
import jax.numpy as jnp

DrawFn: TypeAlias = Callable[[jax.Array, tuple], tuple[jax.Array, jax.Array]]
"""Prototype definition of a drawing function."""


[docs] class Distr(Protocol): """ Abstract parent distribution class. :param int dim: space dimension """ dim: int def __init__(self, dim: int): self.dim = dim
[docs] def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]: """ Samples the distribution. :param ArrayLike key: PRNG key used as the random key. :param tuple shape: shape of the sample. :return: the actualized PRNG key and the samples with shape shape + (py:attr:`dim`,) :rtype: tuple """ # pylint: disable=unnecessary-ellipsis ...
[docs] def eval(self, x: jax.Array) -> jax.Array: """ Evaluates the distribution on a set of points :param array x: points array, with shape (..., py:attr:`dim`) :return: the values of the distribution on the given points :rtype: Array """ return jnp.exp(self.leval(x))
[docs] def leval(self, x: jax.Array) -> jax.Array: """ Evaluates the log distribution on a set of points :param array x: points array, with shape (..., py:attr:`dim`) :return: the values of the log distribution on the given points :rtype: Array """ # pylint: disable=unnecessary-ellipsis ...
[docs] class Uniform(Distr): """ Constant distribution in a box: :param int dim: space dimension :param RealArray minval: minimum (inclusive) value broadcast-compatible with shape for the range (default 0). :param RealArray maxval: maximum (exclusive) value broadcast-compatible with shape for the range (default 1). """ minval: jax.Array maxval: jax.Array lvol: jax.Array def __init__( self, dim: int, minval: jax.Array | float = 0.0, maxval: jax.Array | float = 1.0 ): super().__init__(dim) self.minval = jnp.array(minval) * jnp.ones(dim) self.maxval = jnp.array(maxval) * jnp.ones(dim) self.lvol = jnp.sum(jnp.log(self.maxval - self.minval))
[docs] def inbox(self, x: jax.Array) -> jax.Array: """Check if values are in the support of the distribution :param x: values to be checked (last dimension of the shape should be dimension of the space on which the distribution is defined. :return: the boolean array of the results of the test.""" return jnp.all((x > self.minval) & (x < self.maxval), axis=-1)
[docs] def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]: key, k_x = jax.random.split(key) return key, jax.random.uniform( k_x, shape=shape + (self.dim,), minval=self.minval, maxval=self.maxval )
[docs] def leval(self, x: jax.Array) -> jax.Array: return jnp.where(self.inbox(x), -self.lvol, -jnp.inf)
[docs] def eval(self, x: jax.Array) -> jax.Array: return jnp.where(self.inbox(x), jnp.exp(-self.lvol), 0.0)
[docs] class Normal(Distr): """ Normal distribution with identity covariance. :param int dim: space dimension """ lnorm: jax.Array def __init__(self, dim: int): super().__init__(dim) self.lnorm = (self.dim / 2) * jnp.log(2 * jnp.pi)
[docs] def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]: key, k_x = jax.random.split(key) return key, jax.random.normal(k_x, shape=shape + (self.dim,))
[docs] def leval(self, x: jax.Array) -> jax.Array: return -self.lnorm - 0.5 * jnp.sum(x**2, axis=-1)
[docs] class MVNormal(Normal): """Multivariate Normal distribution :param cov: covariance matrix :param scalar: if true the distribution evaluation is not vectorized (this is useful to use the distribution as a likelihood). :param norm: if true the evaluation returns normalized values. """ cov: jax.Array """Set of covariance matrices""" icov: jax.Array """Inverse covariance matrices""" lower: jax.Array """Cholesky lower matrices""" scalar: bool def __init__(self, cov: jax.Array, scalar: bool = False, norm=True): if len(cov.shape) == 2: cov = cov[None, :, :] if len(cov.shape) != 3: raise ValueError( "The list of covariance matrices should have shape (c, d, d)." ) super().__init__(cov.shape[-1]) self.cov = cov self.lower = jnp.linalg.cholesky(self.cov) self.icov = jnp.linalg.inv(self.cov) self.lnorm = ( (self.dim / 2) * jnp.log(2 * jnp.pi) + jnp.log(jnp.linalg.det(self.cov)) / 2 if norm else jnp.zeros(self.cov.shape[0]) ) if scalar and (cov.shape[0] > 1): raise ValueError("Scalar evaluation is not possible with more than 1 chain") self.scalar = scalar
[docs] def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]: key, k_n = jax.random.split(key) key, sample = super().sample(k_n, shape + (self.cov.shape[0],)) sample = jnp.einsum("cij,...cj->...ci", self.lower, sample) return key, sample
[docs] def leval(self, x: jax.Array) -> jax.Array: return self._vleval(x[None, :])[0] if self.scalar else self._vleval(x)
def _vleval(self, x: jax.Array) -> jax.Array: return -self.lnorm - 0.5 * jnp.einsum("ci,cij,cj->c", x, self.icov, x)
[docs] @staticmethod def draw_cov( key: jax.Array, shape: tuple, minev: jax.Array | float = 0.5, maxev: jax.Array | float = 1.0, ) -> tuple[jax.Array, jax.Array]: """Draw a random set of positive defined covariance matrix. :param key: PRNG key :param shape: shape of the set of covariance. The last dimension of the tuple must be the dimension of the space. So that each covariance matrix will have shape (shape[-1], shape[-1]) :param minev: lower bound(s) for the eigenvalues :param maxenv: higher bound(s) for the eigenvalues :return: PRNG key and the set of covariance matrices. """ keys = jax.random.split(key, 3) m = jax.random.uniform(keys[0], shape + (shape[-1],), minval=0, maxval=1) m += jnp.swapaxes(m, -1, -2) qr = jnp.linalg.qr(m) m = qr.Q * jnp.diagonal(jnp.sign(qr.R), axis1=-1, axis2=-2).reshape( *shape[:-1], 1, shape[-1] ) d = jax.random.uniform(keys[1], shape, minval=minev, maxval=maxev) d = d.reshape(*shape, 1) d *= jnp.eye(shape[-1]).reshape(*[1 for _ in shape[:-1]], shape[-1], shape[-1]) m = jnp.matmul(jnp.matmul(m, d), jnp.swapaxes(m, -1, -2)) return keys[2], m
[docs] class StudentT(Distr): r""" Multivarate student-t with identity covariance. :param int dim: space dimension :param float nu: student-t :math:`\nu` parameter """ nu: float = 5.0 def __init__(self, dim: int, nu: float = 5.0): super().__init__(dim) self.normal = Normal(dim) self.nu = nu
[docs] def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]: key, _z = self.normal.sample(key, shape) key, _gkey = jax.random.split(key) _g = jax.random.gamma(_gkey, a=self.nu / 2.0, shape=shape + (1,)) _g *= 2.0 / self.nu return key, _z / jnp.sqrt(_g)
# TODO: this has to be finished!!!
[docs] def leval(self, x: jax.Array) -> jax.Array: return x