Source code for jexplore.sampling.modsel

"""
This module contains the definitions of classes and types
for running a Markov chains in prod space model selection.
"""

from dataclasses import dataclass
from typing import cast

import jax
import jax.numpy as jnp
import numpy as np

from jexplore.tools.distributions import Distr, DrawFn

from .base import Sampling
from .mh import EpochMH, SamplingMH, StateMH
from .space import Space
from .state import ArrayFn


[docs] @jax.tree_util.register_dataclass @dataclass class StateMS(StateMH): """Definition of a markov chain state for model selection. With respect to the parent :py:attr:`jexplore.sampling.mh.StateMH` class, it adds the model index `k` for each chain (operations on these parameters are handled by the base class methods). :param k: space index (nchains, 1) :param p: status point (nchains, dim) :param ll: log likelihood values for each chain point (nchains, 1). :param lp: log prior values for each chain point (nchains, 1). """ k: jax.Array """model index"""
# pylint: disable=too-many-instance-attributes
[docs] @dataclass class SamplingMS[Tspace: Space](SamplingMH[Tspace]): r"""Markov sampling defining parameters for model selection. With respect to the parent :py:attr:`jexplore.sampling.base.SamplingMH` class the definition of the sampling includes the following parameters :param nwalker: number of walkers per temperature. :param temps: temperature ladder :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param loglik: log likelihood function :math:`logL(k, p)`, a list of math:`logL(p)` for each model (in which case :math:`logL(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`logL(k, p)` is defined by selecting one component of the all models log likelihood defined by `allmodsll` - thus with a speculative computing approach). :param allmodsll: function :math:`vlogL(p)` returning the values of the log likelihood for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`logL(k, p)` defined by `loglik` for all possible `k`. Note: `allmodsll` and `loglik` cannot be both `None`. :param logprior: log prior function :math:`log \pi(k, p)` or list of math:`log \pi(p)` for each model. (in which case :math:`log \pi(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`log \pi(k, p)` is defined by selecting one component of the all models log prior defined by `allmodslp` - thus with a speculative computing approach). :param allmodslp: function :math:`vlog \pi(p)` returning the values of the log prior for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`log \pi(k, p)` defined by `logprior` for all possible `k`. Note: `allmodslp` and `logprior` cannot be both `None`. :param pseudo: list of pseudo prior distributions for all the models. Each element of the list can be: #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the complementary space of the corresponding model. In which case is taken as it is. #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the full space. In which case this is masked to match the model mask. #. If it is `None`, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.). If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0. :param dim: dimensionality of the full space. :param space: :py:attr:`jexplore.sampling.space` object describing the target space. :param inpars: list of input parameter(s) used for the computation of log likelihood and log prior. """ masks: list[jax.Array] """mask of parameters indexes of each model space""" cmasks: list[jax.Array] """complementary mask of parameters indexes of each model space""" mod_update: ArrayFn """:math:`f(k, a, b)` updating array `a` of shapce `self.dim` with an array `b` of the same shape but only on the components of the mask of model `k`""" vmod_update: ArrayFn """vectorized version of `self.mod_update` so that it can perform the same update operation on vectors `k`, `a` and `b` of shapes `(seld.nwalker, )`, `(seld.nwalker, self.dim)` and `(seld.nwalker, self.dim)` """ allmodsll: ArrayFn r"""Function :math:`F(p)` returning the log likelihood for all models""" allmodslp: ArrayFn r"""Function :math:`\pi(p)` returning the log prior for all models""" ppleval: ArrayFn r"""Log pseudo prior function :math:`\tilde{\pi}(k, p)`""" allmodspp: ArrayFn r"""Function :math:`\tilde{\pi}(p)` returning the log pseudo prior for all models""" ppdraw: list[DrawFn] r"""List of pseudo prior drawing functions""" nmodels: int """number of models""" # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, nwalker: int, temps: jax.Array, masks: list[list[int] | None | int | tuple[int, int]], loglik: ArrayFn | list[ArrayFn] | None = None, logprior: ArrayFn | list[ArrayFn] | None = None, allmodsll: ArrayFn | None = None, allmodslp: ArrayFn | None = None, pseudo: list[Distr | None] | None = None, dim: int | None = None, space: Tspace | None = None, inpars: list[str] | None = None, ): _dim = Sampling(1, dim, space).dim self.masks, self.cmasks = self.get_masks(masks, _dim) self.mod_update, self.vmod_update = self._get_updates() self.nmodels = len(self.masks) loglik, self.allmodsll = self._get_ldist(loglik, allmodsll) logprior, self.allmodslp = self._get_ldist(logprior, allmodslp) self._get_pseudo(pseudo, _dim) inpars = ["k", "p"] if inpars is None else inpars super().__init__(nwalker, temps, loglik, logprior, dim, space, inpars)
[docs] def to_backend(self) -> dict: """Save the attributes of the object in backend format.""" return super().to_backend() | { "masks": self.masks, "cmasks": self.cmasks, "nmodels": np.array(self.nmodels, dtype=int), }
def _get_ldist( self, permodd: ArrayFn | list[ArrayFn] | None, allmodsd: ArrayFn | None ) -> tuple[ArrayFn, ArrayFn]: """Gets the per-model and the all-models log likelihood and log prior functions. """ if (permodd is None) and (allmodsd is None): raise ValueError( "One among the per model and the all models log distributions should be defined" ) # We get a list of single model functions. Using this to build the per-model # and the all-models distributions if (allmodsd is None) and isinstance(permodd, list): return self.ldist_from_list(permodd, self.masks), self.ldist_from_list( permodd, self.masks, stack=True ) # We get a :math:`F(k, p)` per model function. Using it to define the # all-models function. if allmodsd is None: _mods = jnp.arange(len(self.masks)) return cast(ArrayFn, permodd), cast( ArrayFn, lambda p: jax.vmap(lambda k: cast(ArrayFn, permodd)(k, p))(_mods), ) # We only have the all-models function. Define the per model function with # a speculative approach by taking one component depending on :math:`k` if permodd is None: return cast(ArrayFn, lambda k, p: jnp.take(allmodsd(p), k)), allmodsd # We already have everything, just returning. return cast(ArrayFn, permodd), cast(ArrayFn, allmodsd) def _get_pseudo(self, pseudo: list[Distr | None] | None, dim: int): pseudo = [None for _ in self.masks] if pseudo is None else pseudo self.ppdraw = [] for _ind, _pp in enumerate(pseudo): mask = self.cmasks[_ind] # If no pseudo is defined we use a 0 delta distr if (mask.shape[0] == 0) or (_pp is None): self.ppdraw.append( cast(DrawFn, lambda k, s, d=mask.size: (k, jnp.zeros(s + (d,)))) ) continue # if the pseudo is of the same size of the mask we get it a it is if _pp.dim == mask.size: self.ppdraw.append(cast(DrawFn, lambda k, s, pp=_pp: pp.sample(k, s))) continue # If it is the full space (and the mask is not) we restrict it to # if _pp.dim == dim: def _masked_draw(k, s, pp=_pp, msk=mask): _k, _p = pp.sample(k, s) return _k, _p[:, msk] self.ppdraw.append(_masked_draw) continue raise ValueError( "pseudo distributions should be none, defined on the " + "masked space or on the whole space." ) distlist = [ ( cast(ArrayFn, _pp.leval) if (_pp is not None) and (self.cmasks[_ipp].shape[0] > 0) else cast(ArrayFn, lambda p: jnp.zeros(p.shape[0])) ) for _ipp, _pp in enumerate(pseudo) ] self.ppleval = self.ldist_from_list(distlist, self.cmasks) self.allmodspp = self.ldist_from_list(distlist, self.cmasks, stack=True)
[docs] @staticmethod def get_masks( masks: list[list[int] | None | int | tuple[int, int]], dim: int ) -> tuple[list[jax.Array], list[jax.Array]]: """Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param dim: full space dimension. :return: list of model masks (as array of indices) and list of models complementary masks. """ for _ind, _msk in enumerate(masks): if _msk is None: _msk = (0, dim) if isinstance(_msk, int): _msk = (0, _msk) if isinstance(_msk, tuple): masks[_ind] = cast(list[int], list(range(*_msk))) return [jnp.array(_msk, dtype=int) for _msk in cast(list[list[int]], masks)], [ jnp.array([_ind for _ind in range(dim) if _ind not in _msk], dtype=int) for _msk in cast(list[list[int]], masks) ]
[docs] @staticmethod def ldist_from_list( logdists: list[ArrayFn], masks: list[jax.Array], stack: bool = False ) -> ArrayFn: """ Defines a product space log distribution from a list of a single model distributions. :param logdists: list of single space distributions :math:`log P(p)` :param masks: mask of parameters indexes of each model space :param stack: if true defines the function returning the values for all models by stacking the results of all the functions in the list. :return: full product space log distribution. """ if len(masks) != len(logdists): raise ValueError( "Number of masks and size of the distributions lists should be the same." ) logdists = [ ( cast(ArrayFn, lambda p, ld=_ld, msk=masks[_ind]: ld(p[msk])) if masks[_ind].size > 0 else cast(ArrayFn, lambda p: jnp.array(0.0)) ) for _ind, _ld in enumerate(logdists) ] if stack: return cast(ArrayFn, lambda p: jnp.stack([_ld(p) for _ld in logdists])) return cast(ArrayFn, lambda k, p: jax.lax.switch(k.squeeze(-1), logdists, p))
def _get_updates(self) -> tuple[ArrayFn, ArrayFn]: masked = [ cast(ArrayFn, lambda a, b, m=_msk: a.at[m].set(b[m])) for _msk in self.masks ] def _func(k, a, b): return jax.lax.switch(k.squeeze(-1), masked, a, b) return cast(ArrayFn, _func), cast(ArrayFn, jax.vmap(_func, in_axes=(0, 0, 0)))
[docs] class EpochMS[Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS]( EpochMH[Tstate, Tsampling] ): """Epoch definition class for model selection. With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class, this class define :py:attr:`jexplore.sampling.mh.StateMS` as default state class and overrides the `complete` method to also compute the log likelihood and log prior of the epoch points (with the model selection signature). :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. """ statecls = StateMS """State class"""