Source code for jexplore.sampling.epoch

"""
This module provides the definitions of the base
:py:attr:`jexplore.sampling.state.Epoch` class and of some
ancillary classes and types needed for running
the markov chain sampler.
"""

# TODO: make a covs computation method and override it for Epoch MH so to compute covs by temp

from dataclasses import asdict, dataclass
from typing import Protocol, Self, cast

import jax
import jax.numpy as jnp
import numpy as np

from jexplore.sampling.base import Sampling
from jexplore.sampling.state import State
from jexplore.tools.covariance import compute_cov


# pylint: disable=too-few-public-methods
[docs] class StepSample[Tstate: State = State](Protocol): """Prototype for a step sampling function. :param jax.Array key: PRNG key :param Tstate state: current state :return: the new state and the chain changing stats of the step (nchain). :rtype: tuple[Tstate, jax.Array] """ def __call__(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: ...
[docs] @dataclass class EpochCycle[Tstate: State = State]: """Class containing the definition of an epoch steps cycle. :param steps: list of substeps each one defined by a dictionnary of alternative steps with their weight. """ steps: tuple[StepSample[Tstate], ...] """tuple of all the steps.""" weights: jax.Array """Matrix of weights. This has shape :math:`(S, A_{max})` where :math:`S` is the number of serial substeps in the cycle and :math:`A_{max}` is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative.""" isteps: jax.Array """:math:`(S, A_{max})` matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding `StepSample` function in the `steps` tuple.""" n: int """number of sequential substeps.""" def __init__(self, steps: list[dict[StepSample[Tstate], float]]): self.n = len(steps) _maxnalt = max(len(_step) for _step in steps) _tot = sum(len(_step) for _step in steps) self.weights = jnp.zeros((self.n, _maxnalt)) self.isteps = (_tot + 1) * jnp.ones((self.n, _maxnalt), dtype=int) _steps: tuple[StepSample[Tstate], ...] = () _totind = 0 for _istep, _step in enumerate(steps): for _ialt, (_alt, _wgt) in enumerate(_step.items()): # pylint: disable=no-member self.isteps = self.isteps.at[_istep, _ialt].set(len(_steps)) self.weights = self.weights.at[_istep, _ialt].set(_wgt) _steps += (_alt,) self.steps = _steps
[docs] def to_backend(self) -> np.ndarray: """dumps the cycle definition into a `numpy` array. This array has a row for each one of the defined substeps/alternatives containing * `name`: the qualified name of the `StepSample` function. * `incycle`: the index of the substep * `weight`: the weight of the corresponding substep alternative. """ return np.array( [ ( self.steps[_ind].__qualname__, _stp, np.array(self.weights)[_stp, _alt], ) for (_stp, _alt), _ind in np.ndenumerate(self.isteps) if _ind < len(self.steps) ], dtype=np.dtype([("name", "<S50"), ("incycle", int), ("weight", float)]), )
# pylint: disable=too-few-public-methods
[docs] @dataclass class EpochStats[Tstate: State = State]: """Class statistics the aggregated statistics of the epoch steps. :param steps: array of steps performed at each iteration. :param counts: counters of the number of executions of each step. :param stats: chains acceptance statistics for each step. """ steps: jax.Array r""" This is :math:`I` sized array Where :math:`I` is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding `steps` tuple in the epoch's `EpochCycle` instance. """ counts: jax.Array """ this array has the same size as the `steps` tuple in the epoch's `EpochCycle` instance. Each element contains the count of how many time the corresponding step was performed during the sampling. """ stats: jax.Array """ this array has shape :math:`(N, C)` where :math:`C` is the size of the `steps` tuple in the epoch's `EpochCycle` instance (that is the number of distinct steps) and :math:`N` the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH). """
[docs] @staticmethod def from_run( stats: jax.Array, steps: jax.Array, cycle: EpochCycle[Tstate] ) -> "EpochStats[Tstate]": r""" Each step returns a :math:`N` sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding `EpochStats` object. :param stats: stacked stats array for all steps in the sampling. This has shape :math:`(N, I)`. :param steps: array with indices of the steps performed. This is a :math:`I` sized array which enries are indices :math:`a = 0, \dots, A_{max} - 1` identifing which one of the possible alternatives was actually performed at each step. :param cycle: `EpochCycle` instance. :return: the `EpochStats` corresponding to these data. """ def _body(_carry, _ind): _counts, _stats = _carry _istep = cycle.isteps[_ind % cycle.n, steps[_ind]] _counts = _counts.at[_istep].add(1) _stats = _stats.at[_istep, :].add(stats[_ind, :]) return (_counts, _stats), (_istep) (counts, stats), (steps) = jax.lax.scan( _body, ( jnp.zeros(len(cycle.steps)), jnp.zeros((len(cycle.steps), stats.shape[1])), ), jnp.arange(len(steps)), ) return EpochStats[Tstate](steps=steps, stats=stats, counts=counts)
[docs] def to_backend(self) -> dict: """Dump the data of the this `EpochStats` instance, along with the information from the correponding `EpochCycle` instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys: * `steps`: `numpy.ndarray` cats of the object `steps` attribute * `counts`: `numpy.ndarray` cats of the object `counts` attribute * `stats`: `numpy.ndarray` cats of the object `stats` attribute * `steps_defs`: data from `EpochCycle` instance as returned by \ :py:attr:`jexplore.sampling.epoch.EpochCycle.to_backend` """ return { "steps": np.array(self.steps), "counts": np.array(self.counts), "stats": np.array(self.stats), }
[docs] class Epoch[Tstate: State = State, Tsampling: Sampling = Sampling]: """Base Epoch definition class. This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run. :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. """ samples: Tstate """ batched state defining the sampling states sequence. """ covs: jax.Array # covariance matrices (nchains, dim, dim) """ Array with shape :math:`(N, D, D)` containing the list of covariance matrices of all chains. """ stats: EpochStats[Tstate] """epoch statistics.""" cycle: EpochCycle[Tstate] """definition of the epoch cycle""" sampling: Tsampling """sampling definition""" statecls: type[Tstate] | type[State] = State """state class corresponding to this epoch class""" def __init__( self, epoch: Self | dict, force_covs: bool = False, ): if isinstance(epoch, type(self)): self.covs = epoch.covs if force_covs: self.covs = compute_cov(epoch.samples.p, cholesky=False) self.samples = epoch.samples.get_in_batch(ind=-1, batch=True) return if isinstance(epoch, dict): self.samples = cast( Tstate, self.statecls.from_dict(epoch, mandatory=["p"], batch=True) ) self.covs = compute_cov( self.samples.p, covs=epoch.get("covs", None), cholesky=False ) else: raise ValueError("You need to provide an Epoch or a dict.")
[docs] def complete(self, sampling: Tsampling, compute: dict | None = None) -> Self: """Complete the definition and the parameters of the epoch. :param sampling: sampling parameters. :param compute: dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are `(func, inpars)` tuples, where `func` is a callable for the computation of the corresponding parameter and `inpars` is the list of the names if the input parameters. :return: this object after update. """ self.sampling = sampling compute = {} if compute is None else compute # self.samples = jax.vmap(lambda st: st.compute(compute))(self.samples) for _par, (_func, _inpars) in compute.items(): if getattr(self.samples, _par).shape[1] == 0: _vals = [getattr(self.samples, _inpar) for _inpar in _inpars] _flat = [ _val.transpose(0, 2, 1).reshape(-1, _val.shape[1]) for _val in _vals ] _flat = jax.vmap(_func, in_axes=tuple(0 for _ in _flat))(*_flat) setattr( self.samples, _par, _flat.reshape(_vals[0].shape[0], -1, _vals[0].shape[2]), ) return self
[docs] def set_cycle(self, cycle: EpochCycle[Tstate]) -> None: """Sets this epoch cycle. :param cycle: the `EpochCycle` value to be set. """ self.cycle = cycle
[docs] def run_step( self, key: jax.Array, ind: int, state: Tstate ) -> tuple[Tstate, jax.Array, jax.Array]: """Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it. :param key: PRNG key :param ind: the index of the substep :param state: current state :return: the new state, the step stats and the index of the substep alternative step ctualy performed. """ key, k_p = jax.random.split(key) ialt = jax.random.categorical(k_p, jnp.log(self.cycle.weights[ind, :])) state, stats = jax.lax.switch( self.cycle.isteps[ind, ialt], self.cycle.steps, key, state ) return state, stats, ialt
[docs] def from_run(self, states: Tstate, stats: jax.Array, steps: jax.Array) -> Self: """ Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data. :param states: stack of states. :param stats: stack of steps stats :param steps: stack of steps indexes :return: this object after updating """ self.samples = states.__class__( **{ _key: jnp.stack(_val, axis=2)[:, :, self.cycle.n - 1 :: self.cycle.n] for _key, _val in asdict(states).items() } ) self.stats = EpochStats[Tstate].from_run(stats, steps, self.cycle) return self
[docs] def to_backend(self, burn: int, stats: bool) -> tuple[int, int, dict | None]: """dumps epoch in backend format""" nsamples = self.samples.p.shape[-1] if burn >= nsamples: return 0, nsamples, None epoch = { "covs": np.array(self.covs), "samples": { _nam: np.array(_vals)[:, :, burn:] for _nam, _vals in asdict(self.samples).items() }, "cycle": self.cycle.to_backend(), "sampling": self.sampling.to_backend(), "class": f"{self.__class__.__module__}.{self.__class__.__qualname__}", } if stats: epoch["stats"] = self.stats.to_backend() return nsamples - burn, burn, epoch