Source code for jexplore.steps.direct

"""
This module defines single chain Metropolis-Hastings
steps with direct proposals: i.e. proposals proposing points
by sampling a given distribution.
"""

from typing import Type

import jax

from jexplore.sampling import EpochMH, StateMH
from jexplore.steps.mh import AllChains
from jexplore.tools import distributions as d


[docs] class Direct(AllChains): r"""Class implementing all chain metropoling hasting step with a proposal directly proposing points sampling from a distribution :param dist: distribution. :param \**opts: options to be passed to the distribution creator. """ dist: Type[d.Distr] """distribution class""" opts: dict """distribution creator options""" epoch_dist: d.Distr """distribution instance for the epoch""" def __init__(self, dist: Type[d.Distr], mask: jax.Array | None = None, **opts): super().__init__(mask=mask) self.dist = dist self.opts = opts
[docs] def build(self, epoch: EpochMH) -> None: """Step epoch initialisation method. This extends method :py:attr:`jexplore.steps.step.Step.build` by instantiating the ditribution object for this epoch. `epoch` and `sampling` attributes. :param epoch: current epoch. """ super().build(epoch) self.epoch_dist = self.dist(dim=self.mask.size, **self.opts)
[docs] def proposal( self, key: jax.Array, state: StateMH ) -> tuple[jax.Array, StateMH, jax.Array]: """Propose a point by sampling this class distribution. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ qxy = self.epoch_dist.leval(state.p[:, self.mask]) key, prop = self.epoch_dist.sample(key, shape=state.p.shape[:1]) qxy -= self.epoch_dist.leval(prop) return key, state.set_val(prop, pslc=self.mask), qxy
[docs] class Uniform(Direct): """ Sampling from a uniform distribution in a box. :param mask: proposal dimensions mask (default all space) :param minval: minimum (inclusive) value broadcast-compatible with shape for the range (default 0). :param maxval: maximum (exclusive) value broadcast-compatible with shape for the range (default 1). """ def __init__( self, mask: jax.Array | None = None, minval: jax.Array | float = 0.0, maxval: jax.Array | float = 1.0, ): super().__init__(mask=mask, dist=d.Uniform, minval=minval, maxval=maxval)
[docs] class Gaussian(Direct): """ Sampling from a Gaussian distribution (to be completed). :param mask: proposal dimensions mask (default all space) """ def __init__(self, mask: jax.Array | None = None): super().__init__(mask=mask, dist=d.Normal)