jexplore.sampling#

Classes for sampling

Submodules#

Classes#

Sampling

Markov sampling defining parameters.

Epoch

Base Epoch definition class.

EpochCycle

Class containing the definition of an epoch steps cycle.

EpochStats

Class statistics the aggregated statistics of the epoch steps.

StepSample

Prototype for a step sampling function.

EpochMH

Epoch definition class for Metropolis-Hastings.

SamplingMH

Markov sampling defining parameters for Metropolis-Hastings

StateMH

This class provides the definition of a markovc chain state for a

EpochMS

Epoch definition class for model selection.

SamplingMS

Markov sampling defining parameters for model selection.

StateMS

Definition of a markov chain state for model selection.

Box

Simple rectangular box space.

State

This class provides the base definition of markov chains state.

Package Contents#

class Sampling[Tspace: jexplore.sampling.space.Space](nchain, dim=None, space=None)[source]#

Markov sampling defining parameters.

In this base class version the parameter are the dimension of the target space and the number of chains. Child classes may specialize to specific type of markov sampling (e.g. jexplore.sampling.mh.SamplingMH)

Parameters:
  • dim (int | None) – dimension of the target space.

  • nchains – number of chains.

  • space (Tspace | None) – jexplore.sampling.space object describing the target space.

  • nchain (int)

the state will then be defined by a (nchain, dim) point

dim: int#

Dimension of the target space

nchain: int#

Number of chains

space: Tspace | jexplore.sampling.space.Box#

Target space.

to_backend()[source]#

Save the attributes of the object in backend format.

Return type:

dict

class Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs=False)[source]#

Base Epoch definition class.

This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run.

Parameters:
  • epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.

  • force_covs (bool) – if true forces the recomputing of the covariance matrices.

samples: Tstate#

batched state defining the sampling states sequence.

covs: jax.Array#

Array with shape \((N, D, D)\) containing the list of covariance matrices of all chains.

stats: EpochStats[Tstate]#

epoch statistics.

cycle: EpochCycle[Tstate]#

definition of the epoch cycle

sampling: Tsampling#

sampling definition

statecls: type[Tstate] | type[jexplore.sampling.state.State]#

state class corresponding to this epoch class

complete(sampling, compute=None)[source]#

Complete the definition and the parameters of the epoch.

Parameters:
  • sampling (Tsampling) – sampling parameters.

  • compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.

Returns:

this object after update.

Return type:

Self

set_cycle(cycle)[source]#

Sets this epoch cycle.

Parameters:

cycle (EpochCycle[Tstate]) – the EpochCycle value to be set.

Return type:

None

run_step(key, ind, state)[source]#

Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it.

Parameters:
  • key (jax.Array) – PRNG key

  • ind (int) – the index of the substep

  • state (Tstate) – current state

Returns:

the new state, the step stats and the index of the substep alternative step ctualy performed.

Return type:

tuple[Tstate, jax.Array, jax.Array]

from_run(states, stats, steps)[source]#

Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data.

Parameters:
  • states (Tstate) – stack of states.

  • stats (jax.Array) – stack of steps stats

  • steps (jax.Array) – stack of steps indexes

Returns:

this object after updating

Return type:

Self

to_backend(burn, stats)[source]#

dumps epoch in backend format

Parameters:
  • burn (int)

  • stats (bool)

Return type:

tuple[int, int, dict | None]

class EpochCycle[Tstate: jexplore.sampling.state.State](steps)[source]#

Class containing the definition of an epoch steps cycle.

Parameters:

steps (list[dict[StepSample[Tstate], float]]) – list of substeps each one defined by a dictionnary of alternative steps with their weight.

steps: tuple[StepSample[Tstate], Ellipsis]#

tuple of all the steps.

weights: jax.Array#

Matrix of weights. This has shape \((S, A_{max})\) where \(S\) is the number of serial substeps in the cycle and \(A_{max}\) is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative.

isteps: jax.Array#

\((S, A_{max})\) matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding StepSample function in the steps tuple.

n: int#

number of sequential substeps.

to_backend()[source]#

dumps the cycle definition into a numpy array. This array has a row for each one of the defined substeps/alternatives containing

  • name: the qualified name of the StepSample function.

  • incycle: the index of the substep

  • weight: the weight of the corresponding substep alternative.

Return type:

numpy.ndarray

class EpochStats[Tstate: jexplore.sampling.state.State][source]#

Class statistics the aggregated statistics of the epoch steps.

Parameters:
  • steps – array of steps performed at each iteration.

  • counts – counters of the number of executions of each step.

  • stats – chains acceptance statistics for each step.

steps: jax.Array#

This is \(I\) sized array Where \(I\) is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding steps tuple in the epoch’s EpochCycle instance.

counts: jax.Array#

this array has the same size as the steps tuple in the epoch’s EpochCycle instance. Each element contains the count of how many time the corresponding step was performed during the sampling.

stats: jax.Array#

this array has shape \((N, C)\) where \(C\) is the size of the steps tuple in the epoch’s EpochCycle instance (that is the number of distinct steps) and \(N\) the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH).

static from_run(stats, steps, cycle)[source]#

Each step returns a \(N\) sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding EpochStats object.

Parameters:
  • stats (jax.Array) – stacked stats array for all steps in the sampling. This has shape \((N, I)\).

  • steps (jax.Array) – array with indices of the steps performed. This is a \(I\) sized array which enries are indices \(a = 0, \dots, A_{max} - 1\) identifing which one of the possible alternatives was actually performed at each step.

  • cycle (EpochCycle[Tstate]) – EpochCycle instance.

Returns:

the EpochStats corresponding to these data.

Return type:

EpochStats[Tstate]

to_backend()[source]#

Dump the data of the this EpochStats instance, along with the information from the correponding EpochCycle instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys:

  • steps: numpy.ndarray cats of the object steps attribute

  • counts: numpy.ndarray cats of the object counts attribute

  • stats: numpy.ndarray cats of the object stats attribute

  • steps_defs: data from EpochCycle instance as returned by jexplore.sampling.epoch.EpochCycle.to_backend

Return type:

dict

class StepSample[Tstate: jexplore.sampling.state.State][source]#

Bases: Protocol

Prototype for a step sampling function.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

the new state and the chain changing stats of the step (nchain).

Return type:

tuple[Tstate, jax.Array]

class EpochMH[Tstate: StateMH, Tsampling: SamplingMH](epoch, force_covs=False)[source]#

Bases: jexplore.sampling.epoch.Epoch[Tstate, Tsampling]

Epoch definition class for Metropolis-Hastings.

With respect to the parent jexplore.sampling.epoch.Epoch class, this class define jexplore.sampling.mh.StateMH as default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points.

Parameters:
  • epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.

  • force_covs (bool) – if true forces the recomputing of the covariance matrices.

statecls: type[Tstate] | type[StateMH]#

State class

complete(sampling, compute=None)[source]#

Complete the definition and the parameters of the epoch.

Parameters:
  • sampling (Tsampling) – sampling parameters.

  • compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.

Return type:

Self

With respect to jexplore.sampling.epoch.Epoch.complete, this implementation of the complete method provides a base for the compute parameter with the instructions to compute log likelihood and log prior epoch samples attributes.

Returns:

this object after update.

Parameters:
  • sampling (Tsampling)

  • compute (dict | None)

Return type:

Self

class SamplingMH[Tspace: jexplore.sampling.space.Space](nwalker, temps, loglik, logprior, dim=None, space=None, inpars=None)[source]#

Bases: jexplore.sampling.base.Sampling[Tspace]

Markov sampling defining parameters for Metropolis-Hastings sampling. With respect to the parent jexplore.sampling.base.Sampling class the definition of the sampling includes the following parameters

Parameters:

The nchain parameter is computed from nwalker and temps.

nwalker: int#

number of walkers

temps: jax.Array#

temperature ladder

loglik: jexplore.sampling.state.ArrayFn#

log likelihood function

logprior: jexplore.sampling.state.ArrayFn#

log prior function

inpars: list[str]#

list of input parameters of loglik and logprior

to_backend()[source]#

Save the attributes of the object in backend format.

Return type:

dict

get_sampler(steps=None, backend=None)[source]#

Return a sampler with default steps and backend.

Parameters:
  • steps – list of Step-like instances. If None, use strech and tswap (if more than 1 temp).

  • backend – a Backend instance. If None, use the default one.

Returns:

a JaxSampler instance.

get_epoch(p)[source]#

Return an Epoch object from starting sample, to be use by MH sampler as initial epoch.

Params p:

samples array of shape nwalker x ntemp, dim

Returns:

a EpochMH instance

class StateMH[source]#

Bases: jexplore.sampling.state.State

This class provides the definition of a markovc chain state for a Metropolis-Hastings markov sampler. With respect to the parent jexplore.sampling.state.State class, it simply adds log likelihood and log prior values to the state attributes (operations on these parameters are handled by the parent class methods).

Parameters:
  • p – status point (nchains, dim)

  • ll – log likelihood values for each chain point (nchains, 1).

  • lp – log prior values for each chain point (nchains, 1).

ll: jax.Array#

log likelihood values

lp: jax.Array#

log prior values.

class EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs=False)[source]#

Bases: jexplore.sampling.mh.EpochMH[Tstate, Tsampling]

Epoch definition class for model selection.

With respect to the parent jexplore.sampling.epoch.Epoch class, this class define jexplore.sampling.mh.StateMS as default state class and overrides the complete method to also compute the log likelihood and log prior of the epoch points (with the model selection signature).

Parameters:
  • epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.

  • force_covs (bool) – if true forces the recomputing of the covariance matrices.

statecls#

State class

class SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik=None, logprior=None, allmodsll=None, allmodslp=None, pseudo=None, dim=None, space=None, inpars=None)[source]#

Bases: jexplore.sampling.mh.SamplingMH[Tspace]

Markov sampling defining parameters for model selection. With respect to the parent jexplore.sampling.base.SamplingMH class the definition of the sampling includes the following parameters

Parameters:
  • nwalker (int) – number of walkers per temperature.

  • temps (jax.Array) – temperature ladder

  • masks (list[list[int] | None | int | tuple[int, int]]) –

    list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be:

    1. None: all the space coordinates

    2. integer d: first d dimension of the space

    3. tuple (a, b): defining the correponding coordinates slice

    4. list of integer: explicit list of coordinates indexes

  • loglik (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log likelihood function \(logL(k, p)\), a list of math:logL(p) for each model (in which case \(logL(k, p)\) is defined by branching on the functions of the list) or None (in which case \(logL(k, p)\) is defined by selecting one component of the all models log likelihood defined by allmodsll - thus with a speculative computing approach).

  • allmodsll (jexplore.sampling.state.ArrayFn | None) – function \(vlogL(p)\) returning the values of the log likelihood for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(logL(k, p)\) defined by loglik for all possible k. Note: allmodsll and loglik cannot be both None.

  • logprior (jexplore.sampling.state.ArrayFn | list[jexplore.sampling.state.ArrayFn] | None) – log prior function \(log \pi(k, p)\) or list of math:log pi(p) for each model. (in which case \(log \pi(k, p)\) is defined by branching on the functions of the list) or None (in which case \(log \pi(k, p)\) is defined by selecting one component of the all models log prior defined by allmodslp - thus with a speculative computing approach).

  • allmodslp (jexplore.sampling.state.ArrayFn | None) – function \(vlog \pi(p)\) returning the values of the log prior for all models at the point \(p\) or None. In the latter case it is defined by stacking the return values of \(log \pi(k, p)\) defined by logprior for all possible k. Note: allmodslp and logprior cannot be both None.

  • pseudo (list[jexplore.tools.distributions.Distr | None] | None) –

    list of pseudo prior distributions for all the models. Each element of the list can be:

    1. A jexpore.tools.distributions.Distr object defined on the complementary space of the corresponding model. In which case is taken as it is.

    2. A jexpore.tools.distributions.Distr object defined on the full space. In which case this is masked to match the model mask.

    3. If it is None, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.).

    If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0.

  • dim (int | None) – dimensionality of the full space.

  • space (Tspace | None) – jexplore.sampling.space object describing the target space.

  • inpars (list[str] | None) – list of input parameter(s) used for the computation of log likelihood and log prior.

masks: list[jax.Array]#

mask of parameters indexes of each model space

cmasks: list[jax.Array]#

complementary mask of parameters indexes of each model space

mod_update: jexplore.sampling.state.ArrayFn#

\(f(k, a, b)\) updating array a of shapce self.dim with an array b of the same shape but only on the components of the mask of model k

vmod_update: jexplore.sampling.state.ArrayFn#

vectorized version of self.mod_update so that it can perform the same update operation on vectors k, a and b of shapes (seld.nwalker, ), (seld.nwalker, self.dim) and (seld.nwalker, self.dim)

allmodsll: jexplore.sampling.state.ArrayFn#

Function \(F(p)\) returning the log likelihood for all models

allmodslp: jexplore.sampling.state.ArrayFn#

Function \(\pi(p)\) returning the log prior for all models

ppleval: jexplore.sampling.state.ArrayFn#

Log pseudo prior function \(\tilde{\pi}(k, p)\)

allmodspp: jexplore.sampling.state.ArrayFn#

Function \(\tilde{\pi}(p)\) returning the log pseudo prior for all models

ppdraw: list[jexplore.tools.distributions.DrawFn]#

List of pseudo prior drawing functions

nmodels: int#

number of models

to_backend()[source]#

Save the attributes of the object in backend format.

Return type:

dict

static get_masks(masks, dim)[source]#

Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space

of each one of the models. This must be a list which length is the number of models. Elements of the list can be:

  1. None: all the space coordinates

  2. integer d: first d dimension of the space

  3. tuple (a, b): defining the correponding coordinates slice

  4. list of integer: explicit list of coordinates indexes

Parameters:
  • dim (int) – full space dimension.

  • masks (list[list[int] | None | int | tuple[int, int]])

Returns:

list of model masks (as array of indices) and list of models

Return type:

tuple[list[jax.Array], list[jax.Array]]

complementary masks.

static ldist_from_list(logdists, masks, stack=False)[source]#

Defines a product space log distribution from a list of a single model distributions.

Parameters:
  • logdists (list[jexplore.sampling.state.ArrayFn]) – list of single space distributions \(log P(p)\)

  • masks (list[jax.Array]) – mask of parameters indexes of each model space

  • stack (bool) – if true defines the function returning the values for all models by stacking the results of all the functions in the list.

Returns:

full product space log distribution.

Return type:

jexplore.sampling.state.ArrayFn

class StateMS[source]#

Bases: jexplore.sampling.mh.StateMH

Definition of a markov chain state for model selection. With respect to the parent jexplore.sampling.mh.StateMH class, it adds the model index k for each chain (operations on these parameters are handled by the base class methods).

Parameters:
  • k – space index (nchains, 1)

  • p – status point (nchains, dim)

  • ll – log likelihood values for each chain point (nchains, 1).

  • lp – log prior values for each chain point (nchains, 1).

k: jax.Array#

model index

class Box(dim=None, size=jnp.inf, box=None, wrapped=None)[source]#

Bases: Space

Simple rectangular box space.

Parameters:
  • dim (int | None) – dimension of the box. Only used if the box parameter is not provided.

  • size (float) – size of the box (assumed having equal size in all dimensions) only used if the box parameter is not provided. Default: infinity.

  • box (list[list[float]] | None) – list of lists defining the segments boundaries for each box dimensions. If not provided a box with dim dimensions of equal size size is considered.

  • wrapped (list[int] | None) – list of periodic dimensions indexes. These dimensions will be considererd unbound and the corresponding box intervals will be interpreted as principal domain intervals. Default: empty list.

bounds: jax.Array#

Box bounds

wrap_dims: list[int]#

Wrapped dimensions indexes

wrap_domain: jax.Array#

wrapped dimensions principal domain

dim#

Dimension of the target space

inspace(points)[source]#

Check which of a set of point lay in the defined space.

Parameters:

points (jax.Array) – set of points to be checked, it can have the state shape \((N_{chains}, D)\) or the samples shape \((N_{chains}, D, S)\).

Returns:

a boolean mask of shape \((N_{chains})\) or \((N_{chains}, S)\) selecting points that lay in the defined space.

Return type:

jax.Array

wrap(points)[source]#

folds the wrapped dimensions of the target space.

Parameters:

points (jax.Array) – set of points to be processed, it can have the state shape \((N_{chains}, D)\) or the samples shape \((N_{chains}, D, S)\).

Return type:

jax.Array

return: the same set of points with the wrapped dimensions folded.

class State[source]#

This class provides the base definition of markov chains state.

Such minimal definition correspond to a single parameter p with shape \((N, D)\) describing the point in the \(D\)-dimensional space for each one of the \(N\) chains of the sampling to be performed.

This can be specialized, in child classes, to different types of sampling by adding extra parameters (e.g. jexplore.sampling.mh.StateMH) with shape \((N, d)\), where \(d\) depends on the parameter.

The class also provides basic state manipulation method which are generic for all the child classes defined as above. Some of these methods can also deal with “batched” states. I.e. State objects having parameters with shape \((N, d, I)\) - where \(I\) is the number of performed sampling iteration - as they are returned by the jax.lax.scan loops of the sampler.

Parameters:

p – status point (nchains, dim)

p: jax.Array#

status point

compute(par, func, inpar)[source]#

Populate the values of one parameter \(p\) as a function of other paramters of the state.

Parameters:
  • par (str) – parameter name.

  • func (ArrayFn) – function defining the parameter value.

  • inpar (List[str]) – list of the name of the input parameters of the function.

Returns:

a new state with the populated values.

Return type:

Self

update(update, pars=None)[source]#

Updates a set of parameters \(p_i\) of the state by calling an update function:

\[p_i = u(i, p_i)\]
Parameters:
  • update (Callable[[str, jax.Array], jax.Array]) – the update function.

  • pars (List[str] | None) – list of the names of parameters to be updated. None (default): all parameters.

Returns:

a new state with the updated parameters.

Return type:

Self

update_mask(other, mask=True, pars=None)[source]#

Import values of some parameters, and for some selected chains, from another state. Chains are selected by a boolean mask (1-array of size nchains).

Parameters:
  • other (Self) – other state. The parameters to update should be defined and have the same shape than in the current state.

  • mask (jax.Array | bool) – boolean mask for the chains to be updated.

  • pars (List[str] | None) – list of the name of attributes to be updated.

Returns:

a new state with the updated parameters.

Return type:

Self

update_slice(other, slc, pars=None)[source]#

Import values of some parameters, and for some selected chains, from another state. Chains are selected by an index slice (i.e. a list of integers).

Parameters:
  • other (Self) – other state. The parameters to update should be defined and have the same shape of the chains slice.

  • slc (jax.Array) – list of slice indexes.

  • pars (List[str] | None) – list of the name of attributes to be updated.

Returns:

a new state with the updated parameters.

Return type:

Self

set_val(val, cslc=None, pslc=None, par='p')[source]#

Update one sample parameter slicing both in chains and parameter dimension.

Parameters:
  • val (jax.Array) – values (should have the same shape of the slice)

  • cslc (jax.Array | None) – chains slice. Default: all.

  • pslc (jax.Array | None) – parameter dimensions slice. Default: all.

  • par (str) – parameter name. Default: p

Returns:

new state

Return type:

Self

slice(slc)[source]#

Return a state with parameters values corresponding to a slice of the current state chains. That is, if the current state has \(N\) chains, we define slice as a set of indexes \(\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}\). This method constructs a new state having one parameters

\[\tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1, ~~~~~~\alpha=0,\dots,d-1\]

for each \((N, d)\) shaped parameter \(p\) of the current state.

Parameters:

slc (jax.Array) – the slice as a list of indexes.

Returns:

the new “sliced” state.

Return type:

Self

swap(even, odd, accept)[source]#

Swap all parameters values between different chains.

Parameters:
  • even (jax.Array) – set of indexes to be swapped

  • odd (jax.Array) – other set of indexes to be swapped (should have the same size)

  • accept (jax.Array) – boolean mask to select swaps that should be performed

Returns:

new state with swapped values.

Return type:

Self

get_in_batch(ind=-1, batch=False)[source]#

This method assumes that the current state is batched. Thus its parameters have shape \((N, d, I)\) where \(I\) is the number of performed markov iterations. In this case the method return a new state corresponing to one specific iteration.

Parameters:
  • ind (int) – index of the iteration.

  • batch (bool) – if True the returned state will be a 1-iteration batched state.

Returns:

the selected state.

Return type:

Self

classmethod from_dict(state_d, mandatory=None, batch=False)[source]#

Instantiates a state object from a dictionnary. It will select all entries in the dictionnary that correspond to the parameters of this class of states, and use the corresponding values to define the state.

Parameters:
  • state_d (dict) – input dictionnary.

  • mandatory (List[str] | None) – list of nams of parameters that have to be present in the dictionnary. Non mandatory parameters, if absent, will be replaced by empty arrays. If None all parameters are mandatory.

  • batch (bool) – when true it forces the the instantiated state to have a “batched” shape.

Returns:

the state defined by the dictionnary.

Return type:

Self