Source code for jexplore.steps.de

"""
This module define the class for a MH step based on Differential evolution
proposal.
"""

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC


[docs] class DEStep[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](ColoredSC[Tepoch, Tstate, Tsampling]): r"""Class implementing a Differential evolution step :param gamma: :math:`\gamma` scale parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. """ gamma: float r"""DE proposal :math:`\gamma` parameter""" sigma: jax.Array r"""gamma distribution :math:`\sigma = \frac{\gamma}{2\sqrt{D}}`""" npart: int = 2 def __init__(self, gamma: float = 2.38, ngroups: int = 2, permute: bool = False): super().__init__(ngroups, permute) self.gamma = gamma
[docs] def build(self, epoch: Tepoch) -> None: r"""Step initialisation method. It extends :py:attr:`jexplore.steps.colored.Colored.build` by simply adding the computation of the :math:`\sigma` of the :math:`\gamma` distribution. :param epoch: current epoch. """ super().build(epoch) self.sigma = self.gamma / jnp.sqrt(epoch.sampling.dim) / 2
[docs] def sample_gamma(self, key: jax.Array, state: Tstate) -> jax.Array: r"""Sample :math:`\gamma` from normal distribution :param key: PRNG key :param size: output size :return: samples """ return self.sigma * jax.random.normal(key, shape=(state.p.shape[0],))
# pylint: disable=unused-argument
[docs] def proposal( self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array ) -> tuple[jax.Array, Tstate, jax.Array]: """Propose a new state according to the DE proposal algorithm. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ key = jax.random.split(key, 3) _a, _b = self.get_partners(key[0], state, group, cgroup) state = state.slice(group) _x = state.p + self.sample_gamma(key[1], state)[:, None] * (_b - _a) _prop = state.set_val(_x) return key[2], _prop, jnp.zeros(_x.shape[0])