"""
This module provides the definitions of the base
:py:attr:`jexplore.sampling.state.State` class.
"""
from dataclasses import asdict, dataclass
from typing import Callable, List, Protocol, Self, cast
import jax
import jax.numpy as jnp
# pylint: disable=too-few-public-methods
[docs]
class ArrayFn(Protocol):
"""Protocol for a function with a generic number of :py:attr:`jax.Array`
input positional parameter and a single :py:attr:`jax.Array` output."""
def __call__(self, *args: jax.Array) -> jax.Array: ...
[docs]
@jax.tree_util.register_dataclass
@dataclass
class State:
"""
This class provides the base definition of markov chains state.
Such minimal definition correspond to a single parameter `p` with
shape :math:`(N, D)` describing the point in the :math:`D`-dimensional
space for each one of the :math:`N` chains of the sampling to be
performed.
This can be specialized, in child classes, to different types of sampling
by adding extra parameters (e.g. :py:attr:`jexplore.sampling.mh.StateMH`)
with shape :math:`(N, d)`, where :math:`d` depends on the parameter.
The class also provides basic state manipulation method which are
generic for all the child classes defined as above.
Some of these methods can also deal with "batched" states. I.e. `State`
objects having parameters with shape :math:`(N, d, I)` - where :math:`I`
is the number of performed sampling iteration - as they are returned by
the :py:attr:`jax.lax.scan` loops of the sampler.
:param p: status point (nchains, dim)
"""
p: jax.Array
"""status point"""
[docs]
def compute(self, par: str, func: ArrayFn, inpar: List[str]) -> Self:
"""
Populate the values of one parameter :math:`p` as a function of other paramters
of the state.
:param par: parameter name.
:param func: function defining the parameter value.
:param inpar: list of the name of the input parameters of the function.
:return: a new state with the populated values.
"""
_val = jax.vmap(func)(*[getattr(self, _inp) for _inp in inpar])
_val = _val.reshape(_val.shape[0], -1)
return self.__class__(**(asdict(self) | {par: _val}))
[docs]
def update(
self,
update: Callable[[str, jax.Array], jax.Array],
pars: List[str] | None = None,
) -> Self:
"""
Updates a set of parameters :math:`p_i` of the state by calling an
update function:
.. math::
p_i = u(i, p_i)
:param update: the update function.
:param pars: list of the names of parameters to be updated. None (default):
all parameters.
:returns: a new state with the updated parameters.
"""
state_d = asdict(self)
pars = list(state_d.keys()) if pars is None else pars
return self.__class__(
**(state_d | {_name: update(_name, state_d[_name]) for _name in pars})
)
[docs]
def update_mask(
self, other: Self, mask: jax.Array | bool = True, pars: List[str] | None = None
) -> Self:
"""Import values of some parameters, and for some selected chains, from another state.
Chains are selected by a boolean mask (1-array of size `nchains`).
:param other: other state. The parameters to update should be defined and
have the same shape than in the current state.
:param mask: boolean mask for the chains to be updated.
:param pars: list of the name of attributes to be updated.
:return: a new state with the updated parameters.
"""
mask = mask if isinstance(mask, bool) else mask[:, None]
return self.update(
update=lambda _name, _val: cast(
jax.Array, jnp.where(mask, getattr(other, _name), _val)
),
pars=pars,
)
[docs]
def update_slice(
self, other: Self, slc: jax.Array, pars: List[str] | None = None
) -> Self:
"""Import values of some parameters, and for some selected chains, from another state.
Chains are selected by an index slice (i.e. a list of integers).
:param other: other state. The parameters to update should be defined and
have the same shape of the chains slice.
:param slc: list of slice indexes.
:param pars: list of the name of attributes to be updated.
:return: a new state with the updated parameters.
"""
# pylint: disable=no-member
return self.update(
lambda _name, _val: _val.at[slc, :].set(getattr(other, _name)), pars
)
[docs]
def set_val(
self,
val: jax.Array,
cslc: jax.Array | None = None,
pslc: jax.Array | None = None,
par: str = "p",
) -> Self:
"""Update one sample parameter slicing both in chains and parameter dimension.
:param val: values (should have the same shape of the slice)
:param cslc: chains slice. Default: all.
:param pslc: parameter dimensions slice. Default: all.
:param par: parameter name. Default: p
:return: new state"""
cslc = jnp.arange(self.p.shape[0]) if cslc is None else cslc
pslc = jnp.arange(getattr(self, par).shape[1]) if pslc is None else pslc
return self.update(
lambda _name, _val: _val.at[jnp.ix_(cslc, pslc)].set(val), [par]
)
[docs]
def slice(self, slc: jax.Array) -> Self:
r"""Return a state with parameters values corresponding to a slice of
the current state chains. That is, if the current state has :math:`N`
chains, we define slice as a set of indexes
:math:`\{0 \leq i_k < N, k=0,\dots,\tilde{N}-1\}`.
This method constructs a new state having one parameters
.. math::
\tilde{p}_{k, \alpha} = p_{i_k, \alpha},~~~~~k=0,\dots,\tilde{N}-1,
~~~~~~\alpha=0,\dots,d-1
for each :math:`(N, d)` shaped parameter :math:`p` of the current state.
:param slc: the slice as a list of indexes.
:returns: the new "sliced" state.
"""
return self.update(lambda _name, _val: _val[slc])
[docs]
def swap(self, even: jax.Array, odd: jax.Array, accept: jax.Array) -> Self:
"""Swap all parameters values between different chains.
:param even: set of indexes to be swapped
:param odd: other set of indexes to be swapped (should have the same size)
:param accept: boolean mask to select swaps that should be performed
:return: new state with swapped values.
"""
return self.update(lambda _name, _val: self._swap_rows(_val, even, odd, accept))
@staticmethod
def _swap_rows(
arr: jax.Array, even: jax.Array, odd: jax.Array, accept: jax.Array
) -> jax.Array:
"""swap 1 parameter values between chains"""
ai, aj = arr[even], arr[odd]
acc_b = accept.reshape(accept.shape + (1,) * (ai.ndim - 1))
arr = arr.at[even].set(jnp.where(acc_b, aj, ai))
arr = arr.at[odd].set(jnp.where(acc_b, ai, aj))
return arr
[docs]
def get_in_batch(self, ind: int = -1, batch: bool = False) -> Self:
"""This method assumes that the current state is batched. Thus its
parameters have shape :math:`(N, d, I)` where :math:`I` is the
number of performed markov iterations. In this case the method return a
new state corresponing to one specific iteration.
:param ind: index of the iteration.
:param batch: if True the returned state will be a 1-iteration batched
state.
:returns: the selected state.
"""
_ret = self.update(lambda _name, _val: _val[:, :, ind])
if batch:
# pylint: disable=protected-access
_ret = _ret.update(lambda _name, _val: _val[:, :, jnp.newaxis])
return _ret
[docs]
@classmethod
def from_dict(
cls, state_d: dict, mandatory: List[str] | None = None, batch: bool = False
) -> Self:
"""Instantiates a state object from a dictionnary. It will select all entries
in the dictionnary that correspond to the parameters of this class of states,
and use the corresponding values to define the state.
:param state_d: input dictionnary.
:param mandatory: list of nams of parameters that have to be present in the
dictionnary. Non mandatory parameters, if absent, will be
replaced by empty arrays. If `None` all parameters are
mandatory.
:param batch: when true it forces the the instantiated state to have a
"batched" shape.
:returns: the state defined by the dictionnary.
"""
# pylint: disable=no-member
_all = list(cls.__dataclass_fields__.keys())
mandatory = _all if mandatory is None else mandatory
_dict = {}
for _par in _all:
if _par not in mandatory:
_dict[_par] = state_d.get(_par, jnp.array([[]]))
continue
try:
_dict[_par] = state_d[_par]
# pylint: disable=raise-missing-from
except:
raise ValueError(f"dict defining a state requires {_par} values.")
if batch:
_dict = {
_key: _val if _val.ndim == 3 else _val[:, :, jnp.newaxis]
for _key, _val in _dict.items()
}
return cls(**_dict)