jexplore.sampling.modsel#

This module contains the definitions of classes and types for running a Markov chains in prod space model selection.

Classes#

StateMS

Definition of a markov chain state for model selection.

SamplingMS

Markov sampling defining parameters for model selection.

EpochMS

Epoch definition class for model selection.

Module Contents#

class StateMS[source]#

Bases: jexplore.sampling.mh.StateMH

Definition of a markov chain state for model selection. With respect to the parent jexplore.sampling.mh.StateMH class, it adds the model index k for each chain (operations on these parameters are handled by the base class methods).

Parameters:
  • k – space index (nchains, 1)

  • p – status point (nchains, dim)

  • ll – log likelihood values for each chain point (nchains, 1).

  • lp – log prior values for each chain point (nchains, 1).

k: jax.Array#

model index

class SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik=None, logprior=None, allmodsll=None, allmodslp=None, pseudo=None, dim=None, space=None, inpars=None)[source]#

Bases: jexplore.sampling.mh.SamplingMH[Tspace]

Markov sampling defining parameters for model selection. With respect to the parent jexplore.sampling.base.SamplingMH class the definition of the sampling includes the following parameters

Parameters:
  • nwalker (int) – number of walkers per temperature.

  • temps (jax.Array) – temperature ladder

  • masks (list[list[int] | None | int | tuple[int, int]]) –

    list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be:

    1. None: all the space coordinates

    2. integer d: first d dimension of the space

    3. tuple (a, b): defining the correponding coordinates slice

    4. list of integer: explicit list of coordinates indexes

  • loglik (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log likelihood function \(logL(k, p)\), a list of math:logL(p) for each model (in which case \(logL(k, p)\) is defined by branching on the functions of the list) or None (in which case \(logL(k, p)\) is defined by selecting one component of the all models log likelihood defined by allmodsll - thus with a speculative computing approach).

  • allmodsll (jexplore.sampling.state.ArrayFn | None) – function \(vlogL(p)\) returning the values of the log likelihood for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(logL(k, p)\) defined by loglik for all possible k. Note: allmodsll and loglik cannot be both None.

  • logprior (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log prior function \(log \pi(k, p)\) or list of math:log pi(p) for each model. (in which case \(log \pi(k, p)\) is defined by branching on the functions of the list) or None (in which case \(log \pi(k, p)\) is defined by selecting one component of the all models log prior defined by allmodslp - thus with a speculative computing approach).

  • allmodslp (jexplore.sampling.state.ArrayFn | None) – function \(vlog \pi(p)\) returning the values of the log prior for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(log \pi(k, p)\) defined by logprior for all possible k. Note: allmodslp and logprior cannot be both None.

  • pseudo (list[jexplore.tools.distributions.Distr | None] | None) –

    list of pseudo prior distributions for all the models. Each element of the list can be:

    1. A jexpore.tools.distributions.Distr object defined on the complementary space of the corresponding model. In which case is taken as it is.

    2. A jexpore.tools.distributions.Distr object defined on the full space. In which case this is masked to match the model mask.

    3. If it is None, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.).

    If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0.

  • dim (int | None) – dimensionality of the full space.

  • space (Tspace | None) – jexplore.sampling.space object describing the target space.

  • inpars (list[str] | None) – list of input parameter(s) used for the computation of log likelihood and log prior.

masks: list[jax.Array]#

mask of parameters indexes of each model space

cmasks: list[jax.Array]#

complementary mask of parameters indexes of each model space

mod_update: jexplore.sampling.state.ArrayFn#

\(f(k, a, b)\) updating array a of shapce self.dim with an array b of the same shape but only on the components of the mask of model k

vmod_update: jexplore.sampling.state.ArrayFn#

vectorized version of self.mod_update so that it can perform the same update operation on vectors k, a and b of shapes (seld.nwalker, ), (seld.nwalker, self.dim) and (seld.nwalker, self.dim)

allmodsll: jexplore.sampling.state.ArrayFn#

Function \(F(p)\) returning the log likelihood for all models

allmodslp: jexplore.sampling.state.ArrayFn#

Function \(\pi(p)\) returning the log prior for all models

ppleval: jexplore.sampling.state.ArrayFn#

Log pseudo prior function \(\tilde{\pi}(k, p)\)

allmodspp: jexplore.sampling.state.ArrayFn#

Function \(\tilde{\pi}(p)\) returning the log pseudo prior for all models

ppdraw: list[jexplore.tools.distributions.DrawFn]#

List of pseudo prior drawing functions

nmodels: int#

number of models

to_backend()[source]#

Save the attributes of the object in backend format.

Return type:

dict

static get_masks(masks, dim)[source]#

Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space

of each one of the models. This must be a list which length is the number of models. Elements of the list can be:

  1. None: all the space coordinates

  2. integer d: first d dimension of the space

  3. tuple (a, b): defining the correponding coordinates slice

  4. list of integer: explicit list of coordinates indexes

Parameters:
  • dim (int) – full space dimension.

  • masks (list[list[int] | None | int | tuple[int, int]])

Returns:

list of model masks (as array of indices) and list of models

Return type:

tuple[list[jax.Array], list[jax.Array]]

complementary masks.

static ldist_from_list(logdists, masks, stack=False)[source]#

Defines a product space log distribution from a list of a single model distributions.

Parameters:
  • logdists (list[jexplore.sampling.state.ArrayFn]) – list of single space distributions \(log P(p)\)

  • masks (list[jax.Array]) – mask of parameters indexes of each model space

  • stack (bool) – if true defines the function returning the values for all models by stacking the results of all the functions in the list.

Returns:

full product space log distribution.

Return type:

jexplore.sampling.state.ArrayFn

class EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs=False)[source]#

Bases: jexplore.sampling.mh.EpochMH[Tstate, Tsampling]

Epoch definition class for model selection.

With respect to the parent jexplore.sampling.epoch.Epoch class, this class define jexplore.sampling.mh.StateMS as default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points (with the model selection signature).

Parameters:
  • epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.

  • force_covs (bool) – if true forces the recomputing of the covariance matrices.

statecls#

State class