jexplore.sampling.modsel#
This module contains the definitions of classes and types for running a Markov chains in prod space model selection.
Classes#
Definition of a markov chain state for model selection. |
|
Markov sampling defining parameters for model selection. |
|
Epoch definition class for model selection. |
Module Contents#
- class StateMS[source]#
Bases:
jexplore.sampling.mh.StateMHDefinition of a markov chain state for model selection. With respect to the parent
jexplore.sampling.mh.StateMHclass, it adds the model index k for each chain (operations on these parameters are handled by the base class methods).- Parameters:
k – space index (nchains, 1)
p – status point (nchains, dim)
ll – log likelihood values for each chain point (nchains, 1).
lp – log prior values for each chain point (nchains, 1).
- k: jax.Array#
model index
- class SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik=None, logprior=None, allmodsll=None, allmodslp=None, pseudo=None, dim=None, space=None, inpars=None)[source]#
Bases:
jexplore.sampling.mh.SamplingMH[Tspace]Markov sampling defining parameters for model selection. With respect to the parent
jexplore.sampling.base.SamplingMHclass the definition of the sampling includes the following parameters- Parameters:
nwalker (int) – number of walkers per temperature.
temps (jax.Array) – temperature ladder
masks (list[list[int] | None | int | tuple[int, int]]) –
list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be:
None: all the space coordinates
integer d: first d dimension of the space
tuple (a, b): defining the correponding coordinates slice
list of integer: explicit list of coordinates indexes
loglik (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log likelihood function \(logL(k, p)\), a list of math:logL(p) for each model (in which case \(logL(k, p)\) is defined by branching on the functions of the list) or None (in which case \(logL(k, p)\) is defined by selecting one component of the all models log likelihood defined by allmodsll - thus with a speculative computing approach).
allmodsll (jexplore.sampling.state.ArrayFn | None) – function \(vlogL(p)\) returning the values of the log likelihood for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(logL(k, p)\) defined by loglik for all possible k. Note: allmodsll and loglik cannot be both None.
logprior (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log prior function \(log \pi(k, p)\) or list of math:log pi(p) for each model. (in which case \(log \pi(k, p)\) is defined by branching on the functions of the list) or None (in which case \(log \pi(k, p)\) is defined by selecting one component of the all models log prior defined by allmodslp - thus with a speculative computing approach).
allmodslp (jexplore.sampling.state.ArrayFn | None) – function \(vlog \pi(p)\) returning the values of the log prior for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(log \pi(k, p)\) defined by logprior for all possible k. Note: allmodslp and logprior cannot be both None.
pseudo (list[jexplore.tools.distributions.Distr | None] | None) –
list of pseudo prior distributions for all the models. Each element of the list can be:
A
jexpore.tools.distributions.Distrobject defined on the complementary space of the corresponding model. In which case is taken as it is.A
jexpore.tools.distributions.Distrobject defined on the full space. In which case this is masked to match the model mask.If it is None, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.).
If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0.
dim (int | None) – dimensionality of the full space.
space (Tspace | None) –
jexplore.sampling.spaceobject describing the target space.inpars (list[str] | None) – list of input parameter(s) used for the computation of log likelihood and log prior.
- masks: list[jax.Array]#
mask of parameters indexes of each model space
- cmasks: list[jax.Array]#
complementary mask of parameters indexes of each model space
- mod_update: jexplore.sampling.state.ArrayFn#
\(f(k, a, b)\) updating array a of shapce self.dim with an array b of the same shape but only on the components of the mask of model k
- vmod_update: jexplore.sampling.state.ArrayFn#
vectorized version of self.mod_update so that it can perform the same update operation on vectors k, a and b of shapes (seld.nwalker, ), (seld.nwalker, self.dim) and (seld.nwalker, self.dim)
- allmodsll: jexplore.sampling.state.ArrayFn#
Function \(F(p)\) returning the log likelihood for all models
- allmodslp: jexplore.sampling.state.ArrayFn#
Function \(\pi(p)\) returning the log prior for all models
- ppleval: jexplore.sampling.state.ArrayFn#
Log pseudo prior function \(\tilde{\pi}(k, p)\)
- allmodspp: jexplore.sampling.state.ArrayFn#
Function \(\tilde{\pi}(p)\) returning the log pseudo prior for all models
- ppdraw: list[jexplore.tools.distributions.DrawFn]#
List of pseudo prior drawing functions
- nmodels: int#
number of models
- static get_masks(masks, dim)[source]#
Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space
of each one of the models. This must be a list which length is the number of models. Elements of the list can be:
None: all the space coordinates
integer d: first d dimension of the space
tuple (a, b): defining the correponding coordinates slice
list of integer: explicit list of coordinates indexes
- Parameters:
dim (int) – full space dimension.
masks (list[list[int] | None | int | tuple[int, int]])
- Returns:
list of model masks (as array of indices) and list of models
- Return type:
tuple[list[jax.Array], list[jax.Array]]
complementary masks.
- static ldist_from_list(logdists, masks, stack=False)[source]#
Defines a product space log distribution from a list of a single model distributions.
- Parameters:
logdists (list[jexplore.sampling.state.ArrayFn]) – list of single space distributions \(log P(p)\)
masks (list[jax.Array]) – mask of parameters indexes of each model space
stack (bool) – if true defines the function returning the values for all models by stacking the results of all the functions in the list.
- Returns:
full product space log distribution.
- Return type:
- class EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs=False)[source]#
Bases:
jexplore.sampling.mh.EpochMH[Tstate,Tsampling]Epoch definition class for model selection.
With respect to the parent
jexplore.sampling.epoch.Epochclass, this class definejexplore.sampling.mh.StateMSas default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points (with the model selection signature).- Parameters:
epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.
force_covs (bool) – if true forces the recomputing of the covariance matrices.
- statecls#
State class