Source code for jexplore.steps.tswap
"""
This module defines the class for a Metropolis-Hasting temperature
swap step.
"""
import jax
import jax.numpy as jnp
from jexplore.sampling.mh import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredMH
[docs]
class TSwap[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](ColoredMH[Tepoch, Tstate, Tsampling]):
"""Class for MH temperature swap step
:param ngroups: number of groups. If -1 (default) the temperature swapping
is fully serial.
:param permute: when true permute values before each iteration
:param duplicate: duplicate chains in group to make colors groups even
:param adjust: adjust the number of group to have even colors groups
"""
_wrap = False
def __init__(
self,
ngroups: int = -1,
permute: bool = False,
duplicate: bool = False,
adjust: bool = False,
):
super().__init__(
ngroups=ngroups, permute=permute, duplicate=duplicate, adjust=adjust
)
[docs]
def grouping(self) -> tuple[list[jax.Array], list[jax.Array]]:
r"""Defines colors chain groups by grouping temperatures. The :math:`i` th
color group contains chains for all walkers and with temperature indices
:math:`t \mod G = i + 1` and for each of these chains the complementary
group of chains contains the chain with temperature :math:`t - 1`.
:return: indexes of the chains of each color, indexes of the complementary
chains for each color."""
if self.ngroups == -1:
self.ngroups = self.sampling.temps.shape[0] - 1
# TODO: need to implement adjust and duplicate for TSwap
# TODO: probably need to add some logging if things are adjusted
if self.adjust or self.duplicate:
raise NotImplementedError("Adjusting grouping not implemented yet.")
ichain = jnp.arange(self.sampling.nchain)
ntemps = self.sampling.temps.shape[0]
nwalkers = self.sampling.nwalker
grps = jnp.arange(self.ngroups)
return (
[
ichain.reshape(nwalkers, ntemps)[:, 1 + _igrp :: self.ngroups].flatten()
for _igrp in grps
],
[
ichain.reshape(nwalkers, ntemps)[:, _igrp : -1 : self.ngroups].flatten()
for _igrp in grps
],
)
[docs]
def colored_step(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[Tstate, jax.Array]:
"""Temperature swap step for a single color. Propose a swap of two consecutive
temperatures and accept/reject it according to MH condition.
:param key: PRNG key
:param state: current state
:param group: indexes of the color chains
:param cgroup: indexes of the complementary chains for this color.
:return: new state and boolean mask of the changed chains"""
_, acc = self.get_accepted(
key,
(self.beta[group] - self.beta[cgroup])
* (state.ll[cgroup] - state.ll[group]),
)
acc = acc.reshape(state.ll[group].shape[0])
return state.swap(group, cgroup, acc), acc