jexplore.steps.modsel.inmodel_single#

This module defines the classes for Model Selection in-model steps acting separately on each chain.

Classes#

AllChainsWrap

Model selection wrapper for jexplore.steps.mh.AllChains

InModelSC

In model single chain MH step. This is constructed out of a list of

Module Contents#

class AllChainsWrap[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS, Tstep: jexplore.steps.mh.AllChains](step)[source]#

Bases: jexplore.steps.mh.AllChains[Tepoch, Tstate, Tsampling]

Model selection wrapper for jexplore.steps.mh.AllChains steps. The step is wrapped so that it is applied only on chains that are in a model which (dimensions) mask fully contains the mask of the step. Note that the resulting step does not propose changes in to the chains’ models.

Parameters:

step (Tstep) – the step to be wrapped

wrapped_step: Tstep#
active_models: jax.Array#
build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

proposal(key, state)[source]#

Propose a new state. The proposed state is identical to the old one for all chains in inactive models (i.e. models whose mask do not contain the mask of the wrapped step) and correspond to the state proposed by the wrapped step for all other chains.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[jax.Array, Tstate, jax.Array]

class InModelSC[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS, Tstep: jexplore.steps.single_masked.SingleMasked](steps, stepsargs)[source]#

Bases: jexplore.steps.mh.AllChains[Tepoch, Tstate, Tsampling]

In model single chain MH step. This is constructed out of a list of single chain proposals each one defined on a subspacei of the corresponding model space.

Parameters:
  • steps (type[Tstep] | list[type[Tstep] | None]) – optional list of one (single chain MH) step instance for each model.

  • stepclass – step class. If steps argument is not specified, the build method will instantiate one step for each model usng this step constructor.

  • stepargs – step instantiation arguments.

  • stepsargs (dict | list[dict] | None)

steps: list[type[Tstep] | None]#

List of steps classes per model.

stepsargs: list[dict]#

Step instantiation arguments.

build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

proposal(key, state)[source]#

In model single chain proposal.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

the new PRNG key, the proposed state, the transition log probability.

Return type:

tuple[jax.Array, Tstate, jax.Array]