Source code for jexplore.steps.direct
"""
This module defines single chain Metropolis-Hastings
steps with direct proposals: i.e. proposals proposing points
by sampling a given distribution.
"""
from typing import Type
import jax
from jexplore.sampling import EpochMH, StateMH
from jexplore.steps.mh import AllChains
from jexplore.tools import distributions as d
[docs]
class Direct(AllChains):
r"""Class implementing all chain metropoling hasting step
with a proposal directly proposing points sampling
from a distribution
:param dist: distribution.
:param \**opts: options to be passed to the distribution creator.
"""
dist: Type[d.Distr]
"""distribution class"""
opts: dict
"""distribution creator options"""
epoch_dist: d.Distr
"""distribution instance for the epoch"""
def __init__(self, dist: Type[d.Distr], mask: jax.Array | None = None, **opts):
super().__init__(mask=mask)
self.dist = dist
self.opts = opts
[docs]
def build(self, epoch: EpochMH) -> None:
"""Step epoch initialisation method. This extends
method :py:attr:`jexplore.steps.step.Step.build` by instantiating
the ditribution object for this epoch.
`epoch` and `sampling` attributes.
:param epoch: current epoch.
"""
super().build(epoch)
self.epoch_dist = self.dist(dim=self.mask.size, **self.opts)
[docs]
def proposal(
self, key: jax.Array, state: StateMH
) -> tuple[jax.Array, StateMH, jax.Array]:
"""Propose a point by sampling this class distribution.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
qxy = self.epoch_dist.leval(state.p[:, self.mask])
key, prop = self.epoch_dist.sample(key, shape=state.p.shape[:1])
qxy -= self.epoch_dist.leval(prop)
return key, state.set_val(prop, pslc=self.mask), qxy
[docs]
class Gaussian(Direct):
"""
Sampling from a Gaussian distribution (to be completed).
:param mask: proposal dimensions mask (default all space)
"""
def __init__(self, mask: jax.Array | None = None):
super().__init__(mask=mask, dist=d.Normal)