Source code for jexplore.steps.modsel.inmodel_single

"""
This module defines the classes for Model Selection
in-model steps acting separately on each chain.
"""

from typing import cast

import jax
import jax.numpy as jnp

from jexplore.sampling import EpochMS, SamplingMS, StateMS
from jexplore.steps.mh import AllChains
from jexplore.steps.single_masked import SingleMasked


[docs] class AllChainsWrap[ Tepoch: EpochMS = EpochMS, Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS, Tstep: AllChains = AllChains, ]( AllChains[Tepoch, Tstate, Tsampling] # type: ignore[type-var] ): """Model selection wrapper for :py:attr:`jexplore.steps.mh.AllChains` steps. The step is wrapped so that it is applied only on chains that are in a model which (dimensions) mask fully contains the mask of the step. Note that the resulting step does not propose changes in to the chains' models. :param step: the step to be wrapped""" wrapped_step: Tstep active_models: jax.Array def __init__(self, step: Tstep): super().__init__() self.wrapped_step = step
[docs] def build(self, epoch: Tepoch) -> None: super().build(epoch) self.wrapped_step.build(epoch) # type: ignore[arg-type] # The only active modes are those which mask fully contains the # proposal mask. # For the other models the proposition acts like the identity. _amod = jnp.array( [ _ind for _ind, _mask in enumerate(self.sampling.masks) if jnp.all(jnp.isin(self.wrapped_step.mask, _mask)) ] ) self.active_models = ( jnp.zeros((self.sampling.nmodels,), dtype=bool).at[_amod].set(True) )
[docs] def proposal( self, key: jax.Array, state: Tstate ) -> tuple[jax.Array, Tstate, jax.Array]: """Propose a new state. The proposed state is identical to the old one for all chains in inactive models (i.e. models whose mask do not contain the mask of the wrapped step) and correspond to the state proposed by the wrapped step for all other chains. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. """ key, prop, qxy = self.wrapped_step.proposal(key, state) _mask = self.active_models[state.k[:, 0], None] prop = cast(Tstate, prop.set_val(jnp.where(_mask, prop.p, state.p))) qxy = jnp.where(_mask[:, 0], qxy, 1.0) return key, prop, qxy
[docs] class InModelSC[ Tepoch: EpochMS = EpochMS, Tstate: StateMS = StateMS, Tsampling: SamplingMS = SamplingMS, Tstep: SingleMasked = SingleMasked, ]( AllChains[Tepoch, Tstate, Tsampling] # type: ignore[type-var] ): """In model single chain MH step. This is constructed out of a list of single chain proposals each one defined on a subspacei of the corresponding model space. :param steps: optional list of one (single chain MH) step instance for each model. :param stepclass: step class. If `steps` argument is not specified, the `build` method will instantiate one step for each model usng this step constructor. :param stepargs: step instantiation arguments.""" steps: list[type[Tstep] | None] """List of steps classes per model.""" stepsargs: list[dict] """Step instantiation arguments.""" _proplist: list """List of proposal functions.""" def __init__( self, steps: type[Tstep] | list[type[Tstep] | None], stepsargs: dict | list[dict] | None, ): super().__init__() if not isinstance(steps, list): steps = [steps] self.steps = steps if stepsargs is None: stepsargs = {} if isinstance(stepsargs, dict): stepsargs = [stepsargs] self.stepsargs = stepsargs # pylint: disable=protected-access @staticmethod def _prop_wrap(key, p, ch, step): prop, qxy = step.sc_masked__proposal(key, p[step.mask], ch) return p.at[step.mask].set(prop), qxy # pylint: disable=unused-argument
[docs] def build(self, epoch: Tepoch) -> None: super().build(epoch) nmod = self.sampling.nmodels if (len(self.steps) == 1) and (nmod > 1): self.steps = [self.steps[0] for _ in range(nmod)] if (len(self.stepsargs) == 1) and (nmod > 1): self.stepsargs = [self.stepsargs[0] for _ in range(nmod)] if not ((len(self.stepsargs) == len(self.steps)) and (len(self.steps) == nmod)): raise ValueError( "The size of the lists of steps and arguments should be the number of models" ) self._proplist = [] for _ind, _cls in enumerate(self.steps): # A None element in the list is a identity proposal if _cls is None: self._proplist.append(lambda key, p, ch: (p, jnp.ones(p.shape[1]))) continue # Instantiate a step with mask equal to the mask of the model of not # otherwise defined in the arguments. _step = _cls(**({"mask": self.sampling.masks[_ind]} | self.stepsargs[_ind])) _step.build(epoch) # type: ignore[arg-type] # If the mask of the step is not fully contained in the model mask # then we use an identity proposal if not jnp.all( jnp.isin(cast(jax.Array, _step.mask), self.sampling.masks[_ind]) ): self._proplist.append(lambda key, p, ch: (p, jnp.ones(p.shape[1]))) continue self._proplist.append( lambda key, p, ch, stp=_step: self._prop_wrap(key, p, ch, stp) )
[docs] def proposal( self, key: jax.Array, state: Tstate ) -> tuple[jax.Array, Tstate, jax.Array]: """In model single chain proposal. :param key: PRNG key :param state: current state :return: the new PRNG key, the proposed state, the transition log probability. """ chains = jnp.arange(state.p.shape[0]) keys = jax.random.split(key, self.sampling.nchain + 1) prop, qxy = jax.vmap( lambda _key, _p, _ch: jax.lax.switch( state.k[_ch, 0], self._proplist, _key, _p, _ch ), in_axes=(0, 0, 0), )(keys[1:], state.p, chains) state = state.set_val(prop) return keys[0], state, qxy