Source code for jexplore.steps.stretch
"""
This module define the class for a MH step based on stretch proposal.
"""
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMH, SamplingMH, StateMH
from jexplore.steps.colored import ColoredSC
[docs]
class Stretch[
Tepoch: EpochMH = EpochMH,
Tstate: StateMH = StateMH,
Tsampling: SamplingMH = SamplingMH,
](ColoredSC[Tepoch, Tstate, Tsampling]):
"""Class implementing a MH steps based on stretch proposal
:param a: stretch proposal `a` parameter
:param ngroups: number of groups. Default 2.
:param permute: if true walkers are permuted at each iteration.
"""
npart: int = 1
a: float
"""stretch proposal `a` parameter"""
def __init__(self, a: float = 2.0, ngroups: int = 2, permute: bool = False):
super().__init__(ngroups, permute)
self.sa = jnp.sqrt(a)
[docs]
def sample_z(self, key: jax.Array, size: int) -> jax.Array:
"""Sample the z distribution
:param key: PRNG key
:param size: output size
:return: samples
"""
u = jax.random.uniform(key, shape=(size,))
z = (u * (self.sa - 1.0 / self.sa) + 1.0 / self.sa) ** 2
return z
# pylint: disable=unused-argument
def _get_lqxy(self, z: jax.Array, state: Tstate, group: jax.Array) -> jax.Array:
"""
Gets the transition probability.
:param z: the (cgrop.size,) array of stretch proposal z parameters.
:param state: current state
:param group: color group list of indexes
:return: the (cgroup.size,) array of transition probabilities.
"""
return (self.sampling.dim - 1) * jnp.log(z)
# pylint: disable=unused-argument
[docs]
def proposal(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[jax.Array, Tstate, jax.Array]:
"""Propose a new state according to the stretch proposal algorithm.
:param key: PRNG key
:param state: current state
:param group: indexes of the color chains
:param cgroup: indexes of the complementary chains for this color.
:return: new state and the boolean mask of the chains modified by the step.
"""
key = jax.random.split(key, 3)
_y = self.get_partners(key[0], state, group, cgroup)[0, :, :]
# Here is the actual stretch prosal
_x = state.p[group, :]
z = self.sample_z(key[1], size=_y.shape[0])
_y = _y + z[:, None] * (_x - _y)
_prop = state.slice(group).set_val(_y)
return key[2], _prop, self._get_lqxy(z, state, group)