Source code for jexplore.sampling.mh

"""
This module contains the definitions of classes and types
for running a Metropolis-Hastings  markov sampler.
"""

from dataclasses import dataclass
from typing import Self

import jax
import numpy as np

from .base import Sampling
from .epoch import Epoch
from .space import Space
from .state import ArrayFn, State


[docs] @jax.tree_util.register_dataclass @dataclass class StateMH(State): """This class provides the definition of a markovc chain state for a Metropolis-Hastings markov sampler. With respect to the parent :py:attr:`jexplore.sampling.state.State` class, it simply adds log likelihood and log prior values to the state attributes (operations on these parameters are handled by the parent class methods). :param p: status point (nchains, dim) :param ll: log likelihood values for each chain point (nchains, 1). :param lp: log prior values for each chain point (nchains, 1). """ ll: jax.Array """log likelihood values""" lp: jax.Array """log prior values."""
[docs] @dataclass class SamplingMH[Tspace: Space](Sampling[Tspace]): """Markov sampling defining parameters for Metropolis-Hastings sampling. With respect to the parent :py:attr:`jexplore.sampling.base.Sampling` class the definition of the sampling includes the following parameters :param nwalker: number of walkers per temperature. :param temps: temperature ladder :param loglik: log likelihood function :param logprior: log prior function The `nchain` parameter is computed from `nwalker` and `temps`. """ nwalker: int """number of walkers""" temps: jax.Array """temperature ladder""" loglik: ArrayFn """log likelihood function""" logprior: ArrayFn """log prior function""" inpars: list[str] """list of input parameters of loglik and logprior""" # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, nwalker: int, temps: jax.Array, loglik: ArrayFn, logprior: ArrayFn, dim: int | None = None, space: Tspace | None = None, inpars: list[str] | None = None, ): self.inpars = ["p"] if inpars is None else inpars self.nwalker = nwalker self.temps = temps self.loglik = loglik self.logprior = logprior super().__init__(self.nwalker * self.temps.shape[0], dim, space)
[docs] def to_backend(self) -> dict: """Save the attributes of the object in backend format.""" return super().to_backend() | { "nwalker": np.array(self.nwalker), "temps": np.array(self.temps), }
[docs] def get_sampler(self, steps=None, backend=None): """Return a sampler with default steps and backend. :param steps: list of `Step`-like instances. If None, use strech and tswap (if more than 1 temp). :param backend: a `Backend` instance. If None, use the default one. :return: a `JaxSampler` instance. """ # pylint: disable=import-outside-toplevel, cyclic-import from jexplore.backends import DefaultBackend from jexplore.sampler import JaxSampler, Steps from jexplore.steps import Stretch, TSwap if steps is None: steps = [ {TSwap(permute=True).builder: 1.0}, {Stretch(permute=True).builder: 1.0}, ] if backend is None: backend = DefaultBackend() backend.reset() return JaxSampler(self, Steps(steps), backend)
[docs] def get_epoch(self, p): """Return an `Epoch` object from starting sample, to be use by MH sampler as initial epoch. :params p: samples array of shape nwalker x ntemp, dim :return: a `EpochMH` instance """ assert p.shape == (self.nwalker * len(self.temps), self.dim) return EpochMH({"p": p})
[docs] class EpochMH[Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH]( Epoch[Tstate, Tsampling] ): """Epoch definition class for Metropolis-Hastings. With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class, this class define :py:attr:`jexplore.sampling.mh.StateMH` as default state class and overrides the `complete` method to also compute the log likelihood and log prior of the epoch points. :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. """ statecls: type[Tstate] | type[StateMH] = StateMH """State class"""
[docs] def complete(self, sampling: Tsampling, compute: dict | None = None) -> Self: """Complete the definition and the parameters of the epoch. :param sampling: sampling parameters. :param compute: dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are `(func, inpars)` tuples, where `func` is a callable for the computation of the corresponding parameter and `inpars` is the list of the names if the input parameters. With respect to :py:attr:`jexplore.sampling.epoch.Epoch.complete`, this implementation of the `complete` method provides a base for the `compute` parameter with the instructions to compute log likelihood and log prior epoch samples attributes. :return: this object after update. """ compute = {} if compute is None else compute compute = { "ll": (sampling.loglik, sampling.inpars), "lp": (sampling.logprior, sampling.inpars), } | compute return super().complete(sampling, compute)