jexplore.steps#

Markov steps

Submodules#

Classes#

DEStep

Class implementing a Differential evolution step

Gaussian

Sampling from a Gaussian distribution (to be completed).

Uniform

Sampling from a uniform distribution in a box.

GaussianRandomWalk

Gaussian random walk proposal.

StudentTRandomWalk

Student-T random walk proposal.

Stretch

Class implementing a MH steps based on stretch proposal

TSwap

Class for MH temperature swap step

Package Contents#

class DEStep[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](gamma=2.38, ngroups=2, permute=False)[source]#

Bases: jexplore.steps.colored.ColoredSC[Tepoch, Tstate, Tsampling]

Class implementing a Differential evolution step

Parameters:
  • gamma (float) – \(\gamma\) scale parameter

  • ngroups (int) – number of groups. Default 2.

  • permute (bool) – if true walkers are permuted at each iteration.

gamma: float#

DE proposal \(\gamma\) parameter

sigma: jax.Array#

gamma distribution \(\sigma = \frac{\gamma}{2\sqrt{D}}\)

npart: int = 2#

number of partners needed to build the proposal

build(epoch)[source]#

Step initialisation method. It extends jexplore.steps.colored.Colored.build by simply adding the computation of the \(\sigma\) of the \(\gamma\) distribution.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

sample_gamma(key, state)[source]#

Sample \(\gamma\) from normal distribution

Parameters:
  • key (jax.Array) – PRNG key

  • size – output size

  • state (Tstate)

Returns:

samples

Return type:

jax.Array

proposal(key, state, group, cgroup)[source]#

Propose a new state according to the DE proposal algorithm.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array)

  • cgroup (jax.Array)

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[jax.Array, Tstate, jax.Array]

class Gaussian(mask=None)[source]#

Bases: Direct

Sampling from a Gaussian distribution (to be completed).

Parameters:

mask (jax.Array | None) – proposal dimensions mask (default all space)

class Uniform(mask=None, minval=0.0, maxval=1.0)[source]#

Bases: Direct

Sampling from a uniform distribution in a box.

Parameters:
  • mask (jax.Array | None) – proposal dimensions mask (default all space)

  • minval (jax.Array | float) – minimum (inclusive) value broadcast-compatible with shape for the range (default 0).

  • maxval (jax.Array | float) – maximum (exclusive) value broadcast-compatible with shape for the range (default 1).

class GaussianRandomWalk(mask=None)[source]#

Bases: MVRandomWalk

Gaussian random walk proposal.

Parameters:

mask (jax.Array | None) – proposal dimensions mask (default all space)

class StudentTRandomWalk(mask=None, nu=5.0)[source]#

Bases: MVRandomWalk

Student-T random walk proposal.

Parameters:
  • mask (jax.Array | None) – proposal dimensions mask (default all space)

  • nu (float) – Student-T nu parameter (default: 5)

class Stretch[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](a=2.0, ngroups=2, permute=False)[source]#

Bases: jexplore.steps.colored.ColoredSC[Tepoch, Tstate, Tsampling]

Class implementing a MH steps based on stretch proposal

Parameters:
  • a (float) – stretch proposal a parameter

  • ngroups (int) – number of groups. Default 2.

  • permute (bool) – if true walkers are permuted at each iteration.

npart: int = 1#

number of partners needed to build the proposal

a: float#

stretch proposal a parameter

sa#
sample_z(key, size)[source]#

Sample the z distribution

Parameters:
  • key (jax.Array) – PRNG key

  • size (int) – output size

Returns:

samples

Return type:

jax.Array

proposal(key, state, group, cgroup)[source]#

Propose a new state according to the stretch proposal algorithm.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – indexes of the color chains

  • cgroup (jax.Array) – indexes of the complementary chains for this color.

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[jax.Array, Tstate, jax.Array]

class TSwap[Tepoch: jexplore.sampling.mh.EpochMH, Tstate: jexplore.sampling.mh.StateMH, Tsampling: jexplore.sampling.mh.SamplingMH](ngroups=-1, permute=False, duplicate=False, adjust=False)[source]#

Bases: jexplore.steps.colored.ColoredMH[Tepoch, Tstate, Tsampling]

Class for MH temperature swap step

Parameters:
  • ngroups (int) – number of groups. If -1 (default) the temperature swapping is fully serial.

  • permute (bool) – when true permute values before each iteration

  • duplicate (bool) – duplicate chains in group to make colors groups even

  • adjust (bool) – adjust the number of group to have even colors groups

grouping()[source]#

Defines colors chain groups by grouping temperatures. The \(i\) th color group contains chains for all walkers and with temperature indices \(t \mod G = i + 1\) and for each of these chains the complementary group of chains contains the chain with temperature \(t - 1\).

Returns:

indexes of the chains of each color, indexes of the complementary chains for each color.

Return type:

tuple[list[jax.Array], list[jax.Array]]

colored_step(key, state, group, cgroup)[source]#

Temperature swap step for a single color. Propose a swap of two consecutive temperatures and accept/reject it according to MH condition.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – indexes of the color chains

  • cgroup (jax.Array) – indexes of the complementary chains for this color.

Returns:

new state and boolean mask of the changed chains

Return type:

tuple[Tstate, jax.Array]