jexplore.sampling.epoch ======================= .. py:module:: jexplore.sampling.epoch .. autoapi-nested-parse:: This module provides the definitions of the base :py:attr:`jexplore.sampling.state.Epoch` class and of some ancillary classes and types needed for running the markov chain sampler. Classes ------- .. autoapisummary:: jexplore.sampling.epoch.StepSample jexplore.sampling.epoch.EpochCycle jexplore.sampling.epoch.EpochStats jexplore.sampling.epoch.Epoch Module Contents --------------- .. py:class:: StepSample[Tstate: jexplore.sampling.state.State] Bases: :py:obj:`Protocol` Prototype for a step sampling function. :param jax.Array key: PRNG key :param Tstate state: current state :return: the new state and the chain changing stats of the step (nchain). :rtype: tuple[Tstate, jax.Array] .. py:class:: EpochCycle[Tstate: jexplore.sampling.state.State](steps) Class containing the definition of an epoch steps cycle. :param steps: list of substeps each one defined by a dictionnary of alternative steps with their weight. .. py:attribute:: steps :type: tuple[StepSample[Tstate], Ellipsis] tuple of all the steps. .. py:attribute:: weights :type: jax.Array Matrix of weights. This has shape :math:`(S, A_{max})` where :math:`S` is the number of serial substeps in the cycle and :math:`A_{max}` is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative. .. py:attribute:: isteps :type: jax.Array :math:`(S, A_{max})` matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding `StepSample` function in the `steps` tuple. .. py:attribute:: n :type: int number of sequential substeps. .. py:method:: to_backend() dumps the cycle definition into a `numpy` array. This array has a row for each one of the defined substeps/alternatives containing * `name`: the qualified name of the `StepSample` function. * `incycle`: the index of the substep * `weight`: the weight of the corresponding substep alternative. .. py:class:: EpochStats[Tstate: jexplore.sampling.state.State] Class statistics the aggregated statistics of the epoch steps. :param steps: array of steps performed at each iteration. :param counts: counters of the number of executions of each step. :param stats: chains acceptance statistics for each step. .. py:attribute:: steps :type: jax.Array This is :math:`I` sized array Where :math:`I` is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding `steps` tuple in the epoch's `EpochCycle` instance. .. py:attribute:: counts :type: jax.Array this array has the same size as the `steps` tuple in the epoch's `EpochCycle` instance. Each element contains the count of how many time the corresponding step was performed during the sampling. .. py:attribute:: stats :type: jax.Array this array has shape :math:`(N, C)` where :math:`C` is the size of the `steps` tuple in the epoch's `EpochCycle` instance (that is the number of distinct steps) and :math:`N` the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH). .. py:method:: from_run(stats, steps, cycle) :staticmethod: Each step returns a :math:`N` sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding `EpochStats` object. :param stats: stacked stats array for all steps in the sampling. This has shape :math:`(N, I)`. :param steps: array with indices of the steps performed. This is a :math:`I` sized array which enries are indices :math:`a = 0, \dots, A_{max} - 1` identifing which one of the possible alternatives was actually performed at each step. :param cycle: `EpochCycle` instance. :return: the `EpochStats` corresponding to these data. .. py:method:: to_backend() Dump the data of the this `EpochStats` instance, along with the information from the correponding `EpochCycle` instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys: * `steps`: `numpy.ndarray` cats of the object `steps` attribute * `counts`: `numpy.ndarray` cats of the object `counts` attribute * `stats`: `numpy.ndarray` cats of the object `stats` attribute * `steps_defs`: data from `EpochCycle` instance as returned by :py:attr:`jexplore.sampling.epoch.EpochCycle.to_backend` .. py:class:: Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs = False) Base Epoch definition class. This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run. :param epoch: `Epoch` object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples. :param force_covs: if true forces the recomputing of the covariance matrices. .. py:attribute:: samples :type: Tstate batched state defining the sampling states sequence. .. py:attribute:: covs :type: jax.Array Array with shape :math:`(N, D, D)` containing the list of covariance matrices of all chains. .. py:attribute:: stats :type: EpochStats[Tstate] epoch statistics. .. py:attribute:: cycle :type: EpochCycle[Tstate] definition of the epoch cycle .. py:attribute:: sampling :type: Tsampling sampling definition .. py:attribute:: statecls :type: type[Tstate] | type[jexplore.sampling.state.State] state class corresponding to this epoch class .. py:method:: complete(sampling, compute = None) Complete the definition and the parameters of the epoch. :param sampling: sampling parameters. :param compute: dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are `(func, inpars)` tuples, where `func` is a callable for the computation of the corresponding parameter and `inpars` is the list of the names if the input parameters. :return: this object after update. .. py:method:: set_cycle(cycle) Sets this epoch cycle. :param cycle: the `EpochCycle` value to be set. .. py:method:: run_step(key, ind, state) Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it. :param key: PRNG key :param ind: the index of the substep :param state: current state :return: the new state, the step stats and the index of the substep alternative step ctualy performed. .. py:method:: from_run(states, stats, steps) Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data. :param states: stack of states. :param stats: stack of steps stats :param steps: stack of steps indexes :return: this object after updating .. py:method:: to_backend(burn, stats) dumps epoch in backend format