Source code for jexplore.steps.single_masked
"""
This modules defines the classes for Metropolis-Hastings
steps.
"""
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.mh import AllChains
[docs]
class SingleMasked[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](AllChains[Tepoch, Tstate, Tsampling]):
"""Full parallel all chains MH or modsel step.
:param mask: mask of the subspace on which this step si defined.
"""
[docs]
def sc_masked_proposal(
self, key: jax.Array, p: jax.Array, ch: jax.Array
) -> tuple[jax.Array, jax.Array]:
"""Single chain masked proposal. This is a prototype.
:param key: PRNG key
:param p: single chain point
:param ch: chain index
:return: the proposed chain point, the transition log probability.
"""
raise NotImplementedError(
"AllChains step is abstract. You need to implement this method."
)
[docs]
def proposal(
self, key: jax.Array, state: Tstate
) -> tuple[jax.Array, Tstate, jax.Array]:
"""All chain proposal. This is a prototype.
:param key: PRNG key
:param state: current state
:return: the new PRNG key, the proposed state, the transition log probability.
"""
chains = jnp.arange(state.p.shape[0])
keys = jax.random.split(key, state.p.shape[0] + 1)
prop, qxy = jax.vmap(self.sc_masked_proposal, in_axes=(0, 0, 0))(
keys[1:], state.p[:, self.mask], chains
)
state = state.set_val(prop, pslc=self.mask)
return keys[0], state, qxy