"""
This module contains the definitions of classes and types
for running a Metropolis-Hastings markov sampler.
"""
from dataclasses import dataclass
from typing import Self
import jax
import numpy as np
from .base import Sampling
from .epoch import Epoch
from .space import Space
from .state import ArrayFn, State
[docs]
@jax.tree_util.register_dataclass
@dataclass
class StateMH(State):
"""This class provides the definition of a markovc chain state for a
Metropolis-Hastings markov sampler. With respect to the parent
:py:attr:`jexplore.sampling.state.State` class, it simply adds
log likelihood and log prior values to the state attributes (operations
on these parameters are handled by the parent class methods).
:param p: status point (nchains, dim)
:param ll: log likelihood values for each chain point (nchains, 1).
:param lp: log prior values for each chain point (nchains, 1).
"""
ll: jax.Array
"""log likelihood values"""
lp: jax.Array
"""log prior values."""
[docs]
@dataclass
class SamplingMH[Tspace: Space](Sampling[Tspace]):
"""Markov sampling defining parameters for Metropolis-Hastings
sampling.
With respect to the parent :py:attr:`jexplore.sampling.base.Sampling`
class the definition of the sampling includes the following parameters
:param nwalker: number of walkers per temperature.
:param temps: temperature ladder
:param loglik: log likelihood function
:param logprior: log prior function
The `nchain` parameter is computed from `nwalker` and `temps`.
"""
nwalker: int
"""number of walkers"""
temps: jax.Array
"""temperature ladder"""
loglik: ArrayFn
"""log likelihood function"""
logprior: ArrayFn
"""log prior function"""
inpars: list[str]
"""list of input parameters of loglik and logprior"""
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
nwalker: int,
temps: jax.Array,
loglik: ArrayFn,
logprior: ArrayFn,
dim: int | None = None,
space: Tspace | None = None,
inpars: list[str] | None = None,
):
self.inpars = ["p"] if inpars is None else inpars
self.nwalker = nwalker
self.temps = temps
self.loglik = loglik
self.logprior = logprior
super().__init__(self.nwalker * self.temps.shape[0], dim, space)
[docs]
def to_backend(self) -> dict:
"""Save the attributes of the object in
backend format."""
return super().to_backend() | {
"nwalker": np.array(self.nwalker),
"temps": np.array(self.temps),
}
[docs]
def get_sampler(self, steps=None, backend=None):
"""Return a sampler with default steps and backend.
:param steps: list of `Step`-like instances. If None, use strech
and tswap (if more than 1 temp).
:param backend: a `Backend` instance. If None, use the default one.
:return: a `JaxSampler` instance.
"""
# pylint: disable=import-outside-toplevel, cyclic-import
from jexplore.backends import DefaultBackend
from jexplore.sampler import JaxSampler, Steps
from jexplore.steps import Stretch, TSwap
if steps is None:
steps = [
{TSwap(permute=True).builder: 1.0},
{Stretch(permute=True).builder: 1.0},
]
if backend is None:
backend = DefaultBackend()
backend.reset()
return JaxSampler(self, Steps(steps), backend)
[docs]
def get_epoch(self, p):
"""Return an `Epoch` object from starting sample, to be use by MH
sampler as initial epoch.
:params p: samples array of shape nwalker x ntemp, dim
:return: a `EpochMH` instance
"""
assert p.shape == (self.nwalker * len(self.temps), self.dim)
return EpochMH({"p": p})
[docs]
class EpochMH[Tstate: StateMH = StateMH, Tsampling: SamplingMH = SamplingMH](
Epoch[Tstate, Tsampling]
):
"""Epoch definition class for Metropolis-Hastings.
With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class,
this class define :py:attr:`jexplore.sampling.mh.StateMH` as default state
class and overrides the `complete` method to also compute the log likelihood
and log prior of the epoch points.
:param epoch: `Epoch` object from which the new instance can be derived by
computing the covariance and getting the last sample. Alternatively
this can be a dictionnary with the definition of the epoch samples.
:param force_covs: if true forces the recomputing of the covariance matrices.
"""
statecls: type[Tstate] | type[StateMH] = StateMH
"""State class"""
[docs]
def complete(self, sampling: Tsampling, compute: dict | None = None) -> Self:
"""Complete the definition and the parameters of the epoch.
:param sampling: sampling parameters.
:param compute: dictionnary defining how to compute the missing parameters.
The keys are the names of the parameters to be computed. Values
are `(func, inpars)` tuples, where `func` is a callable for the
computation of the corresponding parameter and `inpars` is the
list of the names if the input parameters.
With respect to :py:attr:`jexplore.sampling.epoch.Epoch.complete`, this
implementation of the `complete` method provides a base for the `compute`
parameter with the instructions to compute log likelihood and log prior
epoch samples attributes.
:return: this object after update.
"""
compute = {} if compute is None else compute
compute = {
"ll": (sampling.loglik, sampling.inpars),
"lp": (sampling.logprior, sampling.inpars),
} | compute
return super().complete(sampling, compute)