jexplore.steps.modsel.model_switch#

This module defines the classes for Model Selection models switching steps.

Classes#

DrawPseudo

Draws model pseudo prior samples for each chain

DrawModel

Draw models for each chain from a categorical which log weights are the sum of

Module Contents#

class DrawPseudo[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS][source]#

Bases: jexplore.steps.step.Step[Tepoch, Tstate, Tsampling]

Draws model pseudo prior samples for each chain

draw_funcs: list#

static list of pseudo prior rowing function for each model

build(epoch)[source]#

Epoch initialisation method.

Parameters:

epoch (Tepoch) – current epoch.

step(key, state)[source]#

Pseudo drawing step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

class DrawModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](reorder=True, draw_pseudo=True)[source]#

Bases: jexplore.steps.mh.MHStep[Tepoch, Tstate, Tsampling]

Draw models for each chain from a categorical which log weights are the sum of log likelihood, log priori and log pseudo prior.

Parameters:
  • reorder (bool) – reorder chains by model (temp by temp) after running the model draw.

  • draw_pseudo (bool) – draw from pseudo prior before drawing the models.

reorder: bool#

Reorder the chain by model (and temp) after drawing models

draw_pseudo: bool#

Draw pseudo prior before drawing models

build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

step(key, state)[source]#

Model drawing step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]