"""
This module defines base classes for colored proposals.
"""
from typing import List
import jax
import jax.numpy as jnp
from jexplore.sampling import Epoch, EpochMH, Sampling, SamplingMH, State, StateMH
from jexplore.steps.mh import MHStep
from jexplore.steps.step import Step
[docs]
class BadColoring(ValueError):
"""Exception for uneven groups. A colored step runs the substep of
each color in a `jax.lax.scan` loop and this requires all colors to
have the same number of chains and complementary chains. If this is
not the case the present exception is raised.
:param groups: list of colors chain groups
:param cgroups: list of colors complementary chain groups"""
def __init__(self, groups: List[jax.Array], cgroups: List[jax.Array]):
g_sizes = [_group.size for _group in groups]
cg_sizes = [_group.size for _group in cgroups]
super().__init__(
f"""
Groups sizes: {g_sizes}
Comp. groups sizes: {cg_sizes}
Groups and complementary groups must have the same size.
Use the adjust or duplicate flags to instruct the
grouping implementation on how to ensure this."""
)
[docs]
class Colored[Tepoch: Epoch, Tstate: State, Tsampling: Sampling](
Step[Tepoch, Tstate, Tsampling] # type: ignore[type-var]
):
"""Generic class for colored chains step. This provides
* a generic `build` method that calls a `grouping` function that \
defines the different colors chains groups. It also provides a \
prototype for the `grouping` method.
* a generic `step` method that permutes chains callin the `permuting` \
method and calls `colored_step` method for all colors. It also provides \
prototypes for `permuting` and `colored_step`.
:param ngroups: number of groups
:param permute: if true chains are permuted at each iteration
:param duplicate: duplicate chains in group to make colors groups even
:param adjust: adjust the number of group to have even colors groups
"""
permute: bool
"""permute chains at each iteration"""
adjust: bool
"""adjust number of groups to have even colors"""
duplicate: bool
"""duplicate chains to have even colors"""
groups: List[jax.Array]
"""list of indexes of the chains of each color"""
cgroups: List[jax.Array]
"""list of indexes of the complementary chains of each color"""
ngroups: int
"""number of colors"""
def __init__(
self,
ngroups: int = 2,
permute: bool = False,
duplicate: bool = False,
adjust: bool = False,
):
super().__init__()
self.permute = permute
self.ngroups = ngroups
if adjust and duplicate:
raise ValueError("Only one among adjust and duplicate can be True")
self.adjust = adjust
self.duplicate = duplicate
def _check_groups(self) -> None:
"""check colors groups and rise `BadColoring` if they are
uneven"""
groups_check = jnp.unique(
jnp.array([_group.size for _group in self.groups])
).size
cgroups_check = jnp.unique(
jnp.array([_group.size for _group in self.cgroups])
).size
if groups_check * cgroups_check > 1:
raise BadColoring(self.groups, self.cgroups)
[docs]
def build(self, epoch: Tepoch) -> None:
"""Step initialisation method. It extends
:py:attr:`jexplore.steps.step.Step.build` adding a call to `grouping`,
to define the colors group, checks that the defined colors groups are
even, rising a :py:attr:`BadColoring` exception otherwise.
:param epoch: current epoch.
"""
super().build(epoch)
self.groups, self.cgroups = self.grouping()
self._check_groups()
[docs]
def permuting(self, key: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Permutes chains. This is just a prototype. The actual implementation
depends on the type of coloring.
:param key: PRNG key
:return: new PRNG key and permuted chains indexes."""
raise NotImplementedError(
"Colored is an abstract class. Need to implement this method."
)
[docs]
def grouping(self) -> tuple[list[jax.Array], list[jax.Array]]:
"""Defines the groups and complementary groups of chains for
each color. This is just a prototype. The actual implementation
depends on the type of coloring.
:return: indexes of the chains of each color, indexes of the complementary
chains for each color."""
raise NotImplementedError(
"Colored is an abstract class. Need to implement this method."
)
# pylint: disable=unused-argument
[docs]
def colored_step(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[Tstate, jax.Array]:
"""Step to be executed on a single color. This is a prototype.
:param key: PRNG key
:param state: current state
:param group: indexes of the color chains
:param cgroup: indexes of the complementary chains for this color.
:return: new state and boolean mask of the changed chains"""
return state, jnp.empty(0)
[docs]
def runtime_grouping(
self, key: jax.Array, groups: list[jax.Array], cgroups: list[jax.Array]
) -> tuple[jax.Array, list[jax.Array], list[jax.Array]]:
"""Runtime redefinition of groups and complementary groups. This will be called
at each step.
In this default implementation this simply optionnaly calls the permute method to
permute the chains in the colors definitions.
:param key: PRNG key.
:param group: indexes of the color chains.
:param cgroup: indexes of the complementary chains for this color.
:return: new PRNG key and the newly definef groups and cgroups."""
if self.permute:
key, ichain = self.permuting(key)
else:
ichain = jnp.arange(self.sampling.nchain)
groups = [ichain[_group] for _group in groups]
cgroups = [ichain[_cgroup] for _cgroup in cgroups]
return key, groups, cgroups
[docs]
def step(self, key: jax.Array, state: Tstate) -> tuple[Tstate, jax.Array]:
"""General implementation of the step sampling method for a colored step.
There is in principle no need to extend or override this method when
defining an actual colored step implemetation. The implementation details
should rather go in the definition of the `grouping`, `permuting` and
`colored_step` methods.
:param key: PRNG key
:param state: the current state
:return: new state and the boolean mask of the chains modified by the step.
"""
key, groups, cgroups = self.runtime_grouping(key, self.groups, self.cgroups)
keys = jax.random.split(key, self.ngroups)
def group_step(
_carry: tuple[Tstate, jax.Array], _idx: jax.Array
) -> tuple[tuple[Tstate, jax.Array], tuple[()]]:
_state, _stats = _carry
_group = jax.lax.select_n(_idx, *groups)
_state, _new_stats = self.colored_step(
jax.lax.select_n(_idx, *keys),
_state,
_group,
jax.lax.select_n(_idx, *cgroups),
)
_stats = _stats.at[_group].set(_new_stats)
return (_state, _stats), ()
(state, stats), _ = jax.lax.scan(
group_step,
(state, jnp.zeros(self.sampling.nchain)),
jnp.arange(len(groups)),
)
return state, stats
[docs]
class ColoredMH[Tepoch: EpochMH, Tstate: StateMH, Tsampling: SamplingMH](
Colored[Tepoch, Tstate, Tsampling], # type: ignore[type-var]
MHStep[Tepoch, Tstate, Tsampling],
):
"""Generic class for colored Metropolis-Hastings step. With respect
to the two parent classes this class implements explicitly a `permuting`
method that defines permutations of walkers chains for each temperature.
:param ngroups: number of groups
:param permute: if true chains are permuted at each iteration
:param duplicate: duplicate chains in group to make colors groups even
:param adjust: adjust the number of group to have even colors groups
"""
[docs]
def permuting(self, key: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Permutes walkers chains for each temperature.
:param key: PRNG key
:return: new PRNG key and permuted chains indexes."""
ntemps = self.sampling.temps.shape[0]
nwalkers = self.sampling.nwalker
key, kperm = jax.random.split(key)
ichain = jnp.arange(ntemps * nwalkers).reshape(nwalkers, ntemps).T
return (
key,
jax.random.permutation(kperm, ichain, axis=1, independent=True).T.flatten(),
)
[docs]
def grouping(self) -> tuple[list[jax.Array], list[jax.Array]]:
"""Defines the groups and complementary groups of chains for
each color. This is just a prototype. The actual implementation
depends on the type of coloring.
:return: indexes of the chains of each color, indexes of the complementary
chains for each color."""
raise NotImplementedError(
"ColoredMH is an abstract class. Need to implement this method."
)
[docs]
class ColoredSC[Tepoch: EpochMH, Tstate: StateMH, Tsampling: SamplingMH](
ColoredMH[Tepoch, Tstate, Tsampling]
):
"""Generic class for "single" chain colored Metropoling-Hasting (example: stretch).
This defines:
* color `grouping` as equal size split of walkers chains (for all temperature) with \
the complementary group for each chain defined by the set of all chains not in the \
color group and having the same temperature.
* `colored_step` as a MH acceptance test after a call to "reduced" in-color proposal \
encoded in the `proposal` method. It also provides a prototype for this method.
:param ngroups: number of groups
:param permute: if true chains are permuted at each iteration
:param duplicate: duplicate chains in group to make colors groups even
:param adjust: adjust the number of group to have even colors groups
"""
npart: int
"""number of partners needed to build the proposal"""
[docs]
def get_partners(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> jax.Array:
"""Method for getting partners samples fore each chain of a group. This
implementation simply draws randomly one partner for each chain in the
group among those in the chains of complementary group with the same
temperature.
:param key: PRNG key
:param state: current state
:param group: group chains
:param cgroup: complementary group chains
:return: the parners as an array with shape (self.npars, group.size, dim)
"""
ntemps = self.sampling.temps.shape[0]
cg = cgroup.reshape(-1, ntemps)
g = group.reshape(-1, ntemps)
# (group.size // temps, cgroup.size // temps, temps)
# broadcasting cgroup indexes (per temp) for all element of the group
# icg[i, j, h] = cgroup.reshape(-1, ntemps)[j, h]
icg = jnp.broadcast_to(cg[None, :, :], (g.shape[0], *cg.shape))
# (group.size // temps, temps)
# icg[i, j] = one icg indexes randomly chosed.
# Finally we flatten to the group list of chain
icg = jnp.take_along_axis(
icg[None, :, :, :],
jax.random.randint(
key, shape=(self.npart,) + g.shape, minval=0, maxval=cg.shape[0]
)[:, :, None, :],
axis=2,
)[:, :, 0, :].reshape(self.npart, -1)
return state.p[icg, :]
# pylint: disable=unused-argument
[docs]
def proposal(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[jax.Array, Tstate, jax.Array]:
"""Proposal restricted to chains of one color. This is just a prototype.
:param key: PRNG key
:param state: current state
:param group: current color group
:param cgroup: complementary chains for each chain in the group.
:return: PNRG key, proposed state, transition probability for each chain."""
raise NotImplementedError(
"ColoredSC is an abstract class. Need to implement proposal method."
)
[docs]
def grouping(self) -> tuple[list[jax.Array], list[jax.Array]]:
"""Color grouping implementation. For each temperature the walkers
are divided in `ngroups` equal sets. For each chain of each group the
complementary chains are all the chains belonging to the other groups
(and having the same temperature).
:return: indexes of the chains of each color, indexes of the complementary
chains for each color."""
if self.ngroups == -1:
self.ngroups = self.sampling.nwalker
if self.adjust or self.duplicate:
raise NotImplementedError("Adjusting grouping not implemented yet.")
ichain = jnp.arange(self.sampling.nchain)
ntemps = self.sampling.temps.shape[0]
grps = jnp.arange(self.ngroups)
sel = [ichain // ntemps % self.ngroups == _grp for _grp in grps]
return [ichain[sel[_grp]] for _grp in grps], [
ichain[~sel[_grp]] for _grp in grps
]
[docs]
def colored_step(
self, key: jax.Array, state: Tstate, group: jax.Array, cgroup: jax.Array
) -> tuple[Tstate, jax.Array]:
"""Colored step implementation. It calls the `proposal` method and then
accepts/reject (for each chain separately) the proposed point according
to the MH acceptance encoded in the :py:attr:`jexplore.steps.single.MHStep.mh`
method.
:param key: PRNG key
:param state: current state
:param group: indexes of the color chains
:param cgroup: indexes of the complementary chains for this color.
:return: new state and boolean mask of the changed chains
"""
masked = state.slice(group)
key, prop, qxy = self.proposal(key, state, group, cgroup)
new, stats = self.mh(key, masked, prop, self.beta[group], qxy)
return state.update_slice(new, group), stats