jexplore.sampling#
Classes for sampling
Submodules#
Classes#
Markov sampling defining parameters. |
|
Base Epoch definition class. |
|
Class containing the definition of an epoch steps cycle. |
|
Class statistics the aggregated statistics of the epoch steps. |
|
Prototype for a step sampling function. |
|
Epoch definition class for Metropolis-Hastings. |
|
Markov sampling defining parameters for Metropolis-Hastings |
|
This class provides the definition of a markovc chain state for a |
|
Epoch definition class for model selection. |
|
Markov sampling defining parameters for model selection. |
|
Definition of a markov chain state for model selection. |
|
Simple rectangular box space. |
|
This class provides the base definition of markov chains state. |
Package Contents#
- class Sampling[Tspace: jexplore.sampling.space.Space](nchain, dim=None, space=None)[source]#
Markov sampling defining parameters.
In this base class version the parameter are the dimension of the target space and the number of chains. Child classes may specialize to specific type of markov sampling (e.g.
jexplore.sampling.mh.SamplingMH)- Parameters:
dim (int | None) – dimension of the target space.
nchains – number of chains.
space (Tspace | None) –
jexplore.sampling.spaceobject describing the target space.nchain (int)
the state will then be defined by a (nchain, dim) point
- dim: int#
Dimension of the target space
- nchain: int#
Number of chains
- space: Tspace | jexplore.sampling.space.Box#
Target space.
- class Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs=False)[source]#
Base Epoch definition class.
This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run.
- Parameters:
epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.
force_covs (bool) – if true forces the recomputing of the covariance matrices.
- samples: Tstate#
batched state defining the sampling states sequence.
- covs: jax.Array#
Array with shape \((N, D, D)\) containing the list of covariance matrices of all chains.
- stats: EpochStats[Tstate]#
epoch statistics.
- cycle: EpochCycle[Tstate]#
definition of the epoch cycle
- sampling: Tsampling#
sampling definition
- statecls: type[Tstate] | type[jexplore.sampling.state.State]#
state class corresponding to this epoch class
- complete(sampling, compute=None)[source]#
Complete the definition and the parameters of the epoch.
- Parameters:
sampling (Tsampling) – sampling parameters.
compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.
- Returns:
this object after update.
- Return type:
Self
- set_cycle(cycle)[source]#
Sets this epoch cycle.
- Parameters:
cycle (EpochCycle[Tstate]) – the EpochCycle value to be set.
- Return type:
None
- run_step(key, ind, state)[source]#
Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it.
- Parameters:
key (jax.Array) – PRNG key
ind (int) – the index of the substep
state (Tstate) – current state
- Returns:
the new state, the step stats and the index of the substep alternative step ctualy performed.
- Return type:
tuple[Tstate, jax.Array, jax.Array]
- from_run(states, stats, steps)[source]#
Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data.
- Parameters:
states (Tstate) – stack of states.
stats (jax.Array) – stack of steps stats
steps (jax.Array) – stack of steps indexes
- Returns:
this object after updating
- Return type:
Self
- class EpochCycle[Tstate: jexplore.sampling.state.State](steps)[source]#
Class containing the definition of an epoch steps cycle.
- Parameters:
steps (list[dict[StepSample[Tstate], float]]) – list of substeps each one defined by a dictionnary of alternative steps with their weight.
- steps: tuple[StepSample[Tstate], Ellipsis]#
tuple of all the steps.
- weights: jax.Array#
Matrix of weights. This has shape \((S, A_{max})\) where \(S\) is the number of serial substeps in the cycle and \(A_{max}\) is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative.
- isteps: jax.Array#
\((S, A_{max})\) matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding StepSample function in the steps tuple.
- n: int#
number of sequential substeps.
- to_backend()[source]#
dumps the cycle definition into a numpy array. This array has a row for each one of the defined substeps/alternatives containing
name: the qualified name of the StepSample function.
incycle: the index of the substep
weight: the weight of the corresponding substep alternative.
- Return type:
numpy.ndarray
- class EpochStats[Tstate: jexplore.sampling.state.State][source]#
Class statistics the aggregated statistics of the epoch steps.
- Parameters:
steps – array of steps performed at each iteration.
counts – counters of the number of executions of each step.
stats – chains acceptance statistics for each step.
- steps: jax.Array#
This is \(I\) sized array Where \(I\) is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding steps tuple in the epoch’s EpochCycle instance.
- counts: jax.Array#
this array has the same size as the steps tuple in the epoch’s EpochCycle instance. Each element contains the count of how many time the corresponding step was performed during the sampling.
- stats: jax.Array#
this array has shape \((N, C)\) where \(C\) is the size of the steps tuple in the epoch’s EpochCycle instance (that is the number of distinct steps) and \(N\) the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH).
- static from_run(stats, steps, cycle)[source]#
Each step returns a \(N\) sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding EpochStats object.
- Parameters:
stats (jax.Array) – stacked stats array for all steps in the sampling. This has shape \((N, I)\).
steps (jax.Array) – array with indices of the steps performed. This is a \(I\) sized array which enries are indices \(a = 0, \dots, A_{max} - 1\) identifing which one of the possible alternatives was actually performed at each step.
cycle (EpochCycle[Tstate]) – EpochCycle instance.
- Returns:
the EpochStats corresponding to these data.
- Return type:
EpochStats[Tstate]
- to_backend()[source]#
Dump the data of the this EpochStats instance, along with the information from the correponding EpochCycle instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys:
steps: numpy.ndarray cats of the object steps attribute
counts: numpy.ndarray cats of the object counts attribute
stats: numpy.ndarray cats of the object stats attribute
steps_defs: data from EpochCycle instance as returned by
jexplore.sampling.epoch.EpochCycle.to_backend
- Return type:
dict
- class StepSample[Tstate: jexplore.sampling.state.State][source]#
Bases:
ProtocolPrototype for a step sampling function.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
- Returns:
the new state and the chain changing stats of the step (nchain).
- Return type:
tuple[Tstate, jax.Array]
- class EpochMH[Tstate: StateMH, Tsampling: SamplingMH](epoch, force_covs=False)[source]#
Bases:
jexplore.sampling.epoch.Epoch[Tstate,Tsampling]Epoch definition class for Metropolis-Hastings.
With respect to the parent
jexplore.sampling.epoch.Epochclass, this class definejexplore.sampling.mh.StateMHas default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points.- Parameters:
epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.
force_covs (bool) – if true forces the recomputing of the covariance matrices.
- complete(sampling, compute=None)[source]#
Complete the definition and the parameters of the epoch.
- Parameters:
sampling (Tsampling) – sampling parameters.
compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.
- Return type:
Self
With respect to
jexplore.sampling.epoch.Epoch.complete, this implementation of the complete method provides a base for the compute parameter with the instructions to compute log likelihood and log prior epoch samples attributes.- Returns:
this object after update.
- Parameters:
sampling (Tsampling)
compute (dict | None)
- Return type:
Self
- class SamplingMH[Tspace: jexplore.sampling.space.Space](nwalker, temps, loglik, logprior, dim=None, space=None, inpars=None)[source]#
Bases:
jexplore.sampling.base.Sampling[Tspace]Markov sampling defining parameters for Metropolis-Hastings sampling. With respect to the parent
jexplore.sampling.base.Samplingclass the definition of the sampling includes the following parameters- Parameters:
nwalker (int) – number of walkers per temperature.
temps (jax.Array) – temperature ladder
loglik (jexplore.sampling.state.ArrayFn) – log likelihood function
logprior (jexplore.sampling.state.ArrayFn) – log prior function
dim (int | None)
space (Tspace | None)
inpars (list[str] | None)
The nchain parameter is computed from nwalker and temps.
- nwalker: int#
number of walkers
- temps: jax.Array#
temperature ladder
- loglik: jexplore.sampling.state.ArrayFn#
log likelihood function
- logprior: jexplore.sampling.state.ArrayFn#
log prior function
- inpars: list[str]#
list of input parameters of loglik and logprior
- class StateMH[source]#
Bases:
jexplore.sampling.state.StateThis class provides the definition of a markovc chain state for a Metropolis-Hastings markov sampler. With respect to the parent
jexplore.sampling.state.Stateclass, it simply adds log likelihood and log prior values to the state attributes (operations on these parameters are handled by the parent class methods).- Parameters:
p – status point (nchains, dim)
ll – log likelihood values for each chain point (nchains, 1).
lp – log prior values for each chain point (nchains, 1).
- ll: jax.Array#
log likelihood values
- lp: jax.Array#
log prior values.
- class EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs=False)[source]#
Bases:
jexplore.sampling.mh.EpochMH[Tstate,Tsampling]Epoch definition class for model selection.
With respect to the parent
jexplore.sampling.epoch.Epochclass, this class definejexplore.sampling.mh.StateMSas default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points (with the model selection signature).- Parameters:
epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.
force_covs (bool) – if true forces the recomputing of the covariance matrices.
- statecls#
State class
- class SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik=None, logprior=None, allmodsll=None, allmodslp=None, pseudo=None, dim=None, space=None, inpars=None)[source]#
Bases:
jexplore.sampling.mh.SamplingMH[Tspace]Markov sampling defining parameters for model selection. With respect to the parent
jexplore.sampling.base.SamplingMHclass the definition of the sampling includes the following parameters- Parameters:
nwalker (int) – number of walkers per temperature.
temps (jax.Array) – temperature ladder
masks (list[list[int] | None | int | tuple[int, int]]) –
list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be:
None: all the space coordinates
integer d: first d dimension of the space
tuple (a, b): defining the correponding coordinates slice
list of integer: explicit list of coordinates indexes
loglik (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log likelihood function \(logL(k, p)\), a list of math:logL(p) for each model (in which case \(logL(k, p)\) is defined by branching on the functions of the list) or None (in which case \(logL(k, p)\) is defined by selecting one component of the all models log likelihood defined by allmodsll - thus with a speculative computing approach).
allmodsll (jexplore.sampling.state.ArrayFn | None) – function \(vlogL(p)\) returning the values of the log likelihood for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(logL(k, p)\) defined by loglik for all possible k. Note: allmodsll and loglik cannot be both None.
logprior (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log prior function \(log \pi(k, p)\) or list of math:log pi(p) for each model. (in which case \(log \pi(k, p)\) is defined by branching on the functions of the list) or None (in which case \(log \pi(k, p)\) is defined by selecting one component of the all models log prior defined by allmodslp - thus with a speculative computing approach).
allmodslp (jexplore.sampling.state.ArrayFn | None) – function \(vlog \pi(p)\) returning the values of the log prior for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(log \pi(k, p)\) defined by logprior for all possible k. Note: allmodslp and logprior cannot be both None.
pseudo (list[jexplore.tools.distributions.Distr | None] | None) –
list of pseudo prior distributions for all the models. Each element of the list can be:
A
jexpore.tools.distributions.Distrobject defined on the complementary space of the corresponding model. In which case is taken as it is.A
jexpore.tools.distributions.Distrobject defined on the full space. In which case this is masked to match the model mask.If it is None, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.).
If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0.
dim (int | None) – dimensionality of the full space.
space (Tspace | None) –
jexplore.sampling.spaceobject describing the target space.inpars (list[str] | None) – list of input parameter(s) used for the computation of log likelihood and log prior.
- masks: list[jax.Array]#
mask of parameters indexes of each model space
- cmasks: list[jax.Array]#
complementary mask of parameters indexes of each model space
- mod_update: jexplore.sampling.state.ArrayFn#
\(f(k, a, b)\) updating array a of shapce self.dim with an array b of the same shape but only on the components of the mask of model k
- vmod_update: jexplore.sampling.state.ArrayFn#
vectorized version of self.mod_update so that it can perform the same update operation on vectors k, a and b of shapes (seld.nwalker, ), (seld.nwalker, self.dim) and (seld.nwalker, self.dim)
- allmodsll: jexplore.sampling.state.ArrayFn#
Function \(F(p)\) returning the log likelihood for all models
- allmodslp: jexplore.sampling.state.ArrayFn#
Function \(\pi(p)\) returning the log prior for all models
- ppleval: jexplore.sampling.state.ArrayFn#
Log pseudo prior function \(\tilde{\pi}(k, p)\)
- allmodspp: jexplore.sampling.state.ArrayFn#
Function \(\tilde{\pi}(p)\) returning the log pseudo prior for all models
- ppdraw: list[jexplore.tools.distributions.DrawFn]#
List of pseudo prior drawing functions
- nmodels: int#
number of models
- static get_masks(masks, dim)[source]#
Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space
of each one of the models. This must be a list which length is the number of models. Elements of the list can be:
None: all the space coordinates
integer d: first d dimension of the space
tuple (a, b): defining the correponding coordinates slice
list of integer: explicit list of coordinates indexes
- Parameters:
dim (int) – full space dimension.
masks (list[list[int] | None | int | tuple[int, int]])
- Returns:
list of model masks (as array of indices) and list of models
- Return type:
tuple[list[jax.Array], list[jax.Array]]
complementary masks.
- static ldist_from_list(logdists, masks, stack=False)[source]#
Defines a product space log distribution from a list of a single model distributions.
- Parameters:
logdists (list[jexplore.sampling.state.ArrayFn]) – list of single space distributions \(log P(p)\)
masks (list[jax.Array]) – mask of parameters indexes of each model space
stack (bool) – if true defines the function returning the values for all models by stacking the results of all the functions in the list.
- Returns:
full product space log distribution.
- Return type:
- class StateMS[source]#
Bases:
jexplore.sampling.mh.StateMHDefinition of a markov chain state for model selection. With respect to the parent
jexplore.sampling.mh.StateMHclass, it adds the model index k for each chain (operations on these parameters are handled by the base class methods).- Parameters:
k – space index (nchains, 1)
p – status point (nchains, dim)
ll – log likelihood values for each chain point (nchains, 1).
lp – log prior values for each chain point (nchains, 1).
- k: jax.Array#
model index
- class Box(dim=None, size=jnp.inf, box=None, wrapped=None)[source]#
Bases:
SpaceSimple rectangular box space.
- Parameters:
dim (int | None) – dimension of the box. Only used if the box parameter is not provided.
size (float) – size of the box (assumed having equal size in all dimensions) only used if the box parameter is not provided. Default: infinity.
box (list[list[float]] | None) – list of lists defining the segments boundaries for each box dimensions. If not provided a box with dim dimensions of equal size size is considered.
wrapped (list[int] | None) – list of periodic dimensions indexes. These dimensions will be considererd unbound and the corresponding box intervals will be interpreted as principal domain intervals. Default: empty list.
- bounds: jax.Array#
Box bounds
- wrap_dims: list[int]#
Wrapped dimensions indexes
- wrap_domain: jax.Array#
wrapped dimensions principal domain
- dim#
Dimension of the target space
- inspace(points)[source]#
Check which of a set of point lay in the defined space.
- Parameters:
points (jax.Array) – set of points to be checked, it can have the state shape \((N_{chains}, D)\) or the samples shape \((N_{chains}, D, S)\).
- Returns:
a boolean mask of shape \((N_{chains})\) or \((N_{chains}, S)\) selecting points that lay in the defined space.
- Return type:
jax.Array
- wrap(points)[source]#
folds the wrapped dimensions of the target space.
- Parameters:
points (jax.Array) – set of points to be processed, it can have the state shape \((N_{chains}, D)\) or the samples shape \((N_{chains}, D, S)\).
- Return type:
jax.Array
return: the same set of points with the wrapped dimensions folded.
- class State[source]#
This class provides the base definition of markov chains state.
Such minimal definition correspond to a single parameter p with shape \((N, D)\) describing the point in the \(D\)-dimensional space for each one of the \(N\) chains of the sampling to be performed.
This can be specialized, in child classes, to different types of sampling by adding extra parameters (e.g.
jexplore.sampling.mh.StateMH) with shape \((N, d)\), where \(d\) depends on the parameter.The class also provides basic state manipulation method which are generic for all the child classes defined as above. Some of these methods can also deal with “batched” states. I.e. State objects having parameters with shape \((N, d, I)\) - where \(I\) is the number of performed sampling iteration - as they are returned by the
jax.lax.scanloops of the sampler.- Parameters:
p – status point (nchains, dim)
- p: jax.Array#
status point
- compute(par, func, inpar)[source]#
Populate the values of one parameter \(p\) as a function of other paramters of the state.
- Parameters:
par (str) – parameter name.
func (ArrayFn) – function defining the parameter value.
inpar (List[str]) – list of the name of the input parameters of the function.
- Returns:
a new state with the populated values.
- Return type:
Self
- update(update, pars=None)[source]#
Updates a set of parameters \(p_i\) of the state by calling an update function:
\[p_i = u(i, p_i)\]- Parameters:
update (Callable[[str, jax.Array], jax.Array]) – the update function.
pars (List[str] | None) – list of the names of parameters to be updated. None (default): all parameters.
- Returns:
a new state with the updated parameters.
- Return type:
Self
- update_mask(other, mask=True, pars=None)[source]#
Import values of some parameters, and for some selected chains, from another state. Chains are selected by a boolean mask (1-array of size nchains).
- Parameters:
other (Self) – other state. The parameters to update should be defined and have the same shape than in the current state.
mask (jax.Array | bool) – boolean mask for the chains to be updated.
pars (List[str] | None) – list of the name of attributes to be updated.
- Returns:
a new state with the updated parameters.
- Return type:
Self
- update_slice(other, slc, pars=None)[source]#
Import values of some parameters, and for some selected chains, from another state. Chains are selected by an index slice (i.e. a list of integers).
- Parameters:
other (Self) – other state. The parameters to update should be defined and have the same shape of the chains slice.
slc (jax.Array) – list of slice indexes.
pars (List[str] | None) – list of the name of attributes to be updated.
- Returns:
a new state with the updated parameters.
- Return type:
Self
- set_val(val, cslc=None, pslc=None, par='p')[source]#
Update one sample parameter slicing both in chains and parameter dimension.
- Parameters:
val (jax.Array) – values (should have the same shape of the slice)
cslc (jax.Array | None) – chains slice. Default: all.
pslc (jax.Array | None) – parameter dimensions slice. Default: all.
par (str) – parameter name. Default: p
- Returns:
new state
- Return type:
Self
- slice(slc)[source]#
Return a state with parameters values corresponding to a slice of the current state chains. That is, if the current state has \(N\) chains, we define slice as a set of indexes \(\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}\). This method constructs a new state having one parameters
\[\tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1, ~~~~~~\alpha=0,\dots,d-1\]for each \((N, d)\) shaped parameter \(p\) of the current state.
- Parameters:
slc (jax.Array) – the slice as a list of indexes.
- Returns:
the new “sliced” state.
- Return type:
Self
- swap(even, odd, accept)[source]#
Swap all parameters values between different chains.
- Parameters:
even (jax.Array) – set of indexes to be swapped
odd (jax.Array) – other set of indexes to be swapped (should have the same size)
accept (jax.Array) – boolean mask to select swaps that should be performed
- Returns:
new state with swapped values.
- Return type:
Self
- get_in_batch(ind=-1, batch=False)[source]#
This method assumes that the current state is batched. Thus its parameters have shape \((N, d, I)\) where \(I\) is the number of performed markov iterations. In this case the method return a new state corresponing to one specific iteration.
- Parameters:
ind (int) – index of the iteration.
batch (bool) – if True the returned state will be a 1-iteration batched state.
- Returns:
the selected state.
- Return type:
Self
- classmethod from_dict(state_d, mandatory=None, batch=False)[source]#
Instantiates a state object from a dictionnary. It will select all entries in the dictionnary that correspond to the parameters of this class of states, and use the corresponding values to define the state.
- Parameters:
state_d (dict) – input dictionnary.
mandatory (List[str] | None) – list of nams of parameters that have to be present in the dictionnary. Non mandatory parameters, if absent, will be replaced by empty arrays. If None all parameters are mandatory.
batch (bool) – when true it forces the the instantiated state to have a “batched” shape.
- Returns:
the state defined by the dictionnary.
- Return type:
Self