jexplore.steps.mh#
This modules defines the classes for Metropolis-Hastings steps.
Classes#
Module Contents#
- class MHStep[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH][source]#
Bases:
jexplore.steps.step.Step[Tepoch,Tstate,Tsampling]Base class for MH and Model selection step. It specializes build method to define a suitable array of \(\beta_i = 1/T_i\). It also defines a generic MH acceptance step (the mh method).
- beta: jax.Array#
Array of temperatures inverses for all chains. Shape is (nchains, 1)
- build(epoch)[source]#
Step epoch initialisation method. Extend
jexplore.steps.step.Step.buildby populating the betas attribute.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- compute(state)[source]#
Compute the loglik and logprior values of a state :param state: input state
- Returns:
new state with populated loglik and logprior values.
- Parameters:
state (Tstate)
- Return type:
Tstate
- mh(key, state, prop, betas, qxy)[source]#
Metropolis-Hastings acceptance step
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
prop (Tstate) – proposed state
betas (jax.Array) – (nchains, 1) betas array
qxy (jax.Array) – transition probabilities (nchains, 1)
- Returns:
new state with accepted changes and boolean mask of the changed chains.
- Return type:
tuple[Tstate, jax.Array]
- class AllChains[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](mask=None)[source]#
Bases:
MHStep[Tepoch,Tstate,Tsampling]Full parallel all chains MH or modsel step.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
- mask: jax.Array#
List of indices that the proposal will act on
- masked_covs: jax.Array#
Masked version of the covariance
- build(epoch)[source]#
Step epoch initialisation method. Extend
jexplore.steps.step.Step.buildby populating the betas attribute.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- abstract proposal(key, state)[source]#
All chain proposal.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
- Returns:
the new PRNG key, the proposed state, the transition log probability.
- Return type:
tuple[jax.Array, Tstate, jax.Array]
- step(key, state)[source]#
Metropolis Hasting step sampling method. It proposes a new state calling the proposal method and then performs a MH acceptance calling
jexplore.steps.mh.MHState.mhmethod.- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
- Returns:
new state and the boolean mask of the chains modified by the step.
- Return type:
tuple[Tstate, jax.Array]