jexplore.steps#
Markov steps
Submodules#
Classes#
Class implementing a Differential evolution step |
|
Sampling from a Gaussian distribution (to be completed). |
|
Sampling from a uniform distribution in a box. |
|
Gaussian random walk proposal. |
|
Student-T random walk proposal. |
|
Class implementing a MH steps based on stretch proposal |
|
Class for MH temperature swap step |
Package Contents#
- class DEStep[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](gamma=2.38, ngroups=2, permute=False)[source]#
Bases:
jexplore.steps.colored.ColoredSC[Tepoch,Tstate,Tsampling]Class implementing a Differential evolution step
- Parameters:
gamma (float) – \(\gamma\) scale parameter
ngroups (int) – number of groups. Default 2.
permute (bool) – if true walkers are permuted at each iteration.
- gamma: float#
DE proposal \(\gamma\) parameter
- sigma: jax.Array#
gamma distribution \(\sigma = \frac{\gamma}{2\sqrt{D}}\)
- npart: int = 2#
number of partners needed to build the proposal
- build(epoch)[source]#
Step initialisation method. It extends
jexplore.steps.colored.Colored.buildby simply adding the computation of the \(\sigma\) of the \(\gamma\) distribution.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- sample_gamma(key, state)[source]#
Sample \(\gamma\) from normal distribution
- Parameters:
key (jax.Array) – PRNG key
size – output size
state (Tstate)
- Returns:
samples
- Return type:
jax.Array
- proposal(key, state, group, cgroup)[source]#
Propose a new state according to the DE proposal algorithm.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array)
cgroup (jax.Array)
- Returns:
new state and the boolean mask of the chains modified by the step.
- Return type:
tuple[jax.Array, Tstate, jax.Array]
- class Gaussian(mask=None)[source]#
Bases:
DirectSampling from a Gaussian distribution (to be completed).
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
- class Uniform(mask=None, minval=0.0, maxval=1.0)[source]#
Bases:
DirectSampling from a uniform distribution in a box.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
minval (jax.Array | float) – minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
maxval (jax.Array | float) – maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
- class GaussianRandomWalk(mask=None)[source]#
Bases:
MVRandomWalkGaussian random walk proposal.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
- class StudentTRandomWalk(mask=None, nu=5.0)[source]#
Bases:
MVRandomWalkStudent-T random walk proposal.
- Parameters:
mask (jax.Array | None) – proposal dimensions mask (default all space)
nu (float) – Student-T nu parameter (default: 5)
- class Stretch[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](a=2.0, ngroups=2, permute=False)[source]#
Bases:
jexplore.steps.colored.ColoredSC[Tepoch,Tstate,Tsampling]Class implementing a MH steps based on stretch proposal
- Parameters:
a (float) – stretch proposal a parameter
ngroups (int) – number of groups. Default 2.
permute (bool) – if true walkers are permuted at each iteration.
- npart: int = 1#
number of partners needed to build the proposal
- a: float#
stretch proposal a parameter
- sa#
- sample_z(key, size)[source]#
Sample the z distribution
- Parameters:
key (jax.Array) – PRNG key
size (int) – output size
- Returns:
samples
- Return type:
jax.Array
- proposal(key, state, group, cgroup)[source]#
Propose a new state according to the stretch proposal algorithm.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – indexes of the color chains
cgroup (jax.Array) – indexes of the complementary chains for this color.
- Returns:
new state and the boolean mask of the chains modified by the step.
- Return type:
tuple[jax.Array, Tstate, jax.Array]
- class TSwap[Tepoch: jexplore.sampling.mh.EpochMH, Tstate: jexplore.sampling.mh.StateMH, Tsampling: jexplore.sampling.mh.SamplingMH](ngroups=-1, permute=False, duplicate=False, adjust=False)[source]#
Bases:
jexplore.steps.colored.ColoredMH[Tepoch,Tstate,Tsampling]Class for MH temperature swap step
- Parameters:
ngroups (int) – number of groups. If -1 (default) the temperature swapping is fully serial.
permute (bool) – when true permute values before each iteration
duplicate (bool) – duplicate chains in group to make colors groups even
adjust (bool) – adjust the number of group to have even colors groups
- grouping()[source]#
Defines colors chain groups by grouping temperatures. The \(i\) th color group contains chains for all walkers and with temperature indices \(t \mod G = i + 1\) and for each of these chains the complementary group of chains contains the chain with temperature \(t - 1\).
- Returns:
indexes of the chains of each color, indexes of the complementary chains for each color.
- Return type:
tuple[list[jax.Array], list[jax.Array]]
- colored_step(key, state, group, cgroup)[source]#
Temperature swap step for a single color. Propose a swap of two consecutive temperatures and accept/reject it according to MH condition.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – indexes of the color chains
cgroup (jax.Array) – indexes of the complementary chains for this color.
- Returns:
new state and boolean mask of the changed chains
- Return type:
tuple[Tstate, jax.Array]