Source code for jexplore.sampling.space

"""
This module provides the definitions for the base target space classes.
The prototyping class :py:attr:`jexplore.sampling.space.Space` and the class
defining a simple box space :py:attr:`jexplore.sampling.space.Box`
"""

from dataclasses import dataclass
from typing import Protocol, cast

import jax
import jax.numpy as jnp


[docs] @dataclass class Space(Protocol): """Markov sampling target space prototype class. Child space classes should implement an `inspace` method and a `wrap` method. :param dim: dimension of the target space. """ dim: int """Dimension of the target space"""
[docs] def inspace(self, points: jax.Array) -> jax.Array: """Check which of a set of point lay in the defined space. :param points: set of points to be checked, it can have the `state` shape :math:`(N_{chains}, D)` or the `samples` shape :math:`(N_{chains}, D, S)`. :return: a boolean mask of shape :math:`(N_{chains})` or :math:`(N_{chains}, S)` selecting points that lay in the defined space. """ # pylint: disable=unnecessary-ellipsis ...
[docs] def wrap(self, points: jax.Array) -> jax.Array: """folds the wrapped dimensions of the target space. :param points: set of points to be processed, it can have the `state` shape :math:`(N_{chains}, D)` or the `samples` shape :math:`(N_{chains}, D, S)`. return: the same set of points with the wrapped dimensions folded.""" # pylint: disable=unnecessary-ellipsis ...
[docs] class Box(Space): """Simple rectangular box space. :param dim: dimension of the box. Only used if the `box` parameter is not provided. :param size: size of the box (assumed having equal size in all dimensions) only used if the `box` parameter is not provided. Default: infinity. :param box: list of lists defining the segments boundaries for each box dimensions. If not provided a box with `dim` dimensions of equal size `size` is considered. :param wrapped: list of periodic dimensions indexes. These dimensions will be considererd unbound and the corresponding box `intervals` will be interpreted as principal domain intervals. Default: empty list. """ bounds: jax.Array """Box bounds""" wrap_dims: list[int] """Wrapped dimensions indexes""" wrap_domain: jax.Array """wrapped dimensions principal domain""" def __init__( self, dim: int | None = None, size: float = jnp.inf, box: list[list[float]] | None = None, wrapped: list[int] | None = None, ): if (box is None) and (dim is None): raise ValueError("You need at least to provide the dimension of the box.") if box is None: box = [[-size / 2, size / 2] for _ in range(cast(int, dim))] self.dim = len(box) self.wrap_dims = [] if wrapped is None else wrapped _wrap_domain = [] for _ind in self.wrap_dims: _wrap_domain.append(box[_ind]) box[_ind] = [-jnp.inf, +jnp.inf] self.wrap_domain = jnp.array(_wrap_domain) self.bounds = jnp.array(box)
[docs] def inspace(self, points: jax.Array) -> jax.Array: return jnp.all( jnp.array( [ (points[:, _ind] > _btm) & (points[:, _ind] < _top) for _ind, (_btm, _top) in enumerate(self.bounds) ] ), axis=0, )
[docs] def wrap(self, points: jax.Array) -> jax.Array: if len(self.wrap_dims) == 0: return points def _body(_pts, _idx): _ind, _btm, _top = _idx _ind = _ind.astype(int) _pts = _pts.at[:, _ind].set((_pts[:, _ind] - _btm) % (_top - _btm) + _btm) return (_pts), () idx = jnp.array( [ (_dim, self.wrap_domain[_ind, 0], self.wrap_domain[_ind, 1]) for _ind, _dim in enumerate(self.wrap_dims) ] ) return jax.lax.scan(_body, points, idx)[0]