jexplore.steps.modsel#

Model Selection Markov Steps

Submodules#

Attributes#

TSwapMS

Temperature swap for model selection

Classes#

IMDEStep

Class implementing a model selection in-model Differential evolution step.

IMStretch

Class implementing a model selection in-model steps based on stretch proposal.

AllChainsWrap

Model selection wrapper for jexplore.steps.mh.AllChains

DrawModel

Draw models for each chain from a categorical which log weights are the sum of

DrawPseudo

Draws model pseudo prior samples for each chain

OrderByModel

Swap chains to restore model ordering (temp by temp)

Package Contents#

class IMDEStep[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](gamma=2.38, ngroups=2, permute=False)[source]#

Bases: jexplore.steps.de.DEStep[Tepoch, Tstate, Tsampling]

Class implementing a model selection in-model Differential evolution step.

Parameters:
  • gamma (float) – \(\gamma\) scale parameter

  • ngroups (int) – number of groups. Default 2.

  • permute (bool) – if true walkers are permuted at each iteration.

sigmas: jax.Array#

sigma parameters for all models

get_partners#

Method for getting partners samples fore each chain of a group. The main difference from the base implementation in jexplore.steps.colored.ColoredSC is that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model.

Parameters:
  • key – PRNG key

  • state – current state

  • group – group chains

  • cgroup – complementary group chains

Returns:

the parners as an array with shape (self.npars, group.size, dim)

build(epoch)[source]#

Step initialisation method. It extends jexplore.steps.colored.Colored.build by simply adding the computation of the \(\sigma\) of the \(\gamma\) distribution.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

sample_gamma(key, state)[source]#

Sample \(\gamma\) from normal distribution

Parameters:
  • key (jax.Array) – PRNG key

  • size – output size

  • state (Tstate)

Returns:

samples

Return type:

jax.Array

class IMStretch[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](a=2.0, ngroups=2, permute=False)[source]#

Bases: jexplore.steps.stretch.Stretch[Tepoch, Tstate, Tsampling]

Class implementing a model selection in-model steps based on stretch proposal.

Parameters:
  • a (float) – stretch proposal a parameter

  • ngroups (int) – number of groups. Default 2.

  • permute (bool) – if true walkers are permuted at each iteration.

mdims: jax.Array#
get_partners#

Method for getting partners samples fore each chain of a group. The main difference from the base implementation in jexplore.steps.colored.ColoredSC is that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model.

Parameters:
  • key – PRNG key

  • state – current state

  • group – group chains

  • cgroup – complementary group chains

Returns:

the parners as an array with shape (self.npars, group.size, dim)

build(epoch)[source]#

Step initialisation method. It extends jexplore.steps.step.Step.build adding a call to grouping, to define the colors group, checks that the defined colors groups are even, rising a BadColoring exception otherwise.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

class AllChainsWrap[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS, Tstep: jexplore.steps.mh.AllChains](step)[source]#

Bases: jexplore.steps.mh.AllChains[Tepoch, Tstate, Tsampling]

Model selection wrapper for jexplore.steps.mh.AllChains steps. The step is wrapped so that it is applied only on chains that are in a model which (dimensions) mask fully contains the mask of the step. Note that the resulting step does not propose changes in to the chains’ models.

Parameters:

step (Tstep) – the step to be wrapped

wrapped_step: Tstep#
active_models: jax.Array#
build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

proposal(key, state)[source]#

Propose a new state. The proposed state is identical to the old one for all chains in inactive models (i.e. models whose mask do not contain the mask of the wrapped step) and correspond to the state proposed by the wrapped step for all other chains.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[jax.Array, Tstate, jax.Array]

class DrawModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](reorder=True, draw_pseudo=True)[source]#

Bases: jexplore.steps.mh.MHStep[Tepoch, Tstate, Tsampling]

Draw models for each chain from a categorical which log weights are the sum of log likelihood, log priori and log pseudo prior.

Parameters:
  • reorder (bool) – reorder chains by model (temp by temp) after running the model draw.

  • draw_pseudo (bool) – draw from pseudo prior before drawing the models.

reorder: bool#

Reorder the chain by model (and temp) after drawing models

draw_pseudo: bool#

Draw pseudo prior before drawing models

build(epoch)[source]#

Step epoch initialisation method. Extend jexplore.steps.step.Step.build by populating the betas attribute.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

step(key, state)[source]#

Model drawing step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

class DrawPseudo[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS][source]#

Bases: jexplore.steps.step.Step[Tepoch, Tstate, Tsampling]

Draws model pseudo prior samples for each chain

draw_funcs: list#

static list of pseudo prior rowing function for each model

build(epoch)[source]#

Epoch initialisation method.

Parameters:

epoch (Tepoch) – current epoch.

step(key, state)[source]#

Pseudo drawing step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

class OrderByModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS][source]#

Bases: jexplore.steps.step.Step[Tepoch, Tstate, Tsampling]

Swap chains to restore model ordering (temp by temp)

step(key, state)[source]#

Pseudo drawing step

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

TSwapMS#

Temperature swap for model selection