"""
This module provides the definitions of the base
:py:attr:`jexplore.sampling.state.Epoch` class and of some
ancillary classes and types needed for running
the markov chain sampler.
"""
# TODO: make a covs computation method and override it for Epoch MH so to compute covs by temp
from dataclasses import asdict, dataclass
from typing import Protocol, Self, cast
import jax
import jax.numpy as jnp
import numpy as np
from jexplore.sampling.base import Sampling
from jexplore.sampling.state import State
from jexplore.tools.covariance import compute_cov
# pylint: disable=too-few-public-methods
[docs]
class StepSample[Tstate: State = State](Protocol):
"""Prototype for a step sampling function.
:param jax.Array key: PRNG key
:param Tstate state: current state
:return: the new state and the chain changing stats of the step (nchain).
:rtype: tuple[Tstate, jax.Array]
"""
def __call__(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: ...
[docs]
@dataclass
class EpochCycle[Tstate: State = State]:
"""Class containing the definition of an epoch steps cycle.
:param steps: list of substeps each one defined by a dictionnary of
alternative steps with their weight.
"""
steps: tuple[StepSample[Tstate], ...]
"""tuple of all the steps."""
weights: jax.Array
"""Matrix of weights. This has shape :math:`(S, A_{max})` where
:math:`S` is the number of serial substeps in the cycle and
:math:`A_{max}` is the max number of alternatives in all these
substeps. Each element of the matrix that actually corresponds to
a defined step/alternative contains the weight of the corresponding
alternative."""
isteps: jax.Array
""":math:`(S, A_{max})` matrix. Each element of the matrix that
actually corresponds to a defined step/alternative contains the
index of the corresponding `StepSample` function in the `steps`
tuple."""
n: int
"""number of sequential substeps."""
def __init__(self, steps: list[dict[StepSample[Tstate], float]]):
self.n = len(steps)
_maxnalt = max(len(_step) for _step in steps)
_tot = sum(len(_step) for _step in steps)
self.weights = jnp.zeros((self.n, _maxnalt))
self.isteps = (_tot + 1) * jnp.ones((self.n, _maxnalt), dtype=int)
_steps: tuple[StepSample[Tstate], ...] = ()
_totind = 0
for _istep, _step in enumerate(steps):
for _ialt, (_alt, _wgt) in enumerate(_step.items()):
# pylint: disable=no-member
self.isteps = self.isteps.at[_istep, _ialt].set(len(_steps))
self.weights = self.weights.at[_istep, _ialt].set(_wgt)
_steps += (_alt,)
self.steps = _steps
[docs]
def to_backend(self) -> np.ndarray:
"""dumps the cycle definition into a `numpy` array. This array has a row
for each one of the defined substeps/alternatives containing
* `name`: the qualified name of the `StepSample` function.
* `incycle`: the index of the substep
* `weight`: the weight of the corresponding substep alternative.
"""
return np.array(
[
(
self.steps[_ind].__qualname__,
_stp,
np.array(self.weights)[_stp, _alt],
)
for (_stp, _alt), _ind in np.ndenumerate(self.isteps)
if _ind < len(self.steps)
],
dtype=np.dtype([("name", "<S50"), ("incycle", int), ("weight", float)]),
)
# pylint: disable=too-few-public-methods
[docs]
@dataclass
class EpochStats[Tstate: State = State]:
"""Class statistics the aggregated statistics of the epoch steps.
:param steps: array of steps performed at each iteration.
:param counts: counters of the number of executions of each step.
:param stats: chains acceptance statistics for each step.
"""
steps: jax.Array
r"""
This is :math:`I` sized array Where :math:`I` is the total number of
steps performed during the sampling. The entries of the array are
the index of the corresponding `steps` tuple in the epoch's `EpochCycle`
instance.
"""
counts: jax.Array
"""
this array has the same size as the `steps` tuple in the epoch's `EpochCycle`
instance. Each element contains the count of how many time the corresponding
step was performed during the sampling.
"""
stats: jax.Array
"""
this array has shape :math:`(N, C)` where :math:`C` is the size of the `steps`
tuple in the epoch's `EpochCycle` instance (that is the number of distinct steps)
and :math:`N` the number of chains.
The values correspond to statistics of how many time each chain was modified by
each type of step (i.e. the proposal acceptance rate in the case of MH).
"""
[docs]
@staticmethod
def from_run(
stats: jax.Array, steps: jax.Array, cycle: EpochCycle[Tstate]
) -> "EpochStats[Tstate]":
r"""
Each step returns a :math:`N` sized boolean map of the chains that have
been changed by the step and the indexes identifyng the stap performed.
An epoch run returns the stacking of these arrays for each iteration.
This method aggregate and organize this information and instantiate the
corresponding `EpochStats` object.
:param stats: stacked stats array for all steps in the sampling.
This has shape :math:`(N, I)`.
:param steps: array with indices of the steps performed.
This is a :math:`I` sized array which enries are indices
:math:`a = 0, \dots, A_{max} - 1` identifing which one of
the possible alternatives was actually performed at each step.
:param cycle: `EpochCycle` instance.
:return: the `EpochStats` corresponding to these data.
"""
def _body(_carry, _ind):
_counts, _stats = _carry
_istep = cycle.isteps[_ind % cycle.n, steps[_ind]]
_counts = _counts.at[_istep].add(1)
_stats = _stats.at[_istep, :].add(stats[_ind, :])
return (_counts, _stats), (_istep)
(counts, stats), (steps) = jax.lax.scan(
_body,
(
jnp.zeros(len(cycle.steps)),
jnp.zeros((len(cycle.steps), stats.shape[1])),
),
jnp.arange(len(steps)),
)
return EpochStats[Tstate](steps=steps, stats=stats, counts=counts)
[docs]
def to_backend(self) -> dict:
"""Dump the data of the this `EpochStats` instance, along with the
information from the correponding `EpochCycle` instance, to a
dictionnary in a form suitable for dumping to backend. The output
dictionnary has keys:
* `steps`: `numpy.ndarray` cats of the object `steps` attribute
* `counts`: `numpy.ndarray` cats of the object `counts` attribute
* `stats`: `numpy.ndarray` cats of the object `stats` attribute
* `steps_defs`: data from `EpochCycle` instance as returned by \
:py:attr:`jexplore.sampling.epoch.EpochCycle.to_backend`
"""
return {
"steps": np.array(self.steps),
"counts": np.array(self.counts),
"stats": np.array(self.stats),
}
[docs]
class Epoch[Tstate: State = State, Tsampling: Sampling = Sampling]:
"""Base Epoch definition class.
This encapsulates samples, a covariance matrix the stats object, the definition
of the epoch cycle and the definition of the current sampling. The cass provides
methods to initialize the epoch, get a sepecific state, run steps and instantiate
an epoch from the result of a sampler run.
:param epoch: `Epoch` object from which the new instance can be derived by
computing the covariance and getting the last sample. Alternatively
this can be a dictionnary with the definition of the epoch samples.
:param force_covs: if true forces the recomputing of the covariance matrices.
"""
samples: Tstate
"""
batched state defining the sampling states sequence.
"""
covs: jax.Array # covariance matrices (nchains, dim, dim)
"""
Array with shape :math:`(N, D, D)` containing the list of
covariance matrices of all chains.
"""
stats: EpochStats[Tstate]
"""epoch statistics."""
cycle: EpochCycle[Tstate]
"""definition of the epoch cycle"""
sampling: Tsampling
"""sampling definition"""
statecls: type[Tstate] | type[State] = State
"""state class corresponding to this epoch class"""
def __init__(
self,
epoch: Self | dict,
force_covs: bool = False,
):
if isinstance(epoch, type(self)):
self.covs = epoch.covs
if force_covs:
self.covs = compute_cov(epoch.samples.p, cholesky=False)
self.samples = epoch.samples.get_in_batch(ind=-1, batch=True)
return
if isinstance(epoch, dict):
self.samples = cast(
Tstate, self.statecls.from_dict(epoch, mandatory=["p"], batch=True)
)
self.covs = compute_cov(
self.samples.p, covs=epoch.get("covs", None), cholesky=False
)
else:
raise ValueError("You need to provide an Epoch or a dict.")
[docs]
def complete(self, sampling: Tsampling, compute: dict | None = None) -> Self:
"""Complete the definition and the parameters of the epoch.
:param sampling: sampling parameters.
:param compute: dictionnary defining how to compute the missing parameters.
The keys are the names of the parameters to be computed. Values
are `(func, inpars)` tuples, where `func` is a callable for the
computation of the corresponding parameter and `inpars` is the
list of the names if the input parameters.
:return: this object after update.
"""
self.sampling = sampling
compute = {} if compute is None else compute
# self.samples = jax.vmap(lambda st: st.compute(compute))(self.samples)
for _par, (_func, _inpars) in compute.items():
if getattr(self.samples, _par).shape[1] == 0:
_vals = [getattr(self.samples, _inpar) for _inpar in _inpars]
_flat = [
_val.transpose(0, 2, 1).reshape(-1, _val.shape[1]) for _val in _vals
]
_flat = jax.vmap(_func, in_axes=tuple(0 for _ in _flat))(*_flat)
setattr(
self.samples,
_par,
_flat.reshape(_vals[0].shape[0], -1, _vals[0].shape[2]),
)
return self
[docs]
def set_cycle(self, cycle: EpochCycle[Tstate]) -> None:
"""Sets this epoch cycle.
:param cycle: the `EpochCycle` value to be set.
"""
self.cycle = cycle
[docs]
def run_step(
self, key: jax.Array, ind: int, state: Tstate
) -> tuple[Tstate, jax.Array, jax.Array]:
"""Runs a step of the epoch. For a given substep of the epoch cycle
it draws one of the alternative steps and performs it.
:param key: PRNG key
:param ind: the index of the substep
:param state: current state
:return: the new state, the step stats and the index of the substep
alternative step ctualy performed.
"""
key, k_p = jax.random.split(key)
ialt = jax.random.categorical(k_p, jnp.log(self.cycle.weights[ind, :]))
state, stats = jax.lax.switch(
self.cycle.isteps[ind, ialt], self.cycle.steps, key, state
)
return state, stats, ialt
[docs]
def from_run(self, states: Tstate, stats: jax.Array, steps: jax.Array) -> Self:
"""
Gets the stacks of states, steps statistics and steps indexes from an
epoch run and updates the epoch with these data.
:param states: stack of states.
:param stats: stack of steps stats
:param steps: stack of steps indexes
:return: this object after updating
"""
self.samples = states.__class__(
**{
_key: jnp.stack(_val, axis=2)[:, :, self.cycle.n - 1 :: self.cycle.n]
for _key, _val in asdict(states).items()
}
)
self.stats = EpochStats[Tstate].from_run(stats, steps, self.cycle)
return self
[docs]
def to_backend(self, burn: int, stats: bool) -> tuple[int, int, dict | None]:
"""dumps epoch in backend format"""
nsamples = self.samples.p.shape[-1]
if burn >= nsamples:
return 0, nsamples, None
epoch = {
"covs": np.array(self.covs),
"samples": {
_nam: np.array(_vals)[:, :, burn:]
for _nam, _vals in asdict(self.samples).items()
},
"cycle": self.cycle.to_backend(),
"sampling": self.sampling.to_backend(),
"class": f"{self.__class__.__module__}.{self.__class__.__qualname__}",
}
if stats:
epoch["stats"] = self.stats.to_backend()
return nsamples - burn, burn, epoch