Source code for jexplore.steps.modsel.model_switch

"""
This module defines the classes for Model Selection
models switching steps.
"""

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMS, SamplingMS, StateMS, StepSample
from jexplore.sampling.state import ArrayFn
from jexplore.steps.mh import MHStep
from jexplore.steps.modsel.swap import OrderByModel
from jexplore.steps.step import Step


[docs] class DrawPseudo[ Tepoch: EpochMS = EpochMS, Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS, ]( Step[Tepoch, Tstate, Tsampling] # type: ignore[type-var] ): """ Draws model pseudo prior samples for each chain """ draw_funcs: list """static list of pseudo prior rowing function for each model"""
[docs] def build(self, epoch: Tepoch): super().build(epoch) self.draw_funcs = [ lambda p, k, mod=_mod: self._draw_fun(p, k, mod) for _mod, _pp in enumerate(self.sampling.ppdraw) ]
def _draw_fun( self, p: jax.Array, key: jax.Array, mod: int, ) -> jax.Array: _pmask = self.sampling.cmasks[mod] return p.at[_pmask].set(self.sampling.ppdraw[mod](key, (1,))[1][0, :])
[docs] def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: """Pseudo drawing step :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ def _draw_chain(_mod, _p, _key): return jax.lax.switch(_mod, self.draw_funcs, _p, _key) keys = jax.random.split(key, self.sampling.nchain) p = jax.vmap(_draw_chain)(state.k[:, 0], state.p, keys) return state.set_val(p), jnp.ones(state.p.shape[0])
[docs] class DrawModel[ Tepoch: EpochMS = EpochMS, Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS, ]( MHStep[Tepoch, Tstate, Tsampling] # type: ignore[type-var] ): """ Draw models for each chain from a categorical which log weights are the sum of log likelihood, log priori and log pseudo prior. :param reorder: reorder chains by model (temp by temp) after running the model draw. :param draw_pseudo: draw from pseudo prior before drawing the models. """ _wrap: bool = False _reorder: StepSample[Tstate] _draw_pseudo: StepSample[Tstate] _evalf: dict[str, ArrayFn] reorder: bool """Reorder the chain by model (and temp) after drawing models""" draw_pseudo: bool """Draw pseudo prior before drawing models""" def __init__(self, reorder: bool = True, draw_pseudo: bool = True): super().__init__() self.reorder = reorder self.draw_pseudo = draw_pseudo
[docs] def build(self, epoch: Tepoch) -> None: super().build(epoch) self._reorder = OrderByModel[Tepoch, Tstate, Tsampling]().builder(epoch) self._draw_pseudo = DrawPseudo[Tepoch, Tstate, Tsampling]().builder(epoch) self._evalf = { "ll": self.sampling.allmodsll, "lp": self.sampling.allmodslp, "pp": self.sampling.allmodspp, }
[docs] def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]: """Model drawing step :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ keys = jax.random.split(key, 3) state = self._draw_pseudo(keys[0], state)[0] if self.draw_pseudo else state vals = {} for _key, _func in self._evalf.items(): # Note: vals have shape (nchains, nmodel) vals[_key] = jax.vmap(_func, in_axes=(0,))(state.p) logits = self.beta * vals["ll"] + vals["lp"] + vals["pp"] ks = jax.random.categorical(keys[1], logits, axis=1).reshape(-1, 1) _check = jnp.take_along_axis(vals["ll"], ks, axis=1) state = state.set_val(ks, par="k") # Populate ll and lp with the good values for _par in ["ll", "lp"]: state = state.set_val( jnp.take_along_axis(vals[_par], ks, axis=1), par=_par, ) state = self._reorder(keys[2], state)[0] if self.reorder else state return state, jnp.ones(state.p.shape[0])