jexplore.steps.modsel#
Model Selection Markov Steps
Submodules#
Attributes#
Temperature swap for model selection |
Classes#
Class implementing a model selection in-model Differential evolution step. |
|
Class implementing a model selection in-model steps based on stretch proposal. |
|
Model selection wrapper for |
|
Draw models for each chain from a categorical which log weights are the sum of |
|
Draws model pseudo prior samples for each chain |
|
Swap chains to restore model ordering (temp by temp) |
Package Contents#
- class IMDEStep[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](gamma=2.38, ngroups=2, permute=False)[source]#
Bases:
jexplore.steps.de.DEStep[Tepoch,Tstate,Tsampling]Class implementing a model selection in-model Differential evolution step.
- Parameters:
gamma (float) – \(\gamma\) scale parameter
ngroups (int) – number of groups. Default 2.
permute (bool) – if true walkers are permuted at each iteration.
- sigmas: jax.Array#
sigma parameters for all models
- get_partners#
Method for getting partners samples fore each chain of a group. The main difference from the base implementation in
jexplore.steps.colored.ColoredSCis that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model.- Parameters:
key – PRNG key
state – current state
group – group chains
cgroup – complementary group chains
- Returns:
the parners as an array with shape (self.npars, group.size, dim)
- build(epoch)[source]#
Step initialisation method. It extends
jexplore.steps.colored.Colored.buildby simply adding the computation of the \(\sigma\) of the \(\gamma\) distribution.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- class IMStretch[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](a=2.0, ngroups=2, permute=False)[source]#
Bases:
jexplore.steps.stretch.Stretch[Tepoch,Tstate,Tsampling]Class implementing a model selection in-model steps based on stretch proposal.
- Parameters:
a (float) – stretch proposal a parameter
ngroups (int) – number of groups. Default 2.
permute (bool) – if true walkers are permuted at each iteration.
- mdims: jax.Array#
- get_partners#
Method for getting partners samples fore each chain of a group. The main difference from the base implementation in
jexplore.steps.colored.ColoredSCis that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model.- Parameters:
key – PRNG key
state – current state
group – group chains
cgroup – complementary group chains
- Returns:
the parners as an array with shape (self.npars, group.size, dim)
- build(epoch)[source]#
Step initialisation method. It extends
jexplore.steps.step.Step.buildadding a call to grouping, to define the colors group, checks that the defined colors groups are even, rising aBadColoringexception otherwise.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- class AllChainsWrap[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS, Tstep: jexplore.steps.mh.AllChains](step)[source]#
Bases:
jexplore.steps.mh.AllChains[Tepoch,Tstate,Tsampling]Model selection wrapper for
jexplore.steps.mh.AllChainssteps. The step is wrapped so that it is applied only on chains that are in a model which (dimensions) mask fully contains the mask of the step. Note that the resulting step does not propose changes in to the chains’ models.- Parameters:
step (Tstep) – the step to be wrapped
- wrapped_step: Tstep#
- active_models: jax.Array#
- build(epoch)[source]#
Step epoch initialisation method. Extend
jexplore.steps.step.Step.buildby populating the betas attribute.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- proposal(key, state)[source]#
Propose a new state. The proposed state is identical to the old one for all chains in inactive models (i.e. models whose mask do not contain the mask of the wrapped step) and correspond to the state proposed by the wrapped step for all other chains.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
- Returns:
new state and the boolean mask of the chains modified by the step.
- Return type:
tuple[jax.Array, Tstate, jax.Array]
- class DrawModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](reorder=True, draw_pseudo=True)[source]#
Bases:
jexplore.steps.mh.MHStep[Tepoch,Tstate,Tsampling]Draw models for each chain from a categorical which log weights are the sum of log likelihood, log priori and log pseudo prior.
- Parameters:
reorder (bool) – reorder chains by model (temp by temp) after running the model draw.
draw_pseudo (bool) – draw from pseudo prior before drawing the models.
- reorder: bool#
Reorder the chain by model (and temp) after drawing models
- draw_pseudo: bool#
Draw pseudo prior before drawing models
- build(epoch)[source]#
Step epoch initialisation method. Extend
jexplore.steps.step.Step.buildby populating the betas attribute.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- class DrawPseudo[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS][source]#
Bases:
jexplore.steps.step.Step[Tepoch,Tstate,Tsampling]Draws model pseudo prior samples for each chain
- draw_funcs: list#
static list of pseudo prior rowing function for each model
- class OrderByModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS][source]#
Bases:
jexplore.steps.step.Step[Tepoch,Tstate,Tsampling]Swap chains to restore model ordering (temp by temp)
- TSwapMS#
Temperature swap for model selection