Source code for jexplore.steps.modsel.swap
"""
This module defines the class for Model Selection
steps swapping chains.
"""
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMS, SamplingMS, StateMS
from jexplore.steps.step import Step
from jexplore.steps.tswap import TSwap
TSwapMS = TSwap[EpochMS, StateMS, SamplingMS] # type: ignore[type-var]
"""Temperature swap for model selection"""
[docs]
class OrderByModel[
Tepoch: EpochMS = EpochMS,
Tstate: StateMS = StateMS,
Tsampling: SamplingMS = SamplingMS,
](
Step[Tepoch, Tstate, Tsampling] # type: ignore[type-var]
):
"""
Swap chains to restore model ordering (temp by temp)
"""
_wrap: bool = False
[docs]
def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
"""Pseudo drawing step
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
ntemps = self.sampling.temps.shape[0]
_idx = jnp.argsort(state.k.reshape(-1, ntemps, 1), axis=0)
return (
state.update(
lambda _name, _val: jnp.take_along_axis(
_val.reshape(-1, ntemps, _val.shape[1]), _idx, axis=0
).reshape(-1, _val.shape[1])
),
jnp.ones(state.k.shape[0]),
)