jexplore.tools.distributions#

Classes for sampling and evaluating some relevant distributions.

Attributes#

DrawFn

Prototype definition of a drawing function.

Classes#

Distr

Abstract parent distribution class.

Uniform

Constant distribution in a box:

Normal

Normal distribution with identity covariance.

MVNormal

Multivariate Normal distribution

StudentT

Multivarate student-t with identity covariance.

Module Contents#

type DrawFn = Callable[[jax.Array, tuple], tuple[jax.Array, jax.Array]]#

Prototype definition of a drawing function.

class Distr(dim)[source]#

Bases: Protocol

Abstract parent distribution class.

Parameters:

dim (int) – space dimension

dim: int#
sample(key, shape)[source]#

Samples the distribution.

Parameters:
  • key (ArrayLike) – PRNG key used as the random key.

  • shape (tuple) – shape of the sample.

Returns:

the actualized PRNG key and the samples with shape shape + (py:attr:dim,)

Return type:

tuple

eval(x)[source]#

Evaluates the distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the distribution on the given points

Return type:

Array

leval(x)[source]#

Evaluates the log distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the log distribution on the given points

Return type:

Array

class Uniform(dim, minval=0.0, maxval=1.0)[source]#

Bases: Distr

Constant distribution in a box:

Parameters:
  • dim (int) – space dimension

  • minval (RealArray) – minimum (inclusive) value broadcast-compatible with shape for the range (default 0).

  • maxval (RealArray) – maximum (exclusive) value broadcast-compatible with shape for the range (default 1).

minval: jax.Array#
maxval: jax.Array#
lvol: jax.Array#
inbox(x)[source]#

Check if values are in the support of the distribution

Parameters:

x (jax.Array) – values to be checked (last dimension of the shape should be dimension of the space on which the distribution is defined.

Returns:

the boolean array of the results of the test.

Return type:

jax.Array

sample(key, shape)[source]#

Samples the distribution.

Parameters:
  • key (ArrayLike) – PRNG key used as the random key.

  • shape (tuple) – shape of the sample.

Returns:

the actualized PRNG key and the samples with shape shape + (py:attr:dim,)

Return type:

tuple

leval(x)[source]#

Evaluates the log distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the log distribution on the given points

Return type:

Array

eval(x)[source]#

Evaluates the distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the distribution on the given points

Return type:

Array

class Normal(dim)[source]#

Bases: Distr

Normal distribution with identity covariance.

Parameters:

dim (int) – space dimension

lnorm: jax.Array#
sample(key, shape)[source]#

Samples the distribution.

Parameters:
  • key (ArrayLike) – PRNG key used as the random key.

  • shape (tuple) – shape of the sample.

Returns:

the actualized PRNG key and the samples with shape shape + (py:attr:dim,)

Return type:

tuple

leval(x)[source]#

Evaluates the log distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the log distribution on the given points

Return type:

Array

class MVNormal(cov, scalar=False, norm=True)[source]#

Bases: Normal

Multivariate Normal distribution

Parameters:
  • cov (jax.Array) – covariance matrix

  • scalar (bool) – if true the distribution evaluation is not vectorized (this is useful to use the distribution as a likelihood).

  • norm – if true the evaluation returns normalized values.

cov: jax.Array#

Set of covariance matrices

icov: jax.Array#

Inverse covariance matrices

lower: jax.Array#

Cholesky lower matrices

scalar: bool#
lnorm#
sample(key, shape)[source]#

Samples the distribution.

Parameters:
  • key (ArrayLike) – PRNG key used as the random key.

  • shape (tuple) – shape of the sample.

Returns:

the actualized PRNG key and the samples with shape shape + (py:attr:dim,)

Return type:

tuple

leval(x)[source]#

Evaluates the log distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the log distribution on the given points

Return type:

Array

static draw_cov(key, shape, minev=0.5, maxev=1.0)[source]#

Draw a random set of positive defined covariance matrix.

Parameters:
  • key (jax.Array) – PRNG key

  • shape (tuple) – shape of the set of covariance. The last dimension of the tuple must be the dimension of the space. So that each covariance matrix will have shape (shape[-1], shape[-1])

  • minev (jax.Array | float) – lower bound(s) for the eigenvalues

  • maxenv – higher bound(s) for the eigenvalues

  • maxev (jax.Array | float)

Returns:

PRNG key and the set of covariance matrices.

Return type:

tuple[jax.Array, jax.Array]

class StudentT(dim, nu=5.0)[source]#

Bases: Distr

Multivarate student-t with identity covariance.

Parameters:
  • dim (int) – space dimension

  • nu (float) – student-t \(\nu\) parameter

nu: float = 5.0#
normal#
sample(key, shape)[source]#

Samples the distribution.

Parameters:
  • key (ArrayLike) – PRNG key used as the random key.

  • shape (tuple) – shape of the sample.

Returns:

the actualized PRNG key and the samples with shape shape + (py:attr:dim,)

Return type:

tuple

leval(x)[source]#

Evaluates the log distribution on a set of points

Parameters:

x (array) – points array, with shape (…, py:attr:dim)

Returns:

the values of the log distribution on the given points

Return type:

Array