jexplore.tools.distributions#
Classes for sampling and evaluating some relevant distributions.
Attributes#
Prototype definition of a drawing function. |
Classes#
Module Contents#
- type DrawFn = Callable[[jax.Array, tuple], tuple[jax.Array, jax.Array]]#
Prototype definition of a drawing function.
- class Distr(dim)[source]#
Bases:
ProtocolAbstract parent distribution class.
- Parameters:
dim (int) – space dimension
- dim: int#
- sample(key, shape)[source]#
Samples the distribution.
- Parameters:
key (ArrayLike) – PRNG key used as the random key.
shape (tuple) – shape of the sample.
- Returns:
the actualized PRNG key and the samples with shape shape + (py:attr:dim,)
- Return type:
tuple
- class Uniform(dim, minval=0.0, maxval=1.0)[source]#
Bases:
DistrConstant distribution in a box:
- Parameters:
dim (int) – space dimension
minval (RealArray) – minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
maxval (RealArray) – maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
- minval: jax.Array#
- maxval: jax.Array#
- lvol: jax.Array#
- inbox(x)[source]#
Check if values are in the support of the distribution
- Parameters:
x (jax.Array) – values to be checked (last dimension of the shape should be dimension of the space on which the distribution is defined.
- Returns:
the boolean array of the results of the test.
- Return type:
jax.Array
- sample(key, shape)[source]#
Samples the distribution.
- Parameters:
key (ArrayLike) – PRNG key used as the random key.
shape (tuple) – shape of the sample.
- Returns:
the actualized PRNG key and the samples with shape shape + (py:attr:dim,)
- Return type:
tuple
- class Normal(dim)[source]#
Bases:
DistrNormal distribution with identity covariance.
- Parameters:
dim (int) – space dimension
- lnorm: jax.Array#
- class MVNormal(cov, scalar=False, norm=True)[source]#
Bases:
NormalMultivariate Normal distribution
- Parameters:
cov (jax.Array) – covariance matrix
scalar (bool) – if true the distribution evaluation is not vectorized (this is useful to use the distribution as a likelihood).
norm – if true the evaluation returns normalized values.
- cov: jax.Array#
Set of covariance matrices
- icov: jax.Array#
Inverse covariance matrices
- lower: jax.Array#
Cholesky lower matrices
- scalar: bool#
- lnorm#
- sample(key, shape)[source]#
Samples the distribution.
- Parameters:
key (ArrayLike) – PRNG key used as the random key.
shape (tuple) – shape of the sample.
- Returns:
the actualized PRNG key and the samples with shape shape + (py:attr:dim,)
- Return type:
tuple
- leval(x)[source]#
Evaluates the log distribution on a set of points
- Parameters:
x (array) – points array, with shape (…, py:attr:dim)
- Returns:
the values of the log distribution on the given points
- Return type:
Array
- static draw_cov(key, shape, minev=0.5, maxev=1.0)[source]#
Draw a random set of positive defined covariance matrix.
- Parameters:
key (jax.Array) – PRNG key
shape (tuple) – shape of the set of covariance. The last dimension of the tuple must be the dimension of the space. So that each covariance matrix will have shape (shape[-1], shape[-1])
minev (jax.Array | float) – lower bound(s) for the eigenvalues
maxenv – higher bound(s) for the eigenvalues
maxev (jax.Array | float)
- Returns:
PRNG key and the set of covariance matrices.
- Return type:
tuple[jax.Array, jax.Array]
- class StudentT(dim, nu=5.0)[source]#
Bases:
DistrMultivarate student-t with identity covariance.
- Parameters:
dim (int) – space dimension
nu (float) – student-t \(\nu\) parameter
- nu: float = 5.0#
- normal#