jexplore.sampling.epoch#
This module provides the definitions of the base
jexplore.sampling.state.Epoch class and of some
ancillary classes and types needed for running
the markov chain sampler.
Classes#
Prototype for a step sampling function. |
|
Class containing the definition of an epoch steps cycle. |
|
Class statistics the aggregated statistics of the epoch steps. |
|
Base Epoch definition class. |
Module Contents#
- class StepSample[Tstate: jexplore.sampling.state.State][source]#
Bases:
ProtocolPrototype for a step sampling function.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
- Returns:
the new state and the chain changing stats of the step (nchain).
- Return type:
tuple[Tstate, jax.Array]
- class EpochCycle[Tstate: jexplore.sampling.state.State](steps)[source]#
Class containing the definition of an epoch steps cycle.
- Parameters:
steps (list[dict[StepSample[Tstate], float]]) – list of substeps each one defined by a dictionnary of alternative steps with their weight.
- steps: tuple[StepSample[Tstate], Ellipsis]#
tuple of all the steps.
- weights: jax.Array#
Matrix of weights. This has shape \((S, A_{max})\) where \(S\) is the number of serial substeps in the cycle and \(A_{max}\) is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative.
- isteps: jax.Array#
\((S, A_{max})\) matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding StepSample function in the steps tuple.
- n: int#
number of sequential substeps.
- to_backend()[source]#
dumps the cycle definition into a numpy array. This array has a row for each one of the defined substeps/alternatives containing
name: the qualified name of the StepSample function.
incycle: the index of the substep
weight: the weight of the corresponding substep alternative.
- Return type:
numpy.ndarray
- class EpochStats[Tstate: jexplore.sampling.state.State][source]#
Class statistics the aggregated statistics of the epoch steps.
- Parameters:
steps – array of steps performed at each iteration.
counts – counters of the number of executions of each step.
stats – chains acceptance statistics for each step.
- steps: jax.Array#
This is \(I\) sized array Where \(I\) is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding steps tuple in the epoch’s EpochCycle instance.
- counts: jax.Array#
this array has the same size as the steps tuple in the epoch’s EpochCycle instance. Each element contains the count of how many time the corresponding step was performed during the sampling.
- stats: jax.Array#
this array has shape \((N, C)\) where \(C\) is the size of the steps tuple in the epoch’s EpochCycle instance (that is the number of distinct steps) and \(N\) the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH).
- static from_run(stats, steps, cycle)[source]#
Each step returns a \(N\) sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding EpochStats object.
- Parameters:
stats (jax.Array) – stacked stats array for all steps in the sampling. This has shape \((N, I)\).
steps (jax.Array) – array with indices of the steps performed. This is a \(I\) sized array which enries are indices \(a = 0, \dots, A_{max} - 1\) identifing which one of the possible alternatives was actually performed at each step.
cycle (EpochCycle[Tstate]) – EpochCycle instance.
- Returns:
the EpochStats corresponding to these data.
- Return type:
EpochStats[Tstate]
- to_backend()[source]#
Dump the data of the this EpochStats instance, along with the information from the correponding EpochCycle instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys:
steps: numpy.ndarray cats of the object steps attribute
counts: numpy.ndarray cats of the object counts attribute
stats: numpy.ndarray cats of the object stats attribute
steps_defs: data from EpochCycle instance as returned by
jexplore.sampling.epoch.EpochCycle.to_backend
- Return type:
dict
- class Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs=False)[source]#
Base Epoch definition class.
This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run.
- Parameters:
epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.
force_covs (bool) – if true forces the recomputing of the covariance matrices.
- samples: Tstate#
batched state defining the sampling states sequence.
- covs: jax.Array#
Array with shape \((N, D, D)\) containing the list of covariance matrices of all chains.
- stats: EpochStats[Tstate]#
epoch statistics.
- cycle: EpochCycle[Tstate]#
definition of the epoch cycle
- sampling: Tsampling#
sampling definition
- statecls: type[Tstate] | type[jexplore.sampling.state.State]#
state class corresponding to this epoch class
- complete(sampling, compute=None)[source]#
Complete the definition and the parameters of the epoch.
- Parameters:
sampling (Tsampling) – sampling parameters.
compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.
- Returns:
this object after update.
- Return type:
Self
- set_cycle(cycle)[source]#
Sets this epoch cycle.
- Parameters:
cycle (EpochCycle[Tstate]) – the EpochCycle value to be set.
- Return type:
None
- run_step(key, ind, state)[source]#
Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it.
- Parameters:
key (jax.Array) – PRNG key
ind (int) – the index of the substep
state (Tstate) – current state
- Returns:
the new state, the step stats and the index of the substep alternative step ctualy performed.
- Return type:
tuple[Tstate, jax.Array, jax.Array]
- from_run(states, stats, steps)[source]#
Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data.
- Parameters:
states (Tstate) – stack of states.
stats (jax.Array) – stack of steps stats
steps (jax.Array) – stack of steps indexes
- Returns:
this object after updating
- Return type:
Self