jexplore.steps.colored#
This module defines base classes for colored proposals.
Exceptions#
Exception for uneven groups. A colored step runs the substep of |
Classes#
Module Contents#
- exception BadColoring(groups, cgroups)[source]#
Bases:
ValueErrorException for uneven groups. A colored step runs the substep of each color in a jax.lax.scan loop and this requires all colors to have the same number of chains and complementary chains. If this is not the case the present exception is raised.
- Parameters:
groups (List[jax.Array]) – list of colors chain groups
cgroups (List[jax.Array]) – list of colors complementary chain groups
- class Colored[Tepoch: jexplore.sampling.Epoch, Tstate: jexplore.sampling.State, Tsampling: jexplore.sampling.Sampling](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#
Bases:
jexplore.steps.step.Step[Tepoch,Tstate,Tsampling]Generic class for colored chains step. This provides
a generic build method that calls a grouping function that defines the different colors chains groups. It also provides a prototype for the grouping method.
a generic step method that permutes chains callin the permuting method and calls colored_step method for all colors. It also provides prototypes for permuting and colored_step.
- Parameters:
ngroups (int) – number of groups
permute (bool) – if true chains are permuted at each iteration
duplicate (bool) – duplicate chains in group to make colors groups even
adjust (bool) – adjust the number of group to have even colors groups
- permute: bool#
permute chains at each iteration
- adjust: bool#
adjust number of groups to have even colors
- duplicate: bool#
duplicate chains to have even colors
- groups: List[jax.Array]#
list of indexes of the chains of each color
- cgroups: List[jax.Array]#
list of indexes of the complementary chains of each color
- ngroups: int#
number of colors
- build(epoch)[source]#
Step initialisation method. It extends
jexplore.steps.step.Step.buildadding a call to grouping, to define the colors group, checks that the defined colors groups are even, rising aBadColoringexception otherwise.- Parameters:
epoch (Tepoch) – current epoch.
- Return type:
None
- abstract permuting(key)[source]#
Permutes chains. This is just a prototype. The actual implementation depends on the type of coloring.
- Parameters:
key (jax.Array) – PRNG key
- Returns:
new PRNG key and permuted chains indexes.
- Return type:
tuple[jax.Array, jax.Array]
- abstract grouping()[source]#
Defines the groups and complementary groups of chains for each color. This is just a prototype. The actual implementation depends on the type of coloring.
- Returns:
indexes of the chains of each color, indexes of the complementary chains for each color.
- Return type:
tuple[list[jax.Array], list[jax.Array]]
- colored_step(key, state, group, cgroup)[source]#
Step to be executed on a single color. This is a prototype.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – indexes of the color chains
cgroup (jax.Array) – indexes of the complementary chains for this color.
- Returns:
new state and boolean mask of the changed chains
- Return type:
tuple[Tstate, jax.Array]
- runtime_grouping(key, groups, cgroups)[source]#
Runtime redefinition of groups and complementary groups. This will be called at each step.
In this default implementation this simply optionnaly calls the permute method to permute the chains in the colors definitions.
- Parameters:
key (jax.Array) – PRNG key.
group – indexes of the color chains.
cgroup – indexes of the complementary chains for this color.
groups (list[jax.Array])
cgroups (list[jax.Array])
- Returns:
new PRNG key and the newly definef groups and cgroups.
- Return type:
tuple[jax.Array, list[jax.Array], list[jax.Array]]
- step(key, state)[source]#
General implementation of the step sampling method for a colored step. There is in principle no need to extend or override this method when defining an actual colored step implemetation. The implementation details should rather go in the definition of the grouping, permuting and colored_step methods.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – the current state
- Returns:
new state and the boolean mask of the chains modified by the step.
- Return type:
tuple[Tstate, jax.Array]
- class ColoredMH[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#
Bases:
Colored[Tepoch,Tstate,Tsampling],jexplore.steps.mh.MHStep[Tepoch,Tstate,Tsampling]Generic class for colored Metropolis-Hastings step. With respect to the two parent classes this class implements explicitly a permuting method that defines permutations of walkers chains for each temperature.
- Parameters:
ngroups (int) – number of groups
permute (bool) – if true chains are permuted at each iteration
duplicate (bool) – duplicate chains in group to make colors groups even
adjust (bool) – adjust the number of group to have even colors groups
- permuting(key)[source]#
Permutes walkers chains for each temperature.
- Parameters:
key (jax.Array) – PRNG key
- Returns:
new PRNG key and permuted chains indexes.
- Return type:
tuple[jax.Array, jax.Array]
- abstract grouping()[source]#
Defines the groups and complementary groups of chains for each color. This is just a prototype. The actual implementation depends on the type of coloring.
- Returns:
indexes of the chains of each color, indexes of the complementary chains for each color.
- Return type:
tuple[list[jax.Array], list[jax.Array]]
- class ColoredSC[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#
Bases:
ColoredMH[Tepoch,Tstate,Tsampling]Generic class for “single” chain colored Metropoling-Hasting (example: stretch). This defines:
color grouping as equal size split of walkers chains (for all temperature) with the complementary group for each chain defined by the set of all chains not in the color group and having the same temperature.
colored_step as a MH acceptance test after a call to “reduced” in-color proposal encoded in the proposal method. It also provides a prototype for this method.
- Parameters:
ngroups (int) – number of groups
permute (bool) – if true chains are permuted at each iteration
duplicate (bool) – duplicate chains in group to make colors groups even
adjust (bool) – adjust the number of group to have even colors groups
- npart: int#
number of partners needed to build the proposal
- get_partners(key, state, group, cgroup)[source]#
Method for getting partners samples fore each chain of a group. This implementation simply draws randomly one partner for each chain in the group among those in the chains of complementary group with the same temperature.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – group chains
cgroup (jax.Array) – complementary group chains
- Returns:
the parners as an array with shape (self.npars, group.size, dim)
- Return type:
jax.Array
- abstract proposal(key, state, group, cgroup)[source]#
Proposal restricted to chains of one color. This is just a prototype.
- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – current color group
cgroup (jax.Array) – complementary chains for each chain in the group.
- Returns:
PNRG key, proposed state, transition probability for each chain.
- Return type:
tuple[jax.Array, Tstate, jax.Array]
- grouping()[source]#
Color grouping implementation. For each temperature the walkers are divided in ngroups equal sets. For each chain of each group the complementary chains are all the chains belonging to the other groups (and having the same temperature).
- Returns:
indexes of the chains of each color, indexes of the complementary chains for each color.
- Return type:
tuple[list[jax.Array], list[jax.Array]]
- colored_step(key, state, group, cgroup)[source]#
Colored step implementation. It calls the proposal method and then accepts/reject (for each chain separately) the proposed point according to the MH acceptance encoded in the
jexplore.steps.single.MHStep.mhmethod.- Parameters:
key (jax.Array) – PRNG key
state (Tstate) – current state
group (jax.Array) – indexes of the color chains
cgroup (jax.Array) – indexes of the complementary chains for this color.
- Returns:
new state and boolean mask of the changed chains
- Return type:
tuple[Tstate, jax.Array]