jexplore.steps.rwalk#
This module contains definitions of Metropolis-Hasting all chains steps with random walks proposals.
Classes#
Generic multivariate random walk proposal. |
|
Gaussian random walk proposal. |
|
Student-T random walk proposal. |
Module Contents#
- class MVRandomWalk(dist, mask=None, scale=1.19, **opts)[source]#
Bases:
jexplore.steps.direct.DirectGeneric multivariate random walk proposal.
- Parameters:
dist (Type[jexplore.tools.distributions.Distr]) – distribution class
mask (jax.Array | None) – proposal dimensions mask (default all space)
scale (float) – random walk scaling factor.
**opts – options to be passed to the distribution creator.
- lower: jax.Array#
covariance cholesky decomposition matrices (nchains, dim, dim).
- cd_const: jax.Array#
effective scaling factor. This is the scale attribute divided by the square root of the space dimensionality.
- scale: float#
random walk scaling factor.
- build(epoch)[source]#
Step epoch initialisation method. This extends method
jexplore.steps.step.Step.buildby defining the lower and cd_const attributes.epoch and sampling attributes. :param epoch: current epoch.
- Parameters:
epoch (jexplore.sampling.EpochMH)
- Return type:
None
- proposal(key, state)[source]#
Samples from
jexplore.steps.rwalk.MVRandomWalk.epoch_distare used to perform random walk steps on all chains.- Parameters:
key (jax.Array) – PRNG key used as the random key.
state (jexplore.sampling.StateMH) – a (nwalkers * ntemps, dim) array representing the state point of the MCMC sampler. This corresponds to the
jexplore.sampling.state.State.pattribute of ajexplore.sampling.state.Stateobject.
- Returns:
the updated PRNG key, the new state point (nwalkers * ntemps, dim) and a (nwalkers *ntemps) array with the log of the ratio between the backward and the forward transition probabilities.
- Return type:
tuple[jax.Array, jexplore.sampling.StateMH, jax.Array]
- class GaussianRandomWalk(mask=None)[source]#
Bases:
MVRandomWalkGaussian random walk proposal.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
- class StudentTRandomWalk(mask=None, nu=5.0)[source]#
Bases:
MVRandomWalkStudent-T random walk proposal.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
nu (float) – Student-T nu parameter (default: 5)