"""
This module contains the definitions of classes and types
for running a Markov chains in prod space model selection.
"""
from dataclasses import dataclass
from typing import cast
import jax
import jax.numpy as jnp
import numpy as np
from jexplore.tools.distributions import Distr, DrawFn
from .base import Sampling
from .mh import EpochMH, SamplingMH, StateMH
from .space import Space
from .state import ArrayFn
[docs]
@jax.tree_util.register_dataclass
@dataclass
class StateMS(StateMH):
"""Definition of a markov chain state for model selection.
With respect to the parent :py:attr:`jexplore.sampling.mh.StateMH` class,
it adds the model index `k` for each chain (operations on these parameters
are handled by the base class methods).
:param k: space index (nchains, 1)
:param p: status point (nchains, dim)
:param ll: log likelihood values for each chain point (nchains, 1).
:param lp: log prior values for each chain point (nchains, 1).
"""
k: jax.Array
"""model index"""
# pylint: disable=too-many-instance-attributes
[docs]
@dataclass
class SamplingMS[Tspace: Space](SamplingMH[Tspace]):
r"""Markov sampling defining parameters for model selection.
With respect to the parent :py:attr:`jexplore.sampling.base.SamplingMH`
class the definition of the sampling includes the following parameters
:param nwalker: number of walkers per temperature.
:param temps: temperature ladder
:param masks: list of masks identifying the coordinates of the space
of each one of the models. This must be a list which length
is the number of models. Elements of the list can be:
#. `None`: all the space coordinates
#. integer `d`: first `d` dimension of the space
#. tuple (`a`, `b`): defining the correponding coordinates slice
#. list of integer: explicit list of coordinates indexes
:param loglik: log likelihood function :math:`logL(k, p)`, a list of math:`logL(p)`
for each model (in which case :math:`logL(k, p)` is defined by branching
on the functions of the list) or `None` (in which case :math:`logL(k, p)`
is defined by selecting one component of the all models log likelihood
defined by `allmodsll` - thus with a speculative computing approach).
:param allmodsll: function :math:`vlogL(p)` returning the values of the log likelihood
for all models at the point :math:`p` or `None`. In the latter case
it is defined by stacking the return values of :math:`logL(k, p)`
defined by `loglik` for all possible `k`. Note: `allmodsll` and
`loglik` cannot be both `None`.
:param logprior: log prior function :math:`log \pi(k, p)` or list of math:`log \pi(p)`
for each model. (in which case :math:`log \pi(k, p)` is defined by branching
on the functions of the list) or `None` (in which case :math:`log \pi(k, p)`
is defined by selecting one component of the all models log prior
defined by `allmodslp` - thus with a speculative computing approach).
:param allmodslp: function :math:`vlog \pi(p)` returning the values of the log prior
for all models at the point :math:`p` or `None`. In the latter case
it is defined by stacking the return values of :math:`log \pi(k, p)`
defined by `logprior` for all possible `k`. Note: `allmodslp` and
`logprior` cannot be both `None`.
:param pseudo: list of pseudo prior distributions for all the models. Each element of
the list can be:
#. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the
complementary space of the corresponding model. In which case is taken
as it is.
#. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the
full space. In which case this is masked to match the model mask.
#. If it is `None`, of if the model takes tho whole space, a 0. centered
delta distribution is used (and the evaluation returns identically 0.).
If the whole argument is None (Default), a list of None elements is assumed
and the distribution evaluation is identically 0.
:param dim: dimensionality of the full space.
:param space: :py:attr:`jexplore.sampling.space` object
describing the target space.
:param inpars: list of input parameter(s) used for the computation of log
likelihood and log prior.
"""
masks: list[jax.Array]
"""mask of parameters indexes of each model space"""
cmasks: list[jax.Array]
"""complementary mask of parameters indexes of each model space"""
mod_update: ArrayFn
""":math:`f(k, a, b)` updating array `a` of shapce `self.dim` with
an array `b` of the same shape but only on the components of the mask
of model `k`"""
vmod_update: ArrayFn
"""vectorized version of `self.mod_update` so that it can perform
the same update operation on vectors `k`, `a` and `b` of shapes
`(seld.nwalker, )`, `(seld.nwalker, self.dim)` and `(seld.nwalker, self.dim)`
"""
allmodsll: ArrayFn
r"""Function :math:`F(p)` returning the log likelihood for all models"""
allmodslp: ArrayFn
r"""Function :math:`\pi(p)` returning the log prior for all models"""
ppleval: ArrayFn
r"""Log pseudo prior function :math:`\tilde{\pi}(k, p)`"""
allmodspp: ArrayFn
r"""Function :math:`\tilde{\pi}(p)` returning the log pseudo prior for all models"""
ppdraw: list[DrawFn]
r"""List of pseudo prior drawing functions"""
nmodels: int
"""number of models"""
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
nwalker: int,
temps: jax.Array,
masks: list[list[int] | None | int | tuple[int, int]],
loglik: ArrayFn | list[ArrayFn] | None = None,
logprior: ArrayFn | list[ArrayFn] | None = None,
allmodsll: ArrayFn | None = None,
allmodslp: ArrayFn | None = None,
pseudo: list[Distr | None] | None = None,
dim: int | None = None,
space: Tspace | None = None,
inpars: list[str] | None = None,
):
_dim = Sampling(1, dim, space).dim
self.masks, self.cmasks = self.get_masks(masks, _dim)
self.mod_update, self.vmod_update = self._get_updates()
self.nmodels = len(self.masks)
loglik, self.allmodsll = self._get_ldist(loglik, allmodsll)
logprior, self.allmodslp = self._get_ldist(logprior, allmodslp)
self._get_pseudo(pseudo, _dim)
inpars = ["k", "p"] if inpars is None else inpars
super().__init__(nwalker, temps, loglik, logprior, dim, space, inpars)
[docs]
def to_backend(self) -> dict:
"""Save the attributes of the object in
backend format."""
return super().to_backend() | {
"masks": self.masks,
"cmasks": self.cmasks,
"nmodels": np.array(self.nmodels, dtype=int),
}
def _get_ldist(
self, permodd: ArrayFn | list[ArrayFn] | None, allmodsd: ArrayFn | None
) -> tuple[ArrayFn, ArrayFn]:
"""Gets the per-model and the all-models log likelihood and log prior
functions.
"""
if (permodd is None) and (allmodsd is None):
raise ValueError(
"One among the per model and the all models log distributions should be defined"
)
# We get a list of single model functions. Using this to build the per-model
# and the all-models distributions
if (allmodsd is None) and isinstance(permodd, list):
return self.ldist_from_list(permodd, self.masks), self.ldist_from_list(
permodd, self.masks, stack=True
)
# We get a :math:`F(k, p)` per model function. Using it to define the
# all-models function.
if allmodsd is None:
_mods = jnp.arange(len(self.masks))
return cast(ArrayFn, permodd), cast(
ArrayFn,
lambda p: jax.vmap(lambda k: cast(ArrayFn, permodd)(k, p))(_mods),
)
# We only have the all-models function. Define the per model function with
# a speculative approach by taking one component depending on :math:`k`
if permodd is None:
return cast(ArrayFn, lambda k, p: jnp.take(allmodsd(p), k)), allmodsd
# We already have everything, just returning.
return cast(ArrayFn, permodd), cast(ArrayFn, allmodsd)
def _get_pseudo(self, pseudo: list[Distr | None] | None, dim: int):
pseudo = [None for _ in self.masks] if pseudo is None else pseudo
self.ppdraw = []
for _ind, _pp in enumerate(pseudo):
mask = self.cmasks[_ind]
# If no pseudo is defined we use a 0 delta distr
if (mask.shape[0] == 0) or (_pp is None):
self.ppdraw.append(
cast(DrawFn, lambda k, s, d=mask.size: (k, jnp.zeros(s + (d,))))
)
continue
# if the pseudo is of the same size of the mask we get it a it is
if _pp.dim == mask.size:
self.ppdraw.append(cast(DrawFn, lambda k, s, pp=_pp: pp.sample(k, s)))
continue
# If it is the full space (and the mask is not) we restrict it to
#
if _pp.dim == dim:
def _masked_draw(k, s, pp=_pp, msk=mask):
_k, _p = pp.sample(k, s)
return _k, _p[:, msk]
self.ppdraw.append(_masked_draw)
continue
raise ValueError(
"pseudo distributions should be none, defined on the "
+ "masked space or on the whole space."
)
distlist = [
(
cast(ArrayFn, _pp.leval)
if (_pp is not None) and (self.cmasks[_ipp].shape[0] > 0)
else cast(ArrayFn, lambda p: jnp.zeros(p.shape[0]))
)
for _ipp, _pp in enumerate(pseudo)
]
self.ppleval = self.ldist_from_list(distlist, self.cmasks)
self.allmodspp = self.ldist_from_list(distlist, self.cmasks, stack=True)
[docs]
@staticmethod
def get_masks(
masks: list[list[int] | None | int | tuple[int, int]], dim: int
) -> tuple[list[jax.Array], list[jax.Array]]:
"""Get model masks as lists of indices.
:param masks: list of masks identifying the coordinates of the space
of each one of the models. This must be a list which length
is the number of models. Elements of the list can be:
#. `None`: all the space coordinates
#. integer `d`: first `d` dimension of the space
#. tuple (`a`, `b`): defining the correponding coordinates slice
#. list of integer: explicit list of coordinates indexes
:param dim: full space dimension.
:return: list of model masks (as array of indices) and list of models
complementary masks.
"""
for _ind, _msk in enumerate(masks):
if _msk is None:
_msk = (0, dim)
if isinstance(_msk, int):
_msk = (0, _msk)
if isinstance(_msk, tuple):
masks[_ind] = cast(list[int], list(range(*_msk)))
return [jnp.array(_msk, dtype=int) for _msk in cast(list[list[int]], masks)], [
jnp.array([_ind for _ind in range(dim) if _ind not in _msk], dtype=int)
for _msk in cast(list[list[int]], masks)
]
[docs]
@staticmethod
def ldist_from_list(
logdists: list[ArrayFn], masks: list[jax.Array], stack: bool = False
) -> ArrayFn:
"""
Defines a product space log distribution from a list of a single
model distributions.
:param logdists: list of single space distributions :math:`log P(p)`
:param masks: mask of parameters indexes of each model space
:param stack: if true defines the function returning the values for
all models by stacking the results of all the functions
in the list.
:return: full product space log distribution.
"""
if len(masks) != len(logdists):
raise ValueError(
"Number of masks and size of the distributions lists should be the same."
)
logdists = [
(
cast(ArrayFn, lambda p, ld=_ld, msk=masks[_ind]: ld(p[msk]))
if masks[_ind].size > 0
else cast(ArrayFn, lambda p: jnp.array(0.0))
)
for _ind, _ld in enumerate(logdists)
]
if stack:
return cast(ArrayFn, lambda p: jnp.stack([_ld(p) for _ld in logdists]))
return cast(ArrayFn, lambda k, p: jax.lax.switch(k.squeeze(-1), logdists, p))
def _get_updates(self) -> tuple[ArrayFn, ArrayFn]:
masked = [
cast(ArrayFn, lambda a, b, m=_msk: a.at[m].set(b[m])) for _msk in self.masks
]
def _func(k, a, b):
return jax.lax.switch(k.squeeze(-1), masked, a, b)
return cast(ArrayFn, _func), cast(ArrayFn, jax.vmap(_func, in_axes=(0, 0, 0)))
[docs]
class EpochMS[Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS](
EpochMH[Tstate, Tsampling]
):
"""Epoch definition class for model selection.
With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class,
this class define :py:attr:`jexplore.sampling.mh.StateMS` as default state
class and overrides the `complete` method to also compute the log likelihood
and log prior of the epoch points (with the model selection signature).
:param epoch: `Epoch` object from which the new instance can be derived by
computing the covariance and getting the last sample. Alternatively
this can be a dictionnary with the definition of the epoch samples.
:param force_covs: if true forces the recomputing of the covariance matrices.
"""
statecls = StateMS
"""State class"""