Source code for jexplore.steps.mh

"""
This modules defines the classes for Metropolis-Hastings
steps.
"""

from typing import cast

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.sampling.state import ArrayFn
from jexplore.steps.step import Step


[docs] class MHStep[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ]( Step[Tepoch, Tstate, Tsampling] # type: ignore[type-var] ): r"""Base class for MH and Model selection step. It specializes `build` method to define a suitable array of :math:`\beta_i = 1/T_i`. It also defines a generic MH acceptance step (the `mh` method). """ beta: jax.Array """Array of temperatures inverses for all chains. Shape is (nchains, 1)"""
[docs] def build(self, epoch: Tepoch) -> None: """Step epoch initialisation method. Extend :py:attr:`jexplore.steps.step.Step.build` by populating the `betas` attribute. :param epoch: current epoch. """ super().build(epoch) self.beta = 1.0 / jnp.tile(epoch.sampling.temps, epoch.sampling.nwalker) self.beta = self.beta.reshape(-1, 1)
[docs] def compute(self, state: Tstate) -> Tstate: """Compute the loglik and logprior values of a state :param state: input state :return: new state with populated loglik and logprior values. """ return ( state.update_mask(state, pars=self.sampling.inpars) .compute("ll", cast(ArrayFn, self.sampling.loglik), self.sampling.inpars) .compute("lp", cast(ArrayFn, self.sampling.logprior), self.sampling.inpars) )
# pylint: disable=too-many-arguments,too-many-positional-arguments
[docs] def mh( self, key: jax.Array, state: Tstate, prop: Tstate, betas: jax.Array, qxy: jax.Array, ) -> tuple[Tstate, jax.Array]: """Metropolis-Hastings acceptance step :param key: PRNG key :param state: current state :param prop: proposed state :param betas: (nchains, 1) betas array :param qxy: transition probabilities (nchains, 1) :return: new state with accepted changes and boolean mask of the changed chains. """ _new: Tstate = self.compute(prop) acc: jax.Array _, acc = self.get_accepted( key, betas * (_new.ll - state.ll) + _new.lp - state.lp + qxy.reshape(-1, 1) ) acc = acc.reshape(state.ll.shape[0]) return state.update_mask(_new, acc), acc
[docs] def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: """Step sampling method. This is just a prototype. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ raise NotImplementedError( "MHStep is an abstract class. Need to implement this method." )
[docs] class AllChains[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](MHStep[Tepoch, Tstate, Tsampling]): """Full parallel all chains MH or modsel step. :param mask: proposal dimensions mask (default all space) """ mask: jax.Array """List of indices that the proposal will act on""" masked_covs: jax.Array """Masked version of the covariance""" def __init__(self, mask: jax.Array | None = None): super().__init__() self.mask = jnp.array([]) if mask is None else mask
[docs] def build(self, epoch: Tepoch) -> None: super().build(epoch) if self.mask.size == 0: self.mask = jnp.arange(epoch.sampling.dim) self.masked_covs = jnp.take( jnp.take(epoch.covs, self.mask, axis=1), self.mask, axis=2 )
[docs] def proposal( self, key: jax.Array, state: Tstate ) -> tuple[jax.Array, Tstate, jax.Array]: """All chain proposal. :param key: PRNG key :param state: current state :return: the new PRNG key, the proposed state, the transition log probability. """ raise NotImplementedError( "AllChains step is abstract. You need to implement this method." )
[docs] def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: """Metropolis Hasting step sampling method. It proposes a new state calling the `proposal` method and then performs a MH acceptance calling :py:attr:`jexplore.steps.mh.MHState.mh` method. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ prop: Tstate qxy: jax.Array key, prop, qxy = self.proposal(key, state) return self.mh(key, state, prop, self.beta, qxy)