Source code for jexplore.steps.step

"""
This module define the :py:attr:`jexplore.steps.step.Step` class
which is the base class for all steps.
"""

from typing import cast

import jax
import jax.numpy as jnp

from jexplore.sampling import Epoch, Sampling, State, StepSample


[docs] class Step[ Tepoch: Epoch = Epoch, Tstate: State = State, Tsampling: Sampling = Sampling ]: """ This is the base step definition class. It provides a generic epoch initialization method `builder`, some general methods and a placeholder for the step method. """ epoch: Tepoch """current epoch""" sampling: Tsampling """sampling parameters""" _wrap: bool = True """Wrapping flag. If true will perform space wrapping after the step"""
[docs] def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: """Step sampling method. This is just a prototype. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ raise NotImplementedError( "Step is an abstract class. Need to implement this method." )
def _wrapped_step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: state, _stats = self.step(key, state) if self._wrap: state = state.update( lambda _nam, _val: self.sampling.space.wrap(_val), ["p"] ) return state, _stats # pylint: disable=unused-argument
[docs] def build(self, epoch: Tepoch) -> None: """Epoch initialisation method. :param epoch: current epoch. """ self.epoch = epoch self.sampling = cast(Tsampling, epoch.sampling)
[docs] def builder(self, epoch: Tepoch) -> StepSample[Tstate]: """ This method updates the step object state according to the current epoch, calling the `build` method, and returning the pointer to the `step` method. In principle this base implementation is general and does not need to be overridden or extended (rather extend the `build` method). :param epoch: epoch. :return: pointer to this class `step` method. """ # Update Proposal object state according to epoch self.build(epoch) # Return the Proposal.propose method return self._wrapped_step
[docs] @staticmethod def get_accepted(key: jax.Array, logacc: jax.Array) -> tuple[jax.Array, jax.Array]: """Given a set of log acceptance values, draws a correspondng set of log uniform values and returns the acceptance mask. :param key: PRNG key :param logacc: set of log acceptance values :return: new PNRG key and acceptance mask.""" key, kacc = jax.random.split(key) return key, jnp.log(jax.random.uniform(kacc, shape=logacc.shape)) <= logacc