jexplore.steps.tswap#

This module defines the class for a Metropolis-Hasting temperature swap step.

Classes#

TSwap

Class for MH temperature swap step

Module Contents#

class TSwap[Tepoch: jexplore.sampling.mh.EpochMH, Tstate: jexplore.sampling.mh.StateMH, Tsampling: jexplore.sampling.mh.SamplingMH](ngroups=-1, permute=False, duplicate=False, adjust=False)[source]#

Bases: jexplore.steps.colored.ColoredMH[Tepoch, Tstate, Tsampling]

Class for MH temperature swap step

Parameters:
  • ngroups (int) – number of groups. If -1 (default) the temperature swapping is fully serial.

  • permute (bool) – when true permute values before each iteration

  • duplicate (bool) – duplicate chains in group to make colors groups even

  • adjust (bool) – adjust the number of group to have even colors groups

grouping()[source]#

Defines colors chain groups by grouping temperatures. The \(i\) th color group contains chains for all walkers and with temperature indices \(t \mod G = i + 1\) and for each of these chains the complementary group of chains contains the chain with temperature \(t - 1\).

Returns:

indexes of the chains of each color, indexes of the complementary chains for each color.

Return type:

tuple[list[jax.Array], list[jax.Array]]

colored_step(key, state, group, cgroup)[source]#

Temperature swap step for a single color. Propose a swap of two consecutive temperatures and accept/reject it according to MH condition.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – indexes of the color chains

  • cgroup (jax.Array) – indexes of the complementary chains for this color.

Returns:

new state and boolean mask of the changed chains

Return type:

tuple[Tstate, jax.Array]