Source code for jexplore.steps.tswap

"""
This module defines the class for a Metropolis-Hasting temperature
swap step.
"""

import jax
import jax.numpy as jnp

from jexplore.sampling.mh import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredMH


[docs] class TSwap[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](ColoredMH[Tepoch, Tstate, Tsampling]): """Class for MH temperature swap step :param ngroups: number of groups. If -1 (default) the temperature swapping is fully serial. :param permute: when true permute values before each iteration :param duplicate: duplicate chains in group to make colors groups even :param adjust: adjust the number of group to have even colors groups """ _wrap = False def __init__( self, ngroups: int = -1, permute: bool = False, duplicate: bool = False, adjust: bool = False, ): super().__init__( ngroups=ngroups, permute=permute, duplicate=duplicate, adjust=adjust )
[docs] def grouping(self) -> tuple[list[jax.Array], list[jax.Array]]: r"""Defines colors chain groups by grouping temperatures. The :math:`i` th color group contains chains for all walkers and with temperature indices :math:`t \mod G = i + 1` and for each of these chains the complementary group of chains contains the chain with temperature :math:`t - 1`. :return: indexes of the chains of each color, indexes of the complementary chains for each color.""" if self.ngroups == -1: self.ngroups = self.sampling.temps.shape[0] - 1 # TODO: need to implement adjust and duplicate for TSwap # TODO: probably need to add some logging if things are adjusted if self.adjust or self.duplicate: raise NotImplementedError("Adjusting grouping not implemented yet.") ichain = jnp.arange(self.sampling.nchain) ntemps = self.sampling.temps.shape[0] nwalkers = self.sampling.nwalker grps = jnp.arange(self.ngroups) return ( [ ichain.reshape(nwalkers, ntemps)[:, 1 + _igrp :: self.ngroups].flatten() for _igrp in grps ], [ ichain.reshape(nwalkers, ntemps)[:, _igrp : -1 : self.ngroups].flatten() for _igrp in grps ], )
[docs] def colored_step( self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array ) -> tuple[Tstate, jax.Array]: """Temperature swap step for a single color. Propose a swap of two consecutive temperatures and accept/reject it according to MH condition. :param key: PRNG key :param state: current state :param group: indexes of the color chains :param cgroup: indexes of the complementary chains for this color. :return: new state and boolean mask of the changed chains""" _, acc = self.get_accepted( key, (self.beta[group] - self.beta[cgroup]) * (state.ll[cgroup] - state.ll[group]), ) acc = acc.reshape(state.ll[group].shape[0]) return state.swap(group, cgroup, acc), acc