"""
This module defines the classes for Model Selection
in-model steps acting separately on each chain.
"""
from typing import cast
import jax
import jax.numpy as jnp
from jexplore.sampling import EpochMS, SamplingMS, StateMS
from jexplore.steps.mh import AllChains
from jexplore.steps.single_masked import SingleMasked
[docs]
class AllChainsWrap[
Tepoch: EpochMS = EpochMS,
Tstate: StateMS = StateMS,
Tsampling: SamplingMS = SamplingMS,
Tstep: AllChains = AllChains,
](
AllChains[Tepoch, Tstate, Tsampling] # type: ignore[type-var]
):
"""Model selection wrapper for :py:attr:`jexplore.steps.mh.AllChains`
steps. The step is wrapped so that it is applied only on chains that
are in a model which (dimensions) mask fully contains the mask of the
step. Note that the resulting step does not propose changes in to
the chains' models.
:param step: the step to be wrapped"""
wrapped_step: Tstep
active_models: jax.Array
def __init__(self, step: Tstep):
super().__init__()
self.wrapped_step = step
[docs]
def build(self, epoch: Tepoch) -> None:
super().build(epoch)
self.wrapped_step.build(epoch) # type: ignore[arg-type]
# The only active modes are those which mask fully contains the
# proposal mask.
# For the other models the proposition acts like the identity.
_amod = jnp.array(
[
_ind
for _ind, _mask in enumerate(self.sampling.masks)
if jnp.all(jnp.isin(self.wrapped_step.mask, _mask))
]
)
self.active_models = (
jnp.zeros((self.sampling.nmodels,), dtype=bool).at[_amod].set(True)
)
[docs]
def proposal(
self, key: jax.Array, state: Tstate
) -> tuple[jax.Array, Tstate, jax.Array]:
"""Propose a new state. The proposed state is identical to the
old one for all chains in inactive models (i.e. models whose mask
do not contain the mask of the wrapped step) and correspond to the
state proposed by the wrapped step for all other chains.
:param key: PRNG key
:param state: current state
:return: new state and the boolean mask of the chains modified by the step.
"""
key, prop, qxy = self.wrapped_step.proposal(key, state)
_mask = self.active_models[state.k[:, 0], None]
prop = cast(Tstate, prop.set_val(jnp.where(_mask, prop.p, state.p)))
qxy = jnp.where(_mask[:, 0], qxy, 1.0)
return key, prop, qxy
[docs]
class InModelSC[
Tepoch: EpochMS = EpochMS,
Tstate: StateMS = StateMS,
Tsampling: SamplingMS = SamplingMS,
Tstep: SingleMasked = SingleMasked,
](
AllChains[Tepoch, Tstate, Tsampling] # type: ignore[type-var]
):
"""In model single chain MH step. This is constructed out of a list of
single chain proposals each one defined on a subspacei of the corresponding
model space.
:param steps: optional list of one (single chain MH) step instance for each model.
:param stepclass: step class. If `steps` argument is not specified, the `build`
method will instantiate one step for each model usng this step
constructor.
:param stepargs: step instantiation arguments."""
steps: list[type[Tstep] | None]
"""List of steps classes per model."""
stepsargs: list[dict]
"""Step instantiation arguments."""
_proplist: list
"""List of proposal functions."""
def __init__(
self,
steps: type[Tstep] | list[type[Tstep] | None],
stepsargs: dict | list[dict] | None,
):
super().__init__()
if not isinstance(steps, list):
steps = [steps]
self.steps = steps
if stepsargs is None:
stepsargs = {}
if isinstance(stepsargs, dict):
stepsargs = [stepsargs]
self.stepsargs = stepsargs
# pylint: disable=protected-access
@staticmethod
def _prop_wrap(key, p, ch, step):
prop, qxy = step.sc_masked__proposal(key, p[step.mask], ch)
return p.at[step.mask].set(prop), qxy
# pylint: disable=unused-argument
[docs]
def build(self, epoch: Tepoch) -> None:
super().build(epoch)
nmod = self.sampling.nmodels
if (len(self.steps) == 1) and (nmod > 1):
self.steps = [self.steps[0] for _ in range(nmod)]
if (len(self.stepsargs) == 1) and (nmod > 1):
self.stepsargs = [self.stepsargs[0] for _ in range(nmod)]
if not ((len(self.stepsargs) == len(self.steps)) and (len(self.steps) == nmod)):
raise ValueError(
"The size of the lists of steps and arguments should be the number of models"
)
self._proplist = []
for _ind, _cls in enumerate(self.steps):
# A None element in the list is a identity proposal
if _cls is None:
self._proplist.append(lambda key, p, ch: (p, jnp.ones(p.shape[1])))
continue
# Instantiate a step with mask equal to the mask of the model of not
# otherwise defined in the arguments.
_step = _cls(**({"mask": self.sampling.masks[_ind]} | self.stepsargs[_ind]))
_step.build(epoch) # type: ignore[arg-type]
# If the mask of the step is not fully contained in the model mask
# then we use an identity proposal
if not jnp.all(
jnp.isin(cast(jax.Array, _step.mask), self.sampling.masks[_ind])
):
self._proplist.append(lambda key, p, ch: (p, jnp.ones(p.shape[1])))
continue
self._proplist.append(
lambda key, p, ch, stp=_step: self._prop_wrap(key, p, ch, stp)
)
[docs]
def proposal(
self, key: jax.Array, state: Tstate
) -> tuple[jax.Array, Tstate, jax.Array]:
"""In model single chain proposal.
:param key: PRNG key
:param state: current state
:return: the new PRNG key, the proposed state, the transition log probability.
"""
chains = jnp.arange(state.p.shape[0])
keys = jax.random.split(key, self.sampling.nchain + 1)
prop, qxy = jax.vmap(
lambda _key, _p, _ch: jax.lax.switch(
state.k[_ch, 0], self._proplist, _key, _p, _ch
),
in_axes=(0, 0, 0),
)(keys[1:], state.p, chains)
state = state.set_val(prop)
return keys[0], state, qxy