jexplore.sampling.state#

This module provides the definitions of the base jexplore.sampling.state.State class.

Classes#

ArrayFn

Protocol for a function with a generic number of jax.Array

State

This class provides the base definition of markov chains state.

Module Contents#

class ArrayFn[source]#

Bases: Protocol

Protocol for a function with a generic number of jax.Array input positional parameter and a single jax.Array output.

class State[source]#

This class provides the base definition of markov chains state.

Such minimal definition correspond to a single parameter p with shape \((N, D)\) describing the point in the \(D\)-dimensional space for each one of the \(N\) chains of the sampling to be performed.

This can be specialized, in child classes, to different types of sampling by adding extra parameters (e.g. jexplore.sampling.mh.StateMH) with shape \((N, d)\), where \(d\) depends on the parameter.

The class also provides basic state manipulation method which are generic for all the child classes defined as above. Some of these methods can also deal with “batched” states. I.e. State objects having parameters with shape \((N, d, I)\) - where \(I\) is the number of performed sampling iteration - as they are returned by the jax.lax.scan loops of the sampler.

Parameters:

p – status point (nchains, dim)

p: jax.Array#

status point

compute(par, func, inpar)[source]#

Populate the values of one parameter \(p\) as a function of other paramters of the state.

Parameters:
  • par (str) – parameter name.

  • func (ArrayFn) – function defining the parameter value.

  • inpar (List[str]) – list of the name of the input parameters of the function.

Returns:

a new state with the populated values.

Return type:

Self

update(update, pars=None)[source]#

Updates a set of parameters \(p_i\) of the state by calling an update function:

\[p_i = u(i, p_i)\]
Parameters:
  • update (Callable[[str, jax.Array], jax.Array]) – the update function.

  • pars (List[str] | None) – list of the names of parameters to be updated. None (default): all parameters.

Returns:

a new state with the updated parameters.

Return type:

Self

update_mask(other, mask=True, pars=None)[source]#

Import values of some parameters, and for some selected chains, from another state. Chains are selected by a boolean mask (1-array of size nchains).

Parameters:
  • other (Self) – other state. The parameters to update should be defined and have the same shape than in the current state.

  • mask (jax.Array | bool) – boolean mask for the chains to be updated.

  • pars (List[str] | None) – list of the name of attributes to be updated.

Returns:

a new state with the updated parameters.

Return type:

Self

update_slice(other, slc, pars=None)[source]#

Import values of some parameters, and for some selected chains, from another state. Chains are selected by an index slice (i.e. a list of integers).

Parameters:
  • other (Self) – other state. The parameters to update should be defined and have the same shape of the chains slice.

  • slc (jax.Array) – list of slice indexes.

  • pars (List[str] | None) – list of the name of attributes to be updated.

Returns:

a new state with the updated parameters.

Return type:

Self

set_val(val, cslc=None, pslc=None, par='p')[source]#

Update one sample parameter slicing both in chains and parameter dimension.

Parameters:
  • val (jax.Array) – values (should have the same shape of the slice)

  • cslc (jax.Array | None) – chains slice. Default: all.

  • pslc (jax.Array | None) – parameter dimensions slice. Default: all.

  • par (str) – parameter name. Default: p

Returns:

new state

Return type:

Self

slice(slc)[source]#

Return a state with parameters values corresponding to a slice of the current state chains. That is, if the current state has \(N\) chains, we define slice as a set of indexes \(\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}\). This method constructs a new state having one parameters

\[\tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1, ~~~~~~\alpha=0,\dots,d-1\]

for each \((N, d)\) shaped parameter \(p\) of the current state.

Parameters:

slc (jax.Array) – the slice as a list of indexes.

Returns:

the new “sliced” state.

Return type:

Self

swap(even, odd, accept)[source]#

Swap all parameters values between different chains.

Parameters:
  • even (jax.Array) – set of indexes to be swapped

  • odd (jax.Array) – other set of indexes to be swapped (should have the same size)

  • accept (jax.Array) – boolean mask to select swaps that should be performed

Returns:

new state with swapped values.

Return type:

Self

get_in_batch(ind=-1, batch=False)[source]#

This method assumes that the current state is batched. Thus its parameters have shape \((N, d, I)\) where \(I\) is the number of performed markov iterations. In this case the method return a new state corresponing to one specific iteration.

Parameters:
  • ind (int) – index of the iteration.

  • batch (bool) – if True the returned state will be a 1-iteration batched state.

Returns:

the selected state.

Return type:

Self

classmethod from_dict(state_d, mandatory=None, batch=False)[source]#

Instantiates a state object from a dictionnary. It will select all entries in the dictionnary that correspond to the parameters of this class of states, and use the corresponding values to define the state.

Parameters:
  • state_d (dict) – input dictionnary.

  • mandatory (List[str] | None) – list of nams of parameters that have to be present in the dictionnary. Non mandatory parameters, if absent, will be replaced by empty arrays. If None all parameters are mandatory.

  • batch (bool) – when true it forces the the instantiated state to have a “batched” shape.

Returns:

the state defined by the dictionnary.

Return type:

Self