jexplore.sampling.state ======================= .. py:module:: jexplore.sampling.state .. autoapi-nested-parse:: This module provides the definitions of the base :py:attr:`jexplore.sampling.state.State` class. Classes ------- .. autoapisummary:: jexplore.sampling.state.ArrayFn jexplore.sampling.state.State Module Contents --------------- .. py:class:: ArrayFn Bases: :py:obj:`Protocol` Protocol for a function with a generic number of :py:attr:`jax.Array` input positional parameter and a single :py:attr:`jax.Array` output. .. py:class:: State This class provides the base definition of markov chains state. Such minimal definition correspond to a single parameter `p` with shape :math:`(N, D)` describing the point in the :math:`D`-dimensional space for each one of the :math:`N` chains of the sampling to be performed. This can be specialized, in child classes, to different types of sampling by adding extra parameters (e.g. :py:attr:`jexplore.sampling.mh.StateMH`) with shape :math:`(N, d)`, where :math:`d` depends on the parameter. The class also provides basic state manipulation method which are generic for all the child classes defined as above. Some of these methods can also deal with "batched" states. I.e. `State` objects having parameters with shape :math:`(N, d, I)` - where :math:`I` is the number of performed sampling iteration - as they are returned by the :py:attr:`jax.lax.scan` loops of the sampler. :param p: status point (nchains, dim) .. py:attribute:: p :type: jax.Array status point .. py:method:: compute(par, func, inpar) Populate the values of one parameter :math:`p` as a function of other paramters of the state. :param par: parameter name. :param func: function defining the parameter value. :param inpar: list of the name of the input parameters of the function. :return: a new state with the populated values. .. py:method:: update(update, pars = None) Updates a set of parameters :math:`p_i` of the state by calling an update function: .. math:: p_i = u(i, p_i) :param update: the update function. :param pars: list of the names of parameters to be updated. None (default): all parameters. :returns: a new state with the updated parameters. .. py:method:: update_mask(other, mask = True, pars = None) Import values of some parameters, and for some selected chains, from another state. Chains are selected by a boolean mask (1-array of size `nchains`). :param other: other state. The parameters to update should be defined and have the same shape than in the current state. :param mask: boolean mask for the chains to be updated. :param pars: list of the name of attributes to be updated. :return: a new state with the updated parameters. .. py:method:: update_slice(other, slc, pars = None) Import values of some parameters, and for some selected chains, from another state. Chains are selected by an index slice (i.e. a list of integers). :param other: other state. The parameters to update should be defined and have the same shape of the chains slice. :param slc: list of slice indexes. :param pars: list of the name of attributes to be updated. :return: a new state with the updated parameters. .. py:method:: set_val(val, cslc = None, pslc = None, par = 'p') Update one sample parameter slicing both in chains and parameter dimension. :param val: values (should have the same shape of the slice) :param cslc: chains slice. Default: all. :param pslc: parameter dimensions slice. Default: all. :param par: parameter name. Default: p :return: new state .. py:method:: slice(slc) Return a state with parameters values corresponding to a slice of the current state chains. That is, if the current state has :math:`N` chains, we define slice as a set of indexes :math:`\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}`. This method constructs a new state having one parameters .. math:: \tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1, ~~~~~~\alpha=0,\dots,d-1 for each :math:`(N, d)` shaped parameter :math:`p` of the current state. :param slc: the slice as a list of indexes. :returns: the new "sliced" state. .. py:method:: swap(even, odd, accept) Swap all parameters values between different chains. :param even: set of indexes to be swapped :param odd: other set of indexes to be swapped (should have the same size) :param accept: boolean mask to select swaps that should be performed :return: new state with swapped values. .. py:method:: get_in_batch(ind = -1, batch = False) This method assumes that the current state is batched. Thus its parameters have shape :math:`(N, d, I)` where :math:`I` is the number of performed markov iterations. In this case the method return a new state corresponing to one specific iteration. :param ind: index of the iteration. :param batch: if True the returned state will be a 1-iteration batched state. :returns: the selected state. .. py:method:: from_dict(state_d, mandatory = None, batch = False) :classmethod: Instantiates a state object from a dictionnary. It will select all entries in the dictionnary that correspond to the parameters of this class of states, and use the corresponding values to define the state. :param state_d: input dictionnary. :param mandatory: list of nams of parameters that have to be present in the dictionnary. Non mandatory parameters, if absent, will be replaced by empty arrays. If `None` all parameters are mandatory. :param batch: when true it forces the the instantiated state to have a "batched" shape. :returns: the state defined by the dictionnary.