jexplore.steps.tswap#
This module defines the class for a Metropolis-Hasting temperature swap step.
Classes#
Class for MH temperature swap step |
Module Contents#
- class TSwap[Tepoch: jexplore.sampling.mh.EpochMH, Tstate: jexplore.sampling.mh.StateMH, Tsampling: jexplore.sampling.mh.SamplingMH](ngroups=-1, permute=False, duplicate=False, adjust=False)[source]#
Bases:
jexplore.steps.colored.ColoredMH[Tepoch,Tstate,Tsampling]Class for MH temperature swap step
- Parameters:
ngroups (int) – number of groups. If -1 (default) the temperature swapping is fully serial.
permute (bool) – when true permute values before each iteration
duplicate (bool) – duplicate chains in group to make colors groups even
adjust (bool) – adjust the number of group to have even colors groups
- grouping()[source]#
Defines colors chain groups by grouping temperatures. The \(i\) th color group contains chains for all walkers and with temperature indices \(t \mod G = i + 1\) and for each of these chains the complementary group of chains contains the chain with temperature \(t - 1\).
- Returns:
indexes of the chains of each color, indexes of the complementary chains for each color.
- Return type:
tuple[list[jax.Array], list[jax.Array]]
- colored_step(key, state, group, cgroup)[source]#
Temperature swap step for a single color. Propose a swap of two consecutive temperatures and accept/reject it according to MH condition.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – indexes of the color chains
cgroup (jax.Array) – indexes of the complementary chains for this color.
- Returns:
new state and boolean mask of the changed chains
- Return type:
tuple[Tstate, jax.Array]