Source code for jexplore.sampler

"""
Jexplore: Jax based markov sampler
"""

import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Generic, TypeAlias, TypeVar, cast

import jax
import jax.numpy as jnp
import numpy as np

from .backends import DefaultBackend
from .sampling import Epoch, EpochCycle, Sampling, State, StepSample

logger = logging.getLogger(__name__)
"""logging instance"""

Tepoch = TypeVar("Tepoch", bound=Epoch)
"""placeholder for epoch class"""

Tsampling = TypeVar("Tsampling", bound=Sampling)
"""placeholder for samping definition class"""

Tbackend = TypeVar("Tbackend", bound=DefaultBackend)
"""placeholder for backend class"""

Tstate = TypeVar("Tstate", bound=State)
"""placholder for state class"""

StepBuilder: TypeAlias = Callable[[Tepoch], StepSample[Tstate]]
"""Prototype definition of a step builder.
This is a function getting an epoch as argument and
returning a :py:attr:`jexplore.sampling.epoch.StepSample` function."""


[docs] class Steps(list[dict[StepBuilder[Tepoch, Tstate], float]]): """Sampler steps class. This is a list of dictionnaries. Each element of the list is a sequential substep and the corresponding dictionnary keys are the builders of the alternative steps for that substep and the values are the weights."""
[docs] def epoch_cycle(self, epoch: Tepoch) -> EpochCycle: """Defines the epoch cycle. Each one of the `StepBuilder` dictionnary keys is called on with argument `epoch`, and replaced by the returned :py:attr:`jexplore.sampling.epoch.StepSample` instance. :param epoch: this epoch :return: the epoch cycle list of dictionnaries. """ return EpochCycle( steps=[ { cast(StepSample[State], _alt(epoch)): _wgt for _alt, _wgt in _step.items() } for _step in self ], )
[docs] @dataclass class JaxSampler(Generic[Tsampling, Tepoch, Tbackend]): """Main jexplore sampler class. :param sampling: sampling definition :param steps: steps definition :param backend: backend instance """ sampling: Tsampling """sampling parameters definition""" steps: Steps """steps definition""" backend: Tbackend """backend instance"""
[docs] def run( self, epoch: Tepoch, niters: int, nepoch: int = 1, seed: int | None = None ) -> Tepoch: """Runs MCMC, stores results in the backend and returns last epoch. :param epoch: starting epoch :param niters: number of iterations per epoch :param nepoch: number of epochs :param seed: seed. If None a random seed is generated. :return: last epoch. """ key = jax.random.key(np.random.randint(0, 1000) if seed is None else seed) epoch = epoch.__class__(epoch) last_epoch = epoch for _ in range(nepoch): del last_epoch epoch.complete(self.sampling) epoch.cycle = self.steps.epoch_cycle(epoch) key, epoch = self.run_epoch(epoch, niters, key) self.backend.ingest(epoch) last_epoch = epoch epoch = epoch.__class__(epoch, force_covs=True) return last_epoch
[docs] @staticmethod def run_epoch(epoch: Tepoch, niters: int, key) -> tuple[jax.Array, Tepoch]: """Runs one MCMC epoch iterations and returns the resulting Epoch object. :param epoch: last epoch :param int nstep: number of MCMC iterations. :param PRNGKey key: PRNGKey number :return: key and the new Epoch """ state = epoch.samples.get_in_batch() keys = jax.random.split(key, (niters * epoch.cycle.n) + 1) key = keys[0] keys = keys[1:] def body(_state, _iter): _key, _idx = _iter _key = keys[_key] _state, _acc_rate, _step = epoch.run_step(_key, _idx, _state) return (_state), (_state, _acc_rate, _step) _iters = jnp.concatenate( [jnp.array([(_ind, _ind % epoch.cycle.n) for _ind, _ in enumerate(keys)])] ) _, (states, stats, steps) = jax.lax.scan(body, (state), _iters) return key, epoch.from_run(states, stats, steps)