Source code for jexplore.steps.single_masked

"""
This modules defines the classes for Metropolis-Hastings
steps.
"""

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.mh import AllChains


[docs] class SingleMasked[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](AllChains[Tepoch, Tstate, Tsampling]): """Full parallel all chains MH or modsel step. :param mask: mask of the subspace on which this step si defined. """
[docs] def sc_masked_proposal( self, key: jax.Array, p: jax.Array, ch: jax.Array ) -> tuple[jax.Array, jax.Array]: """Single chain masked proposal. This is a prototype. :param key: PRNG key :param p: single chain point :param ch: chain index :return: the proposed chain point, the transition log probability. """ raise NotImplementedError( "AllChains step is abstract. You need to implement this method." )
[docs] def proposal( self, key: jax.Array, state: Tstate ) -> tuple[jax.Array, Tstate, jax.Array]: """All chain proposal. This is a prototype. :param key: PRNG key :param state: current state :return: the new PRNG key, the proposed state, the transition log probability. """ chains = jnp.arange(state.p.shape[0]) keys = jax.random.split(key, state.p.shape[0] + 1) prop, qxy = jax.vmap(self.sc_masked_proposal, in_axes=(0, 0, 0))( keys[1:], state.p[:, self.mask], chains ) state = state.set_val(prop, pslc=self.mask) return keys[0], state, qxy