Source code for jexplore.sampling.base
"""
This module provides the definitions of the base
:py:attr:`jexplore.sampling.state.Sampling` class.
"""
from dataclasses import dataclass
import numpy as np
from .space import Box, Space
[docs]
@dataclass
class Sampling[Tspace: Space]:
"""Markov sampling defining parameters.
In this base class version the parameter are the dimension of the
target space and the number of chains. Child classes may specialize
to specific type of markov sampling (e.g.
:py:attr:`jexplore.sampling.mh.SamplingMH`)
:param dim: dimension of the target space.
:param nchains: number of chains.
:param space: :py:attr:`jexplore.sampling.space` object
describing the target space.
the state will then be defined by a (nchain, dim) point"""
dim: int
"""Dimension of the target space"""
nchain: int
"""Number of chains"""
space: Tspace | Box
"""Target space."""
def __init__(
self, nchain: int, dim: int | None = None, space: Tspace | None = None
):
self.nchain = nchain
if space is not None:
self.space = space
self.dim = self.space.dim
elif dim is not None:
self.dim = dim
self.space = Box(dim=dim)
else:
raise ValueError(
"You need at least to define the dimension of the chains space."
)
[docs]
def to_backend(self) -> dict:
"""Save the attributes of the object in
backend format."""
return {
"nchain": np.array(self.nchain),
"dim": np.array(self.dim),
"class": f"{self.__class__.__module__}.{self.__class__.__qualname__}",
}