"""
Distributions
=============
Classes for sampling and evaluating some relevant distributions.
"""
from typing import Callable, Protocol, TypeAlias
import jax
import jax.numpy as jnp
DrawFn: TypeAlias = Callable[[jax.Array, tuple], tuple[jax.Array, jax.Array]]
"""Prototype definition of a drawing function."""
[docs]
class Distr(Protocol):
"""
Abstract parent distribution class.
:param int dim: space dimension
"""
dim: int
def __init__(self, dim: int):
self.dim = dim
[docs]
def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]:
"""
Samples the distribution.
:param ArrayLike key: PRNG key used as the random key.
:param tuple shape: shape of the sample.
:return: the actualized PRNG key and the samples with shape shape + (py:attr:`dim`,)
:rtype: tuple
"""
# pylint: disable=unnecessary-ellipsis
...
[docs]
def eval(self, x: jax.Array) -> jax.Array:
"""
Evaluates the distribution on a set of points
:param array x: points array, with shape (..., py:attr:`dim`)
:return: the values of the distribution on the given points
:rtype: Array
"""
return jnp.exp(self.leval(x))
[docs]
def leval(self, x: jax.Array) -> jax.Array:
"""
Evaluates the log distribution on a set of points
:param array x: points array, with shape (..., py:attr:`dim`)
:return: the values of the log distribution on the given points
:rtype: Array
"""
# pylint: disable=unnecessary-ellipsis
...
[docs]
class Normal(Distr):
"""
Normal distribution with identity covariance.
:param int dim: space dimension
"""
lnorm: jax.Array
def __init__(self, dim: int):
super().__init__(dim)
self.lnorm = (self.dim / 2) * jnp.log(2 * jnp.pi)
[docs]
def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]:
key, k_x = jax.random.split(key)
return key, jax.random.normal(k_x, shape=shape + (self.dim,))
[docs]
def leval(self, x: jax.Array) -> jax.Array:
return -self.lnorm - 0.5 * jnp.sum(x**2, axis=-1)
[docs]
class MVNormal(Normal):
"""Multivariate Normal distribution
:param cov: covariance matrix
:param scalar: if true the distribution evaluation is not vectorized (this is
useful to use the distribution as a likelihood).
:param norm: if true the evaluation returns normalized values.
"""
cov: jax.Array
"""Set of covariance matrices"""
icov: jax.Array
"""Inverse covariance matrices"""
lower: jax.Array
"""Cholesky lower matrices"""
scalar: bool
def __init__(self, cov: jax.Array, scalar: bool = False, norm=True):
if len(cov.shape) == 2:
cov = cov[None, :, :]
if len(cov.shape) != 3:
raise ValueError(
"The list of covariance matrices should have shape (c, d, d)."
)
super().__init__(cov.shape[-1])
self.cov = cov
self.lower = jnp.linalg.cholesky(self.cov)
self.icov = jnp.linalg.inv(self.cov)
self.lnorm = (
(self.dim / 2) * jnp.log(2 * jnp.pi) + jnp.log(jnp.linalg.det(self.cov)) / 2
if norm
else jnp.zeros(self.cov.shape[0])
)
if scalar and (cov.shape[0] > 1):
raise ValueError("Scalar evaluation is not possible with more than 1 chain")
self.scalar = scalar
[docs]
def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]:
key, k_n = jax.random.split(key)
key, sample = super().sample(k_n, shape + (self.cov.shape[0],))
sample = jnp.einsum("cij,...cj->...ci", self.lower, sample)
return key, sample
[docs]
def leval(self, x: jax.Array) -> jax.Array:
return self._vleval(x[None, :])[0] if self.scalar else self._vleval(x)
def _vleval(self, x: jax.Array) -> jax.Array:
return -self.lnorm - 0.5 * jnp.einsum("ci,cij,cj->c", x, self.icov, x)
[docs]
@staticmethod
def draw_cov(
key: jax.Array,
shape: tuple,
minev: jax.Array | float = 0.5,
maxev: jax.Array | float = 1.0,
) -> tuple[jax.Array, jax.Array]:
"""Draw a random set of positive defined covariance matrix.
:param key: PRNG key
:param shape: shape of the set of covariance. The last dimension of the
tuple must be the dimension of the space. So that each
covariance matrix will have shape (shape[-1], shape[-1])
:param minev: lower bound(s) for the eigenvalues
:param maxenv: higher bound(s) for the eigenvalues
:return: PRNG key and the set of covariance matrices.
"""
keys = jax.random.split(key, 3)
m = jax.random.uniform(keys[0], shape + (shape[-1],), minval=0, maxval=1)
m += jnp.swapaxes(m, -1, -2)
qr = jnp.linalg.qr(m)
m = qr.Q * jnp.diagonal(jnp.sign(qr.R), axis1=-1, axis2=-2).reshape(
*shape[:-1], 1, shape[-1]
)
d = jax.random.uniform(keys[1], shape, minval=minev, maxval=maxev)
d = d.reshape(*shape, 1)
d *= jnp.eye(shape[-1]).reshape(*[1 for _ in shape[:-1]], shape[-1], shape[-1])
m = jnp.matmul(jnp.matmul(m, d), jnp.swapaxes(m, -1, -2))
return keys[2], m
[docs]
class StudentT(Distr):
r"""
Multivarate student-t with identity covariance.
:param int dim: space dimension
:param float nu: student-t :math:`\nu` parameter
"""
nu: float = 5.0
def __init__(self, dim: int, nu: float = 5.0):
super().__init__(dim)
self.normal = Normal(dim)
self.nu = nu
[docs]
def sample(self, key: jax.Array, shape: tuple) -> tuple[jax.Array, jax.Array]:
key, _z = self.normal.sample(key, shape)
key, _gkey = jax.random.split(key)
_g = jax.random.gamma(_gkey, a=self.nu / 2.0, shape=shape + (1,))
_g *= 2.0 / self.nu
return key, _z / jnp.sqrt(_g)
# TODO: this has to be finished!!!
[docs]
def leval(self, x: jax.Array) -> jax.Array:
return x