"""
This module defines an alternate version of
:py:attr:`jexplore.steps.colored.ColoredSC` and of its child classes
:py:attr:`jexplore.steps.stretch.Stretch` and :py:attr:`jexplore.steps.de.DEStep`.
The main difference is that the proposed partners for each chain are distinct.
"""
import jax
from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC
from jexplore.steps.de import DEStep
from jexplore.steps.stretch import Stretch
[docs]
class ColoredAlt[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](ColoredSC[Tepoch, Tstate, Tsampling]):
"""Class implementing a MH steps based on stretch proposal
:param a: stretch proposal `a` parameter
:param ngroups: number of groups. Default 2.
:param permute: if true walkers are permuted at each iteration.
"""
[docs]
def get_partners(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> jax.Array:
"""Method for getting partners samples fore each chain of a group. This
alternate implementation gets the first self.npart * group.size // ntemps
chains out of the cgroup chains for each temperatures, after permuting them
and uses them as partners of the group chains of the same temperature.
With respect to the base implementation of this method, this one always
returns distinct partners for each group. For this reason ngroups must
be larger than self.npart + 1 (this has to be implemented by the children
classes).
:param key: PRNG key
:param state: current state
:param group: group chains
:param cgroup: complementary group chains
:return: the parners as an array with shape (self.npars, group.size, dim)
"""
ntemps = self.sampling.temps.shape[0]
partners = jax.random.permutation(
key, cgroup.reshape(-1, ntemps), axis=0, independent=True
)[: self.npart * group.shape[0] // ntemps, :].reshape(
self.npart, group.shape[0]
)
return state.p[partners, :]
# pylint: disable=unused-argument
[docs]
def proposal(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[jax.Array, Tstate, jax.Array]:
"""Proposal restricted to chains of one color. This is just a prototype.
:param key: PRNG key
:param state: current state
:param group: current color group
:param cgroup: complementary chains for each chain in the group.
:return: PNRG key, proposed state, transition probability for each chain."""
raise NotImplementedError(
"ColoredSC is an abstract class. Need to implement proposal method."
)
[docs]
class StretchAlt[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](Stretch[Tepoch, Tstate, Tsampling]):
"""Class implementing a MH steps based on stretch proposal. This is an
alternate version inwhich each chain of a color group receives a distinct
partner from the complementary group chains.
:param a: stretch proposal `a` parameter
:param ngroups: number of groups. Default 2.
:param permute: if true walkers are permuted at each iteration.
"""
get_partners = ColoredAlt[Tepoch, Tstate, Tsampling].get_partners
[docs]
class DEStepAlt[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](DEStep[Tepoch, Tstate, Tsampling]):
r"""Class implementing a Differential evolution step. This is an
alternate version inwhich each chain of a color group receives a distinct
couple of partner from the complementary group chains.
:param gamma: :math:`\gamma` scale parameter
:param ngroups: number of groups. Default 2.
:param permute: if true walkers are permuted at each iteration.
"""
get_partners = ColoredAlt[Tepoch, Tstate, Tsampling].get_partners
def __init__(self, gamma: float = 2.38, ngroups: int = 3, permute: bool = False):
if ngroups < 3:
raise ValueError("DE requires at least 3 colors grouping.")
super().__init__(gamma, ngroups, permute)