Source code for jexplore.steps.de
"""
This module define the class for a MH step based on Differential evolution
proposal.
"""
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC
[docs]
class DEStep[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](ColoredSC[Tepoch, Tstate, Tsampling]):
r"""Class implementing a Differential evolution step
:param gamma: :math:`\gamma` scale parameter
:param ngroups: number of groups. Default 2.
:param permute: if true walkers are permuted at each iteration.
"""
gamma: float
r"""DE proposal :math:`\gamma` parameter"""
sigma: jax.Array
r"""gamma distribution :math:`\sigma = \frac{\gamma}{2\sqrt{D}}`"""
npart: int = 2
def __init__(self, gamma: float = 2.38, ngroups: int = 2, permute: bool = False):
super().__init__(ngroups, permute)
self.gamma = gamma
[docs]
def build(self, epoch: Tepoch) -> None:
r"""Step initialisation method. It extends
:py:attr:`jexplore.steps.colored.Colored.build` by simply adding the
computation of the :math:`\sigma` of the :math:`\gamma` distribution.
:param epoch: current epoch.
"""
super().build(epoch)
self.sigma = self.gamma / jnp.sqrt(epoch.sampling.dim) / 2
[docs]
def sample_gamma(self, key: jax.Array, state: Tstate) -> jax.Array:
r"""Sample :math:`\gamma` from normal distribution
:param key: PRNG key
:param size: output size
:return: samples
"""
return self.sigma * jax.random.normal(key, shape=(state.p.shape[0],))
# pylint: disable=unused-argument
[docs]
def proposal(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[jax.Array, Tstate, jax.Array]:
"""Propose a new state according to the DE proposal algorithm.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
key = jax.random.split(key, 3)
_a, _b = self.get_partners(key[0], state, group, cgroup)
state = state.slice(group)
_x = state.p + self.sample_gamma(key[1], state)[:, None] * (_b - _a)
_prop = state.set_val(_x)
return key[2], _prop, jnp.zeros(_x.shape[0])