Source code for jexplore.steps.stretch

"""
This module define the class for a MH step based on stretch proposal.
"""

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC


[docs] class Stretch[ Tepoch: EpochMH = EpochMH, Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH, ](ColoredSC[Tepoch, Tstate, Tsampling]): """Class implementing a MH steps based on stretch proposal :param a: stretch proposal `a` parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. """ npart: int = 1 a: float """stretch proposal `a` parameter""" def __init__(self, a: float = 2.0, ngroups: int = 2, permute: bool = False): super().__init__(ngroups, permute) self.sa = jnp.sqrt(a)
[docs] def sample_z(self, key: jax.Array, size: int) -> jax.Array: """Sample the z distribution :param key: PRNG key :param size: output size :return: samples """ u = jax.random.uniform(key, shape=(size,)) z = (u * (self.sa - 1.0 / self.sa) + 1.0 / self.sa) ** 2 return z
# pylint: disable=unused-argument def _get_lqxy(self, z: jax.Array, state: Tstate, group: jax.Array) -> jax.Array: """ Gets the transition probability. :param z: the (cgrop.size,) array of stretch proposal z parameters. :param state: current state :param group: color group list of indexes :return: the (cgroup.size,) array of transition probabilities. """ return (self.sampling.dim - 1) * jnp.log(z) # pylint: disable=unused-argument
[docs] def proposal( self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array ) -> tuple[jax.Array, Tstate, jax.Array]: """Propose a new state according to the stretch proposal algorithm. :param key: PRNG key :param state: current state :param group: indexes of the color chains :param cgroup: indexes of the complementary chains for this color. :return: new state and the boolean mask of the chains modified by the step. """ key = jax.random.split(key, 3) _y = self.get_partners(key[0], state, group, cgroup)[0, :, :] # Here is the actual stretch prosal _x = state.p[group, :] z = self.sample_z(key[1], size=_y.shape[0]) _y = _y + z[:, None] * (_x - _y) _prop = state.slice(group).set_val(_y) return key[2], _prop, self._get_lqxy(z, state, group)