Source code for jexplore.steps.rwalk

"""
This module contains definitions of Metropolis-Hasting all chains
steps with random walks proposals.
"""

from typing import Type

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMH, StateMH
from jexplore.steps.direct import Direct
from jexplore.tools import distributions as d


[docs] class MVRandomWalk(Direct): r""" Generic multivariate random walk proposal. :param dist: distribution class :param mask: proposal dimensions mask (default all space) :param scale: random walk scaling factor. :param \**opts: options to be passed to the distribution creator. """ lower: jax.Array """covariance cholesky decomposition matrices `(nchains, dim, dim)`.""" cd_const: jax.Array """ effective scaling factor. This is the `scale` attribute divided by the square root of the space dimensionality.""" scale: float """random walk scaling factor.""" def __init__( self, dist: Type[d.Distr], mask: jax.Array | None = None, scale: float = 1.19, **opts, ): super().__init__(mask=mask, dist=dist, **opts) self.scale = scale
[docs] def build(self, epoch: EpochMH) -> None: """Step epoch initialisation method. This extends method :py:attr:`jexplore.steps.step.Step.build` by defining the `lower` and `cd_const` attributes. `epoch` and `sampling` attributes. :param epoch: current epoch. """ super().build(epoch) self.lower = jnp.linalg.cholesky(self.masked_covs) self.cd_const = self.scale / jnp.sqrt(self.mask.size)
[docs] def proposal( self, key: jax.Array, state: StateMH ) -> tuple[jax.Array, StateMH, jax.Array]: """ Samples from :py:attr:`jexplore.steps.rwalk.MVRandomWalk.epoch_dist` are used to perform random walk steps on all chains. :param key: PRNG key used as the random key. :param state: a (nwalkers * ntemps, dim) array representing the state point of the MCMC sampler. This corresponds to the :py:attr:`jexplore.sampling.state.State.p` attribute of a :py:attr:`jexplore.sampling.state.State` object. :return: the updated PRNG key, the new state point (nwalkers * ntemps, dim) and a (nwalkers *ntemps) array with the log of the ratio between the backward and the forward transition probabilities. """ key, y = self.epoch_dist.sample(key, shape=state.p.shape[:1]) step = self.cd_const * jnp.einsum("cij,cj->ci", self.lower, y) state = state.set_val(step + state.p[:, self.mask], pslc=self.mask) return key, state, jnp.zeros((state.p.shape[0]))
[docs] class GaussianRandomWalk(MVRandomWalk): """ Gaussian random walk proposal. :param mask: proposal dimensions mask (default all space) """ def __init__(self, mask: jax.Array | None = None): super().__init__(mask=mask, dist=d.Normal)
[docs] class StudentTRandomWalk(MVRandomWalk): """ Student-T random walk proposal. :param mask: proposal dimensions mask (default all space) :param float nu: Student-T nu parameter (default: 5) """ def __init__(self, mask: jax.Array | None = None, nu=5.0): super().__init__(mask=mask, dist=d.StudentT, nu=nu)