jexplore.sampling.modsel ======================== .. py:module:: jexplore.sampling.modsel .. autoapi-nested-parse:: This module contains the definitions of classes and types for running a Markov chains in prod space model selection. Classes ------- .. autoapisummary:: jexplore.sampling.modsel.StateMS jexplore.sampling.modsel.SamplingMS jexplore.sampling.modsel.EpochMS Module Contents --------------- .. py:class:: StateMS Bases: :py:obj:`jexplore.sampling.mh.StateMH` Definition of a markov chain state for model selection. With respect to the parent :py:attr:`jexplore.sampling.mh.StateMH` class, it adds the model index `k` for each chain (operations on these parameters are handled by the base class methods). :param k: space index (nchains, 1) :param p: status point (nchains, dim) :param ll: log likelihood values for each chain point (nchains, 1). :param lp: log prior values for each chain point (nchains, 1). .. py:attribute:: k :type: jax.Array model index .. py:class:: SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik = None, logprior = None, allmodsll = None, allmodslp = None, pseudo = None, dim = None, space = None, inpars = None) Bases: :py:obj:`jexplore.sampling.mh.SamplingMH`\ [\ :py:obj:`Tspace`\ ] Markov sampling defining parameters for model selection. With respect to the parent :py:attr:`jexplore.sampling.base.SamplingMH` class the definition of the sampling includes the following parameters :param nwalker: number of walkers per temperature. :param temps: temperature ladder :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param loglik: log likelihood function :math:`logL(k, p)`, a list of math:`logL(p)` for each model (in which case :math:`logL(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`logL(k, p)` is defined by selecting one component of the all models log likelihood defined by `allmodsll` - thus with a speculative computing approach). :param allmodsll: function :math:`vlogL(p)` returning the values of the log likelihood for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`logL(k, p)` defined by `loglik` for all possible `k`. Note: `allmodsll` and `loglik` cannot be both `None`. :param logprior: log prior function :math:`log \pi(k, p)` or list of math:`log \pi(p)` for each model. (in which case :math:`log \pi(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`log \pi(k, p)` is defined by selecting one component of the all models log prior defined by `allmodslp` - thus with a speculative computing approach). :param allmodslp: function :math:`vlog \pi(p)` returning the values of the log prior for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`log \pi(k, p)` defined by `logprior` for all possible `k`. Note: `allmodslp` and `logprior` cannot be both `None`. :param pseudo: list of pseudo prior distributions for all the models. Each element of the list can be: #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the complementary space of the corresponding model. In which case is taken as it is. #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the full space. In which case this is masked to match the model mask. #. If it is `None`, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.). If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0. :param dim: dimensionality of the full space. :param space: :py:attr:`jexplore.sampling.space` object describing the target space. :param inpars: list of input parameter(s) used for the computation of log likelihood and log prior. .. py:attribute:: masks :type: list[jax.Array] mask of parameters indexes of each model space .. py:attribute:: cmasks :type: list[jax.Array] complementary mask of parameters indexes of each model space .. py:attribute:: mod_update :type: jexplore.sampling.state.ArrayFn :math:`f(k, a, b)` updating array `a` of shapce `self.dim` with an array `b` of the same shape but only on the components of the mask of model `k` .. py:attribute:: vmod_update :type: jexplore.sampling.state.ArrayFn vectorized version of `self.mod_update` so that it can perform the same update operation on vectors `k`, `a` and `b` of shapes `(seld.nwalker, )`, `(seld.nwalker, self.dim)` and `(seld.nwalker, self.dim)` .. py:attribute:: allmodsll :type: jexplore.sampling.state.ArrayFn Function :math:`F(p)` returning the log likelihood for all models .. py:attribute:: allmodslp :type: jexplore.sampling.state.ArrayFn Function :math:`\pi(p)` returning the log prior for all models .. py:attribute:: ppleval :type: jexplore.sampling.state.ArrayFn Log pseudo prior function :math:`\tilde{\pi}(k, p)` .. py:attribute:: allmodspp :type: jexplore.sampling.state.ArrayFn Function :math:`\tilde{\pi}(p)` returning the log pseudo prior for all models .. py:attribute:: ppdraw :type: list[jexplore.tools.distributions.DrawFn] List of pseudo prior drawing functions .. py:attribute:: nmodels :type: int number of models .. py:method:: to_backend() Save the attributes of the object in backend format. .. py:method:: get_masks(masks, dim) :staticmethod: Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param dim: full space dimension. :return: list of model masks (as array of indices) and list of models complementary masks. .. py:method:: ldist_from_list(logdists, masks, stack = False) :staticmethod: Defines a product space log distribution from a list of a single model distributions. :param logdists: list of single space distributions :math:`log P(p)` :param masks: mask of parameters indexes of each model space :param stack: if true defines the function returning the values for all models by stacking the results of all the functions in the list. :return: full product space log distribution. .. py:class:: EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs = False) Bases: :py:obj:`jexplore.sampling.mh.EpochMH`\ [\ :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Epoch definition class for model selection. With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class, this class define :py:attr:`jexplore.sampling.mh.StateMS` as default state class and overrides the `complete` method to also compute the log likelihood and log prior of the epoch points (with the model selection signature). :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. .. py:attribute:: statecls State class