Source code for jexplore.sampling.base

"""
This module provides the definitions of the base
:py:attr:`jexplore.sampling.state.Sampling` class.
"""

from dataclasses import dataclass

import numpy as np

from .space import Box, Space


[docs] @dataclass class Sampling[Tspace: Space]: """Markov sampling defining parameters. In this base class version the parameter are the dimension of the target space and the number of chains. Child classes may specialize to specific type of markov sampling (e.g. :py:attr:`jexplore.sampling.mh.SamplingMH`) :param dim: dimension of the target space. :param nchains: number of chains. :param space: :py:attr:`jexplore.sampling.space` object describing the target space. the state will then be defined by a (nchain, dim) point""" dim: int """Dimension of the target space""" nchain: int """Number of chains""" space: Tspace | Box """Target space.""" def __init__( self, nchain: int, dim: int | None = None, space: Tspace | None = None ): self.nchain = nchain if space is not None: self.space = space self.dim = self.space.dim elif dim is not None: self.dim = dim self.space = Box(dim=dim) else: raise ValueError( "You need at least to define the dimension of the chains space." )
[docs] def to_backend(self) -> dict: """Save the attributes of the object in backend format.""" return { "nchain": np.array(self.nchain), "dim": np.array(self.dim), "class": f"{self.__class__.__module__}.{self.__class__.__qualname__}", }