jexplore.sampling ================= .. py:module:: jexplore.sampling .. autoapi-nested-parse:: Classes for sampling Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/jexplore/sampling/base/index /autoapi/jexplore/sampling/epoch/index /autoapi/jexplore/sampling/mh/index /autoapi/jexplore/sampling/modsel/index /autoapi/jexplore/sampling/space/index /autoapi/jexplore/sampling/state/index Classes ------- .. autoapisummary:: jexplore.sampling.Sampling jexplore.sampling.Epoch jexplore.sampling.EpochCycle jexplore.sampling.EpochStats jexplore.sampling.StepSample jexplore.sampling.EpochMH jexplore.sampling.SamplingMH jexplore.sampling.StateMH jexplore.sampling.EpochMS jexplore.sampling.SamplingMS jexplore.sampling.StateMS jexplore.sampling.Box jexplore.sampling.State Package Contents ---------------- .. py:class:: Sampling[Tspace: jexplore.sampling.space.Space](nchain, dim = None, space = None) Markov sampling defining parameters. In this base class version the parameter are the dimension of the target space and the number of chains. Child classes may specialize to specific type of markov sampling (e.g. :py:attr:`jexplore.sampling.mh.SamplingMH`) :param dim: dimension of the target space. :param nchains: number of chains. :param space: :py:attr:`jexplore.sampling.space` object describing the target space. the state will then be defined by a (nchain, dim) point .. py:attribute:: dim :type: int Dimension of the target space .. py:attribute:: nchain :type: int Number of chains .. py:attribute:: space :type: Tspace | jexplore.sampling.space.Box Target space. .. py:method:: to_backend() Save the attributes of the object in backend format. .. py:class:: Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs = False) Base Epoch definition class. This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run. :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. .. py:attribute:: samples :type: Tstate batched state defining the sampling states sequence. .. py:attribute:: covs :type: jax.Array Array with shape :math:`(N, D, D)` containing the list of covariance matrices of all chains. .. py:attribute:: stats :type: EpochStats[Tstate] epoch statistics. .. py:attribute:: cycle :type: EpochCycle[Tstate] definition of the epoch cycle .. py:attribute:: sampling :type: Tsampling sampling definition .. py:attribute:: statecls :type: type[Tstate] | type[jexplore.sampling.state.State] state class corresponding to this epoch class .. py:method:: complete(sampling, compute = None) Complete the definition and the parameters of the epoch. :param sampling: sampling parameters. :param compute: dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are `(func, inpars)` tuples, where `func` is a callable for the computation of the corresponding parameter and `inpars` is the list of the names if the input parameters. :return: this object after update. .. py:method:: set_cycle(cycle) Sets this epoch cycle. :param cycle: the `EpochCycle` value to be set. .. py:method:: run_step(key, ind, state) Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it. :param key: PRNG key :param ind: the index of the substep :param state: current state :return: the new state, the step stats and the index of the substep alternative step ctualy performed. .. py:method:: from_run(states, stats, steps) Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data. :param states: stack of states. :param stats: stack of steps stats :param steps: stack of steps indexes :return: this object after updating .. py:method:: to_backend(burn, stats) dumps epoch in backend format .. py:class:: EpochCycle[Tstate: jexplore.sampling.state.State](steps) Class containing the definition of an epoch steps cycle. :param steps: list of substeps each one defined by a dictionnary of alternative steps with their weight. .. py:attribute:: steps :type: tuple[StepSample[Tstate], Ellipsis] tuple of all the steps. .. py:attribute:: weights :type: jax.Array Matrix of weights. This has shape :math:`(S, A_{max})` where :math:`S` is the number of serial substeps in the cycle and :math:`A_{max}` is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative. .. py:attribute:: isteps :type: jax.Array :math:`(S, A_{max})` matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding `StepSample` function in the `steps` tuple. .. py:attribute:: n :type: int number of sequential substeps. .. py:method:: to_backend() dumps the cycle definition into a `numpy` array. This array has a row for each one of the defined substeps/alternatives containing * `name`: the qualified name of the `StepSample` function. * `incycle`: the index of the substep * `weight`: the weight of the corresponding substep alternative. .. py:class:: EpochStats[Tstate: jexplore.sampling.state.State] Class statistics the aggregated statistics of the epoch steps. :param steps: array of steps performed at each iteration. :param counts: counters of the number of executions of each step. :param stats: chains acceptance statistics for each step. .. py:attribute:: steps :type: jax.Array This is :math:`I` sized array Where :math:`I` is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding `steps` tuple in the epoch's `EpochCycle` instance. .. py:attribute:: counts :type: jax.Array this array has the same size as the `steps` tuple in the epoch's `EpochCycle` instance. Each element contains the count of how many time the corresponding step was performed during the sampling. .. py:attribute:: stats :type: jax.Array this array has shape :math:`(N, C)` where :math:`C` is the size of the `steps` tuple in the epoch's `EpochCycle` instance (that is the number of distinct steps) and :math:`N` the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH). .. py:method:: from_run(stats, steps, cycle) :staticmethod: Each step returns a :math:`N` sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding `EpochStats` object. :param stats: stacked stats array for all steps in the sampling. This has shape :math:`(N, I)`. :param steps: array with indices of the steps performed. This is a :math:`I` sized array which enries are indices :math:`a = 0, \dots, A_{max} - 1` identifing which one of the possible alternatives was actually performed at each step. :param cycle: `EpochCycle` instance. :return: the `EpochStats` corresponding to these data. .. py:method:: to_backend() Dump the data of the this `EpochStats` instance, along with the information from the correponding `EpochCycle` instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys: * `steps`: `numpy.ndarray` cats of the object `steps` attribute * `counts`: `numpy.ndarray` cats of the object `counts` attribute * `stats`: `numpy.ndarray` cats of the object `stats` attribute * `steps_defs`: data from `EpochCycle` instance as returned by :py:attr:`jexplore.sampling.epoch.EpochCycle.to_backend` .. py:class:: StepSample[Tstate: jexplore.sampling.state.State] Bases: :py:obj:`Protocol` Prototype for a step sampling function. :param jax.Array key: PRNG key :param Tstate state: current state :return: the new state and the chain changing stats of the step (nchain). :rtype: tuple[Tstate, jax.Array] .. py:class:: EpochMH[Tstate: StateMH, Tsampling: SamplingMH](epoch, force_covs = False) Bases: :py:obj:`jexplore.sampling.epoch.Epoch`\ [\ :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Epoch definition class for Metropolis-Hastings. With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class, this class define :py:attr:`jexplore.sampling.mh.StateMH` as default state class and overrides the `complete` method to also compute the log likelihood and log prior of the epoch points. :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. .. py:attribute:: statecls :type: type[Tstate] | type[StateMH] State class .. py:method:: complete(sampling, compute = None) Complete the definition and the parameters of the epoch. :param sampling: sampling parameters. :param compute: dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are `(func, inpars)` tuples, where `func` is a callable for the computation of the corresponding parameter and `inpars` is the list of the names if the input parameters. With respect to :py:attr:`jexplore.sampling.epoch.Epoch.complete`, this implementation of the `complete` method provides a base for the `compute` parameter with the instructions to compute log likelihood and log prior epoch samples attributes. :return: this object after update. .. py:class:: SamplingMH[Tspace: jexplore.sampling.space.Space](nwalker, temps, loglik, logprior, dim = None, space = None, inpars = None) Bases: :py:obj:`jexplore.sampling.base.Sampling`\ [\ :py:obj:`Tspace`\ ] Markov sampling defining parameters for Metropolis-Hastings sampling. With respect to the parent :py:attr:`jexplore.sampling.base.Sampling` class the definition of the sampling includes the following parameters :param nwalker: number of walkers per temperature. :param temps: temperature ladder :param loglik: log likelihood function :param logprior: log prior function The `nchain` parameter is computed from `nwalker` and `temps`. .. py:attribute:: nwalker :type: int number of walkers .. py:attribute:: temps :type: jax.Array temperature ladder .. py:attribute:: loglik :type: jexplore.sampling.state.ArrayFn log likelihood function .. py:attribute:: logprior :type: jexplore.sampling.state.ArrayFn log prior function .. py:attribute:: inpars :type: list[str] list of input parameters of loglik and logprior .. py:method:: to_backend() Save the attributes of the object in backend format. .. py:method:: get_sampler(steps=None, backend=None) Return a sampler with default steps and backend. :param steps: list of `Step`-like instances. If None, use strech and tswap (if more than 1 temp). :param backend: a `Backend` instance. If None, use the default one. :return: a `JaxSampler` instance. .. py:method:: get_epoch(p) Return an `Epoch` object from starting sample, to be use by MH sampler as initial epoch. :params p: samples array of shape nwalker x ntemp, dim :return: a `EpochMH` instance .. py:class:: StateMH Bases: :py:obj:`jexplore.sampling.state.State` This class provides the definition of a markovc chain state for a Metropolis-Hastings markov sampler. With respect to the parent :py:attr:`jexplore.sampling.state.State` class, it simply adds log likelihood and log prior values to the state attributes (operations on these parameters are handled by the parent class methods). :param p: status point (nchains, dim) :param ll: log likelihood values for each chain point (nchains, 1). :param lp: log prior values for each chain point (nchains, 1). .. py:attribute:: ll :type: jax.Array log likelihood values .. py:attribute:: lp :type: jax.Array log prior values. .. py:class:: EpochMS[Tstate: StateMS, Tsampling: SamplingMS](epoch, force_covs = False) Bases: :py:obj:`jexplore.sampling.mh.EpochMH`\ [\ :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Epoch definition class for model selection. With respect to the parent :py:attr:`jexplore.sampling.epoch.Epoch` class, this class define :py:attr:`jexplore.sampling.mh.StateMS` as default state class and overrides the `complete` method to also compute the log likelihood and log prior of the epoch points (with the model selection signature). :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. .. py:attribute:: statecls State class .. py:class:: SamplingMS[Tspace: jexplore.sampling.space.Space](nwalker, temps, masks, loglik = None, logprior = None, allmodsll = None, allmodslp = None, pseudo = None, dim = None, space = None, inpars = None) Bases: :py:obj:`jexplore.sampling.mh.SamplingMH`\ [\ :py:obj:`Tspace`\ ] Markov sampling defining parameters for model selection. With respect to the parent :py:attr:`jexplore.sampling.base.SamplingMH` class the definition of the sampling includes the following parameters :param nwalker: number of walkers per temperature. :param temps: temperature ladder :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param loglik: log likelihood function :math:`logL(k, p)`, a list of math:`logL(p)` for each model (in which case :math:`logL(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`logL(k, p)` is defined by selecting one component of the all models log likelihood defined by `allmodsll` - thus with a speculative computing approach). :param allmodsll: function :math:`vlogL(p)` returning the values of the log likelihood for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`logL(k, p)` defined by `loglik` for all possible `k`. Note: `allmodsll` and `loglik` cannot be both `None`. :param logprior: log prior function :math:`log \pi(k, p)` or list of math:`log \pi(p)` for each model. (in which case :math:`log \pi(k, p)` is defined by branching on the functions of the list) or `None` (in which case :math:`log \pi(k, p)` is defined by selecting one component of the all models log prior defined by `allmodslp` - thus with a speculative computing approach). :param allmodslp: function :math:`vlog \pi(p)` returning the values of the log prior for all models at the point :math:`p` or `None`. In the latter case it is defined by stacking the return values of :math:`log \pi(k, p)` defined by `logprior` for all possible `k`. Note: `allmodslp` and `logprior` cannot be both `None`. :param pseudo: list of pseudo prior distributions for all the models. Each element of the list can be: #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the complementary space of the corresponding model. In which case is taken as it is. #. A :py:attr:`jexpore.tools.distributions.Distr` object defined on the full space. In which case this is masked to match the model mask. #. If it is `None`, of if the model takes tho whole space, a 0. centered delta distribution is used (and the evaluation returns identically 0.). If the whole argument is None (Default), a list of None elements is assumed and the distribution evaluation is identically 0. :param dim: dimensionality of the full space. :param space: :py:attr:`jexplore.sampling.space` object describing the target space. :param inpars: list of input parameter(s) used for the computation of log likelihood and log prior. .. py:attribute:: masks :type: list[jax.Array] mask of parameters indexes of each model space .. py:attribute:: cmasks :type: list[jax.Array] complementary mask of parameters indexes of each model space .. py:attribute:: mod_update :type: jexplore.sampling.state.ArrayFn :math:`f(k, a, b)` updating array `a` of shapce `self.dim` with an array `b` of the same shape but only on the components of the mask of model `k` .. py:attribute:: vmod_update :type: jexplore.sampling.state.ArrayFn vectorized version of `self.mod_update` so that it can perform the same update operation on vectors `k`, `a` and `b` of shapes `(seld.nwalker, )`, `(seld.nwalker, self.dim)` and `(seld.nwalker, self.dim)` .. py:attribute:: allmodsll :type: jexplore.sampling.state.ArrayFn Function :math:`F(p)` returning the log likelihood for all models .. py:attribute:: allmodslp :type: jexplore.sampling.state.ArrayFn Function :math:`\pi(p)` returning the log prior for all models .. py:attribute:: ppleval :type: jexplore.sampling.state.ArrayFn Log pseudo prior function :math:`\tilde{\pi}(k, p)` .. py:attribute:: allmodspp :type: jexplore.sampling.state.ArrayFn Function :math:`\tilde{\pi}(p)` returning the log pseudo prior for all models .. py:attribute:: ppdraw :type: list[jexplore.tools.distributions.DrawFn] List of pseudo prior drawing functions .. py:attribute:: nmodels :type: int number of models .. py:method:: to_backend() Save the attributes of the object in backend format. .. py:method:: get_masks(masks, dim) :staticmethod: Get model masks as lists of indices. :param masks: list of masks identifying the coordinates of the space of each one of the models. This must be a list which length is the number of models. Elements of the list can be: #. `None`: all the space coordinates #. integer `d`: first `d` dimension of the space #. tuple (`a`, `b`): defining the correponding coordinates slice #. list of integer: explicit list of coordinates indexes :param dim: full space dimension. :return: list of model masks (as array of indices) and list of models complementary masks. .. py:method:: ldist_from_list(logdists, masks, stack = False) :staticmethod: Defines a product space log distribution from a list of a single model distributions. :param logdists: list of single space distributions :math:`log P(p)` :param masks: mask of parameters indexes of each model space :param stack: if true defines the function returning the values for all models by stacking the results of all the functions in the list. :return: full product space log distribution. .. py:class:: StateMS Bases: :py:obj:`jexplore.sampling.mh.StateMH` Definition of a markov chain state for model selection. With respect to the parent :py:attr:`jexplore.sampling.mh.StateMH` class, it adds the model index `k` for each chain (operations on these parameters are handled by the base class methods). :param k: space index (nchains, 1) :param p: status point (nchains, dim) :param ll: log likelihood values for each chain point (nchains, 1). :param lp: log prior values for each chain point (nchains, 1). .. py:attribute:: k :type: jax.Array model index .. py:class:: Box(dim = None, size = jnp.inf, box = None, wrapped = None) Bases: :py:obj:`Space` Simple rectangular box space. :param dim: dimension of the box. Only used if the `box` parameter is not provided. :param size: size of the box (assumed having equal size in all dimensions) only used if the `box` parameter is not provided. Default: infinity. :param box: list of lists defining the segments boundaries for each box dimensions. If not provided a box with `dim` dimensions of equal size `size` is considered. :param wrapped: list of periodic dimensions indexes. These dimensions will be considererd unbound and the corresponding box `intervals` will be interpreted as principal domain intervals. Default: empty list. .. py:attribute:: bounds :type: jax.Array Box bounds .. py:attribute:: wrap_dims :type: list[int] Wrapped dimensions indexes .. py:attribute:: wrap_domain :type: jax.Array wrapped dimensions principal domain .. py:attribute:: dim Dimension of the target space .. py:method:: inspace(points) Check which of a set of point lay in the defined space. :param points: set of points to be checked, it can have the `state` shape :math:`(N_{chains}, D)` or the `samples` shape :math:`(N_{chains}, D, S)`. :return: a boolean mask of shape :math:`(N_{chains})` or :math:`(N_{chains}, S)` selecting points that lay in the defined space. .. py:method:: wrap(points) folds the wrapped dimensions of the target space. :param points: set of points to be processed, it can have the `state` shape :math:`(N_{chains}, D)` or the `samples` shape :math:`(N_{chains}, D, S)`. return: the same set of points with the wrapped dimensions folded. .. py:class:: State This class provides the base definition of markov chains state. Such minimal definition correspond to a single parameter `p` with shape :math:`(N, D)` describing the point in the :math:`D`-dimensional space for each one of the :math:`N` chains of the sampling to be performed. This can be specialized, in child classes, to different types of sampling by adding extra parameters (e.g. :py:attr:`jexplore.sampling.mh.StateMH`) with shape :math:`(N, d)`, where :math:`d` depends on the parameter. The class also provides basic state manipulation method which are generic for all the child classes defined as above. Some of these methods can also deal with "batched" states. I.e. `State` objects having parameters with shape :math:`(N, d, I)` - where :math:`I` is the number of performed sampling iteration - as they are returned by the :py:attr:`jax.lax.scan` loops of the sampler. :param p: status point (nchains, dim) .. py:attribute:: p :type: jax.Array status point .. py:method:: compute(par, func, inpar) Populate the values of one parameter :math:`p` as a function of other paramters of the state. :param par: parameter name. :param func: function defining the parameter value. :param inpar: list of the name of the input parameters of the function. :return: a new state with the populated values. .. py:method:: update(update, pars = None) Updates a set of parameters :math:`p_i` of the state by calling an update function: .. math:: p_i = u(i, p_i) :param update: the update function. :param pars: list of the names of parameters to be updated. None (default): all parameters. :returns: a new state with the updated parameters. .. py:method:: update_mask(other, mask = True, pars = None) Import values of some parameters, and for some selected chains, from another state. Chains are selected by a boolean mask (1-array of size `nchains`). :param other: other state. The parameters to update should be defined and have the same shape than in the current state. :param mask: boolean mask for the chains to be updated. :param pars: list of the name of attributes to be updated. :return: a new state with the updated parameters. .. py:method:: update_slice(other, slc, pars = None) Import values of some parameters, and for some selected chains, from another state. Chains are selected by an index slice (i.e. a list of integers). :param other: other state. The parameters to update should be defined and have the same shape of the chains slice. :param slc: list of slice indexes. :param pars: list of the name of attributes to be updated. :return: a new state with the updated parameters. .. py:method:: set_val(val, cslc = None, pslc = None, par = 'p') Update one sample parameter slicing both in chains and parameter dimension. :param val: values (should have the same shape of the slice) :param cslc: chains slice. Default: all. :param pslc: parameter dimensions slice. Default: all. :param par: parameter name. Default: p :return: new state .. py:method:: slice(slc) Return a state with parameters values corresponding to a slice of the current state chains. That is, if the current state has :math:`N` chains, we define slice as a set of indexes :math:`\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}`. This method constructs a new state having one parameters .. math:: \tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1, ~~~~~~\alpha=0,\dots,d-1 for each :math:`(N, d)` shaped parameter :math:`p` of the current state. :param slc: the slice as a list of indexes. :returns: the new "sliced" state. .. py:method:: swap(even, odd, accept) Swap all parameters values between different chains. :param even: set of indexes to be swapped :param odd: other set of indexes to be swapped (should have the same size) :param accept: boolean mask to select swaps that should be performed :return: new state with swapped values. .. py:method:: get_in_batch(ind = -1, batch = False) This method assumes that the current state is batched. Thus its parameters have shape :math:`(N, d, I)` where :math:`I` is the number of performed markov iterations. In this case the method return a new state corresponing to one specific iteration. :param ind: index of the iteration. :param batch: if True the returned state will be a 1-iteration batched state. :returns: the selected state. .. py:method:: from_dict(state_d, mandatory = None, batch = False) :classmethod: Instantiates a state object from a dictionnary. It will select all entries in the dictionnary that correspond to the parameters of this class of states, and use the corresponding values to define the state. :param state_d: input dictionnary. :param mandatory: list of nams of parameters that have to be present in the dictionnary. Non mandatory parameters, if absent, will be replaced by empty arrays. If `None` all parameters are mandatory. :param batch: when true it forces the the instantiated state to have a "batched" shape. :returns: the state defined by the dictionnary.