Source code for jexplore.steps.step
"""
This module define the :py:attr:`jexplore.steps.step.Step` class
which is the base class for all steps.
"""
from typing import cast
import jax
import jax.numpy as jnp
from jexplore.sampling import Epoch, Sampling, State, StepSample
[docs]
class Step[
Tepoch: Epoch = Epoch, Tstate: State = State, Tsampling: Sampling = Sampling
]:
"""
This is the base step definition class. It provides a generic epoch initialization
method `builder`, some general methods and a placeholder for the step method.
"""
epoch: Tepoch
"""current epoch"""
sampling: Tsampling
"""sampling parameters"""
_wrap: bool = True
"""Wrapping flag. If true will perform space wrapping after the step"""
[docs]
def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
"""Step sampling method. This is just a prototype.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
raise NotImplementedError(
"Step is an abstract class. Need to implement this method."
)
def _wrapped_step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
state, _stats = self.step(key, state)
if self._wrap:
state = state.update(
lambda _nam, _val: self.sampling.space.wrap(_val), ["p"]
)
return state, _stats
# pylint: disable=unused-argument
[docs]
def build(self, epoch: Tepoch) -> None:
"""Epoch initialisation method.
:param epoch: current epoch.
"""
self.epoch = epoch
self.sampling = cast(Tsampling, epoch.sampling)
[docs]
def builder(self, epoch: Tepoch) -> StepSample[Tstate]:
"""
This method updates the step object state according to
the current epoch, calling the `build` method, and
returning the pointer to the `step` method.
In principle this base implementation is general and
does not need to be overridden or extended (rather extend the
`build` method).
:param epoch: epoch.
:return: pointer to this class `step` method.
"""
# Update Proposal object state according to epoch
self.build(epoch)
# Return the Proposal.propose method
return self._wrapped_step
[docs]
@staticmethod
def get_accepted(key: jax.Array, logacc: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Given a set of log acceptance values, draws a correspondng set of
log uniform values and returns the acceptance mask.
:param key: PRNG key
:param logacc: set of log acceptance values
:return: new PNRG key and acceptance mask."""
key, kacc = jax.random.split(key)
return key, jnp.log(jax.random.uniform(kacc, shape=logacc.shape)) <= logacc