Source code for jexplore.sampling.space
"""
This module provides the definitions for the base target space classes.
The prototyping class :py:attr:`jexplore.sampling.space.Space` and the class
defining a simple box space :py:attr:`jexplore.sampling.space.Box`
"""
from dataclasses import dataclass
from typing import Protocol, cast
import jax
import jax.numpy as jnp
[docs]
@dataclass
class Space(Protocol):
"""Markov sampling target space prototype class.
Child space classes should implement an `inspace`
method and a `wrap` method.
:param dim: dimension of the target space.
"""
dim: int
"""Dimension of the target space"""
[docs]
def inspace(self, points: jax.Array) -> jax.Array:
"""Check which of a set of point lay in the defined space.
:param points: set of points to be checked, it can have the `state`
shape :math:`(N_{chains}, D)` or the `samples` shape
:math:`(N_{chains}, D, S)`.
:return: a boolean mask of shape :math:`(N_{chains})` or :math:`(N_{chains}, S)`
selecting points that lay in the defined space.
"""
# pylint: disable=unnecessary-ellipsis
...
[docs]
def wrap(self, points: jax.Array) -> jax.Array:
"""folds the wrapped dimensions of the target space.
:param points: set of points to be processed, it can have the `state`
shape :math:`(N_{chains}, D)` or the `samples` shape
:math:`(N_{chains}, D, S)`.
return: the same set of points with the wrapped dimensions folded."""
# pylint: disable=unnecessary-ellipsis
...
[docs]
class Box(Space):
"""Simple rectangular box space.
:param dim: dimension of the box. Only used if the `box` parameter is not
provided.
:param size: size of the box (assumed having equal size in all dimensions)
only used if the `box` parameter is not provided.
Default: infinity.
:param box: list of lists defining the segments boundaries for each box
dimensions. If not provided a box with `dim` dimensions of equal
size `size` is considered.
:param wrapped: list of periodic dimensions indexes. These dimensions will be
considererd unbound and the corresponding box `intervals` will
be interpreted as principal domain intervals. Default: empty list.
"""
bounds: jax.Array
"""Box bounds"""
wrap_dims: list[int]
"""Wrapped dimensions indexes"""
wrap_domain: jax.Array
"""wrapped dimensions principal domain"""
def __init__(
self,
dim: int | None = None,
size: float = jnp.inf,
box: list[list[float]] | None = None,
wrapped: list[int] | None = None,
):
if (box is None) and (dim is None):
raise ValueError("You need at least to provide the dimension of the box.")
if box is None:
box = [[-size / 2, size / 2] for _ in range(cast(int, dim))]
self.dim = len(box)
self.wrap_dims = [] if wrapped is None else wrapped
_wrap_domain = []
for _ind in self.wrap_dims:
_wrap_domain.append(box[_ind])
box[_ind] = [-jnp.inf, +jnp.inf]
self.wrap_domain = jnp.array(_wrap_domain)
self.bounds = jnp.array(box)
[docs]
def inspace(self, points: jax.Array) -> jax.Array:
return jnp.all(
jnp.array(
[
(points[:, _ind] > _btm) & (points[:, _ind] < _top)
for _ind, (_btm, _top) in enumerate(self.bounds)
]
),
axis=0,
)
[docs]
def wrap(self, points: jax.Array) -> jax.Array:
if len(self.wrap_dims) == 0:
return points
def _body(_pts, _idx):
_ind, _btm, _top = _idx
_ind = _ind.astype(int)
_pts = _pts.at[:, _ind].set((_pts[:, _ind] - _btm) % (_top - _btm) + _btm)
return (_pts), ()
idx = jnp.array(
[
(_dim, self.wrap_domain[_ind, 0], self.wrap_domain[_ind, 1])
for _ind, _dim in enumerate(self.wrap_dims)
]
)
return jax.lax.scan(_body, points, idx)[0]