Source code for jexplore.steps.colored_alt

"""
This module defines an alternate version of
:py:attr:`jexplore.steps.colored.ColoredSC` and of its child classes
:py:attr:`jexplore.steps.stretch.Stretch` and :py:attr:`jexplore.steps.de.DEStep`.
The main difference is that the proposed partners for each chain are distinct.
"""

import jax

from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC
from jexplore.steps.de import DEStep
from jexplore.steps.stretch import Stretch


[docs] class ColoredAlt[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](ColoredSC[Tepoch, Tstate, Tsampling]): """Class implementing a MH steps based on stretch proposal :param a: stretch proposal `a` parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. """
[docs] def get_partners( self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array ) -> jax.Array: """Method for getting partners samples fore each chain of a group. This alternate implementation gets the first self.npart * group.size // ntemps chains out of the cgroup chains for each temperatures, after permuting them and uses them as partners of the group chains of the same temperature. With respect to the base implementation of this method, this one always returns distinct partners for each group. For this reason ngroups must be larger than self.npart + 1 (this has to be implemented by the children classes). :param key: PRNG key :param state: current state :param group: group chains :param cgroup: complementary group chains :return: the parners as an array with shape (self.npars, group.size, dim) """ ntemps = self.sampling.temps.shape[0] partners = jax.random.permutation( key, cgroup.reshape(-1, ntemps), axis=0, independent=True )[: self.npart * group.shape[0] // ntemps, :].reshape( self.npart, group.shape[0] ) return state.p[partners, :]
# pylint: disable=unused-argument
[docs] def proposal( self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array ) -> tuple[jax.Array, Tstate, jax.Array]: """Proposal restricted to chains of one color. This is just a prototype. :param key: PRNG key :param state: current state :param group: current color group :param cgroup: complementary chains for each chain in the group. :return: PNRG key, proposed state, transition probability for each chain.""" raise NotImplementedError( "ColoredSC is an abstract class. Need to implement proposal method." )
[docs] class StretchAlt[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](Stretch[Tepoch, Tstate, Tsampling]): """Class implementing a MH steps based on stretch proposal. This is an alternate version inwhich each chain of a color group receives a distinct partner from the complementary group chains. :param a: stretch proposal `a` parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. """ get_partners = ColoredAlt[Tepoch, Tstate, Tsampling].get_partners
[docs] class DEStepAlt[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](DEStep[Tepoch, Tstate, Tsampling]): r"""Class implementing a Differential evolution step. This is an alternate version inwhich each chain of a color group receives a distinct couple of partner from the complementary group chains. :param gamma: :math:`\gamma` scale parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. """ get_partners = ColoredAlt[Tepoch, Tstate, Tsampling].get_partners def __init__(self, gamma: float = 2.38, ngroups: int = 3, permute: bool = False): if ngroups < 3: raise ValueError("DE requires at least 3 colors grouping.") super().__init__(gamma, ngroups, permute)