"""
This modules defines the classes for Metropolis-Hastings
steps.
"""
from typing import cast
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.sampling.state import ArrayFn
from jexplore.steps.step import Step
[docs]
class MHStep[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](
Step[Tepoch, Tstate, Tsampling] # type: ignore[type-var]
):
r"""Base class for MH and Model selection step. It specializes
`build` method to define a suitable array of :math:`\beta_i = 1/T_i`.
It also defines a generic MH acceptance step (the `mh` method).
"""
beta: jax.Array
"""Array of temperatures inverses for all chains. Shape is (nchains, 1)"""
[docs]
def build(self, epoch: Tepoch) -> None:
"""Step epoch initialisation method. Extend :py:attr:`jexplore.steps.step.Step.build`
by populating the `betas` attribute.
:param epoch: current epoch.
"""
super().build(epoch)
self.beta = 1.0 / jnp.tile(epoch.sampling.temps, epoch.sampling.nwalker)
self.beta = self.beta.reshape(-1, 1)
[docs]
def compute(self, state: Tstate) -> Tstate:
"""Compute the loglik and logprior values of a state
:param state: input state
:return: new state with populated loglik and logprior values.
"""
return (
state.update_mask(state, pars=self.sampling.inpars)
.compute("ll", cast(ArrayFn, self.sampling.loglik), self.sampling.inpars)
.compute("lp", cast(ArrayFn, self.sampling.logprior), self.sampling.inpars)
)
# pylint: disable=too-many-arguments,too-many-positional-arguments
[docs]
def mh(
self,
key: jax.Array,
state: Tstate,
prop: Tstate,
betas: jax.Array,
qxy: jax.Array,
) -> tuple[Tstate, jax.Array]:
"""Metropolis-Hastings acceptance step
:param key: PRNG key
:param state: current state
:param prop: proposed state
:param betas: (nchains, 1) betas array
:param qxy: transition probabilities (nchains, 1)
:return: new state with accepted changes and boolean mask of the changed chains.
"""
_new: Tstate = self.compute(prop)
acc: jax.Array
_, acc = self.get_accepted(
key, betas * (_new.ll - state.ll) + _new.lp - state.lp + qxy.reshape(-1, 1)
)
acc = acc.reshape(state.ll.shape[0])
return state.update_mask(_new, acc), acc
[docs]
def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
"""Step sampling method. This is just a prototype.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
raise NotImplementedError(
"MHStep is an abstract class. Need to implement this method."
)
[docs]
class AllChains[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](MHStep[Tepoch, Tstate, Tsampling]):
"""Full parallel all chains MH or modsel step.
:param mask: proposal dimensions mask (default all space)
"""
mask: jax.Array
"""List of indices that the proposal will act on"""
masked_covs: jax.Array
"""Masked version of the covariance"""
def __init__(self, mask: jax.Array | None = None):
super().__init__()
self.mask = jnp.array([]) if mask is None else mask
[docs]
def build(self, epoch: Tepoch) -> None:
super().build(epoch)
if self.mask.size == 0:
self.mask = jnp.arange(epoch.sampling.dim)
self.masked_covs = jnp.take(
jnp.take(epoch.covs, self.mask, axis=1), self.mask, axis=2
)
[docs]
def proposal(
self, key: jax.Array, state: Tstate
) -> tuple[jax.Array, Tstate, jax.Array]:
"""All chain proposal.
:param key: PRNG key
:param state: current state
:return: the new PRNG key, the proposed state, the transition log probability.
"""
raise NotImplementedError(
"AllChains step is abstract. You need to implement this method."
)
[docs]
def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
"""Metropolis Hasting step sampling method. It proposes a new state
calling the `proposal` method and then performs a MH acceptance calling
:py:attr:`jexplore.steps.mh.MHState.mh` method.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
prop: Tstate
qxy: jax.Array
key, prop, qxy = self.proposal(key, state)
return self.mh(key, state, prop, self.beta, qxy)