jexplore.sampling.epoch#

This module provides the definitions of the base jexplore.sampling.state.Epoch class and of some ancillary classes and types needed for running the markov chain sampler.

Classes#

StepSample

Prototype for a step sampling function.

EpochCycle

Class containing the definition of an epoch steps cycle.

EpochStats

Class statistics the aggregated statistics of the epoch steps.

Epoch

Base Epoch definition class.

Module Contents#

class StepSample[Tstate: jexplore.sampling.state.State][source]#

Bases: Protocol

Prototype for a step sampling function.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

Returns:

the new state and the chain changing stats of the step (nchain).

Return type:

tuple[Tstate, jax.Array]

class EpochCycle[Tstate: jexplore.sampling.state.State](steps)[source]#

Class containing the definition of an epoch steps cycle.

Parameters:

steps (list[dict[StepSample[Tstate], float]]) – list of substeps each one defined by a dictionnary of alternative steps with their weight.

steps: tuple[StepSample[Tstate], Ellipsis]#

tuple of all the steps.

weights: jax.Array#

Matrix of weights. This has shape \((S, A_{max})\) where \(S\) is the number of serial substeps in the cycle and \(A_{max}\) is the max number of alternatives in all these substeps. Each element of the matrix that actually corresponds to a defined step/alternative contains the weight of the corresponding alternative.

isteps: jax.Array#

\((S, A_{max})\) matrix. Each element of the matrix that actually corresponds to a defined step/alternative contains the index of the corresponding StepSample function in the steps tuple.

n: int#

number of sequential substeps.

to_backend()[source]#

dumps the cycle definition into a numpy array. This array has a row for each one of the defined substeps/alternatives containing

  • name: the qualified name of the StepSample function.

  • incycle: the index of the substep

  • weight: the weight of the corresponding substep alternative.

Return type:

numpy.ndarray

class EpochStats[Tstate: jexplore.sampling.state.State][source]#

Class statistics the aggregated statistics of the epoch steps.

Parameters:
  • steps – array of steps performed at each iteration.

  • counts – counters of the number of executions of each step.

  • stats – chains acceptance statistics for each step.

steps: jax.Array#

This is \(I\) sized array Where \(I\) is the total number of steps performed during the sampling. The entries of the array are the index of the corresponding steps tuple in the epoch’s EpochCycle instance.

counts: jax.Array#

this array has the same size as the steps tuple in the epoch’s EpochCycle instance. Each element contains the count of how many time the corresponding step was performed during the sampling.

stats: jax.Array#

this array has shape \((N, C)\) where \(C\) is the size of the steps tuple in the epoch’s EpochCycle instance (that is the number of distinct steps) and \(N\) the number of chains. The values correspond to statistics of how many time each chain was modified by each type of step (i.e. the proposal acceptance rate in the case of MH).

static from_run(stats, steps, cycle)[source]#

Each step returns a \(N\) sized boolean map of the chains that have been changed by the step and the indexes identifyng the stap performed. An epoch run returns the stacking of these arrays for each iteration. This method aggregate and organize this information and instantiate the corresponding EpochStats object.

Parameters:
  • stats (jax.Array) – stacked stats array for all steps in the sampling. This has shape \((N, I)\).

  • steps (jax.Array) – array with indices of the steps performed. This is a \(I\) sized array which enries are indices \(a = 0, \dots, A_{max} - 1\) identifing which one of the possible alternatives was actually performed at each step.

  • cycle (EpochCycle[Tstate]) – EpochCycle instance.

Returns:

the EpochStats corresponding to these data.

Return type:

EpochStats[Tstate]

to_backend()[source]#

Dump the data of the this EpochStats instance, along with the information from the correponding EpochCycle instance, to a dictionnary in a form suitable for dumping to backend. The output dictionnary has keys:

  • steps: numpy.ndarray cats of the object steps attribute

  • counts: numpy.ndarray cats of the object counts attribute

  • stats: numpy.ndarray cats of the object stats attribute

  • steps_defs: data from EpochCycle instance as returned by jexplore.sampling.epoch.EpochCycle.to_backend

Return type:

dict

class Epoch[Tstate: jexplore.sampling.state.State, Tsampling: jexplore.sampling.base.Sampling](epoch, force_covs=False)[source]#

Base Epoch definition class.

This encapsulates samples, a covariance matrix the stats object, the definition of the epoch cycle and the definition of the current sampling. The cass provides methods to initialize the epoch, get a sepecific state, run steps and instantiate an epoch from the result of a sampler run.

Parameters:
  • epoch (Self | dict) – Epoch object from which the new instance can be derived by computing the covariance and getting the last sample. Alternatively this can be a dictionnary with the definition of the epoch samples.

  • force_covs (bool) – if true forces the recomputing of the covariance matrices.

samples: Tstate#

batched state defining the sampling states sequence.

covs: jax.Array#

Array with shape \((N, D, D)\) containing the list of covariance matrices of all chains.

stats: EpochStats[Tstate]#

epoch statistics.

cycle: EpochCycle[Tstate]#

definition of the epoch cycle

sampling: Tsampling#

sampling definition

statecls: type[Tstate] | type[jexplore.sampling.state.State]#

state class corresponding to this epoch class

complete(sampling, compute=None)[source]#

Complete the definition and the parameters of the epoch.

Parameters:
  • sampling (Tsampling) – sampling parameters.

  • compute (dict | None) – dictionnary defining how to compute the missing parameters. The keys are the names of the parameters to be computed. Values are (func, inpars) tuples, where func is a callable for the computation of the corresponding parameter and inpars is the list of the names if the input parameters.

Returns:

this object after update.

Return type:

Self

set_cycle(cycle)[source]#

Sets this epoch cycle.

Parameters:

cycle (EpochCycle[Tstate]) – the EpochCycle value to be set.

Return type:

None

run_step(key, ind, state)[source]#

Runs a step of the epoch. For a given substep of the epoch cycle it draws one of the alternative steps and performs it.

Parameters:
  • key (jax.Array) – PRNG key

  • ind (int) – the index of the substep

  • state (Tstate) – current state

Returns:

the new state, the step stats and the index of the substep alternative step ctualy performed.

Return type:

tuple[Tstate, jax.Array, jax.Array]

from_run(states, stats, steps)[source]#

Gets the stacks of states, steps statistics and steps indexes from an epoch run and updates the epoch with these data.

Parameters:
  • states (Tstate) – stack of states.

  • stats (jax.Array) – stack of steps stats

  • steps (jax.Array) – stack of steps indexes

Returns:

this object after updating

Return type:

Self

to_backend(burn, stats)[source]#

dumps epoch in backend format

Parameters:
  • burn (int)

  • stats (bool)

Return type:

tuple[int, int, dict | None]