jexplore.steps.mh#

This modules defines the classes for Metropolis-Hastings steps.

Classes#

MHStep

Base class for MH and Model selection step. It specializes

AllChains

Full parallel all chains MH or modsel step.

Module Contents#

class MHStep[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH][source]#

Bases: jexplore.steps.step.Step[Tepoch, Tstate, Tsampling]

Base class for MH and Model selection step. It specializes build method to define a suitable array of \(\beta_i = 1/T_i\). It also defines a generic MH acceptance step (the mh method).

beta: jax.Array#

Array of temperatures inverses for all chains. Shape is (nchains, 1)

build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

compute(state)[source]#

Compute the loglik and logprior values of a state :param state: input state

Returns:

new state with populated loglik and logprior values.

Parameters:

state (Tstate)

Return type:

Tstate

mh(key, state, prop, betas, qxy)[source]#

Metropolis-Hastings acceptance step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • prop (Tstate) – proposed state

  • betas (jax.Array) – (nchains, 1) betas array

  • qxy (jax.Array) – transition probabilities (nchains, 1)

Returns:

new state with accepted changes and boolean mask of the changed chains.

Return type:

tuple[Tstate, jax.Array]

abstract step(key, state)[source]#

Step sampling method. This is just a prototype.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

class AllChains[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](mask=None)[source]#

Bases: MHStep[Tepoch, Tstate, Tsampling]

Full parallel all chains MH or modsel step.

Parameters:

mask (jax.Array | None) – proposal dimensions mask (default all space)

mask: jax.Array#

List of indices that the proposal will act on

masked_covs: jax.Array#

Masked version of the covariance

build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

abstract proposal(key, state)[source]#

All chain proposal.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

the new PRNG key, the proposed state, the transition log probability.

Return type:

tuple[jax.Array, Tstate, jax.Array]

step(key, state)[source]#

Metropolis Hasting step sampling method. It proposes a new state calling the proposal method and then performs a MH acceptance calling jexplore.steps.mh.MHState.mh method.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]