jexplore.steps.colored#

This module defines base classes for colored proposals.

Exceptions#

BadColoring

Exception for uneven groups. A colored step runs the substep of

Classes#

Colored

Generic class for colored chains step. This provides

ColoredMH

Generic class for colored Metropolis-Hastings step. With respect

ColoredSC

Generic class for "single" chain colored Metropoling-Hasting (example: stretch).

Module Contents#

exception BadColoring(groups, cgroups)[source]#

Bases: ValueError

Exception for uneven groups. A colored step runs the substep of each color in a jax.lax.scan loop and this requires all colors to have the same number of chains and complementary chains. If this is not the case the present exception is raised.

Parameters:
  • groups (List[jax.Array]) – list of colors chain groups

  • cgroups (List[jax.Array]) – list of colors complementary chain groups

class Colored[Tepoch: jexplore.sampling.Epoch, Tstate: jexplore.sampling.State, Tsampling: jexplore.sampling.Sampling](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#

Bases: jexplore.steps.step.Step[Tepoch, Tstate, Tsampling]

Generic class for colored chains step. This provides

  • a generic build method that calls a grouping function that defines the different colors chains groups. It also provides a prototype for the grouping method.

  • a generic step method that permutes chains callin the permuting method and calls colored_step method for all colors. It also provides prototypes for permuting and colored_step.

Parameters:
  • ngroups (int) – number of groups

  • permute (bool) – if true chains are permuted at each iteration

  • duplicate (bool) – duplicate chains in group to make colors groups even

  • adjust (bool) – adjust the number of group to have even colors groups

permute: bool#

permute chains at each iteration

adjust: bool#

adjust number of groups to have even colors

duplicate: bool#

duplicate chains to have even colors

groups: List[jax.Array]#

list of indexes of the chains of each color

cgroups: List[jax.Array]#

list of indexes of the complementary chains of each color

ngroups: int#

number of colors

build(epoch)[source]#

Step initialisation method. It extends jexplore.steps.step.Step.build adding a call to grouping, to define the colors group, checks that the defined colors groups are even, rising a BadColoring exception otherwise.

Parameters:

epoch (Tepoch) – current epoch.

Return type:

None

abstract permuting(key)[source]#

Permutes chains. This is just a prototype. The actual implementation depends on the type of coloring.

Parameters:

key (jax.Array) – PRNG key

Returns:

new PRNG key and permuted chains indexes.

Return type:

tuple[jax.Array, jax.Array]

abstract grouping()[source]#

Defines the groups and complementary groups of chains for each color. This is just a prototype. The actual implementation depends on the type of coloring.

Returns:

indexes of the chains of each color, indexes of the complementary chains for each color.

Return type:

tuple[list[jax.Array], list[jax.Array]]

colored_step(key, state, group, cgroup)[source]#

Step to be executed on a single color. This is a prototype.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – indexes of the color chains

  • cgroup (jax.Array) – indexes of the complementary chains for this color.

Returns:

new state and boolean mask of the changed chains

Return type:

tuple[Tstate, jax.Array]

runtime_grouping(key, groups, cgroups)[source]#

Runtime redefinition of groups and complementary groups. This will be called at each step.

In this default implementation this simply optionnaly calls the permute method to permute the chains in the colors definitions.

Parameters:
  • key (jax.Array) – PRNG key.

  • group – indexes of the color chains.

  • cgroup – indexes of the complementary chains for this color.

  • groups (list[jax.Array])

  • cgroups (list[jax.Array])

Returns:

new PRNG key and the newly definef groups and cgroups.

Return type:

tuple[jax.Array, list[jax.Array], list[jax.Array]]

step(key, state)[source]#

General implementation of the step sampling method for a colored step. There is in principle no need to extend or override this method when defining an actual colored step implemetation. The implementation details should rather go in the definition of the grouping, permuting and colored_step methods.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – the current state

Returns:

new state and the boolean mask of the chains modified by the step.

Return type:

tuple[Tstate, jax.Array]

class ColoredMH[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#

Bases: Colored[Tepoch, Tstate, Tsampling], jexplore.steps.mh.MHStep[Tepoch, Tstate, Tsampling]

Generic class for colored Metropolis-Hastings step. With respect to the two parent classes this class implements explicitly a permuting method that defines permutations of walkers chains for each temperature.

Parameters:
  • ngroups (int) – number of groups

  • permute (bool) – if true chains are permuted at each iteration

  • duplicate (bool) – duplicate chains in group to make colors groups even

  • adjust (bool) – adjust the number of group to have even colors groups

permuting(key)[source]#

Permutes walkers chains for each temperature.

Parameters:

key (jax.Array) – PRNG key

Returns:

new PRNG key and permuted chains indexes.

Return type:

tuple[jax.Array, jax.Array]

abstract grouping()[source]#

Defines the groups and complementary groups of chains for each color. This is just a prototype. The actual implementation depends on the type of coloring.

Returns:

indexes of the chains of each color, indexes of the complementary chains for each color.

Return type:

tuple[list[jax.Array], list[jax.Array]]

class ColoredSC[Tepoch: jexplore.sampling.EpochMH, Tstate: jexplore.sampling.StateMH, Tsampling: jexplore.sampling.SamplingMH](ngroups=2, permute=False, duplicate=False, adjust=False)[source]#

Bases: ColoredMH[Tepoch, Tstate, Tsampling]

Generic class for “single” chain colored Metropoling-Hasting (example: stretch). This defines:

  • color grouping as equal size split of walkers chains (for all temperature) with the complementary group for each chain defined by the set of all chains not in the color group and having the same temperature.

  • colored_step as a MH acceptance test after a call to “reduced” in-color proposal encoded in the proposal method. It also provides a prototype for this method.

Parameters:
  • ngroups (int) – number of groups

  • permute (bool) – if true chains are permuted at each iteration

  • duplicate (bool) – duplicate chains in group to make colors groups even

  • adjust (bool) – adjust the number of group to have even colors groups

npart: int#

number of partners needed to build the proposal

get_partners(key, state, group, cgroup)[source]#

Method for getting partners samples fore each chain of a group. This implementation simply draws randomly one partner for each chain in the group among those in the chains of complementary group with the same temperature.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – group chains

  • cgroup (jax.Array) – complementary group chains

Returns:

the parners as an array with shape (self.npars, group.size, dim)

Return type:

jax.Array

abstract proposal(key, state, group, cgroup)[source]#

Proposal restricted to chains of one color. This is just a prototype.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – current color group

  • cgroup (jax.Array) – complementary chains for each chain in the group.

Returns:

PNRG key, proposed state, transition probability for each chain.

Return type:

tuple[jax.Array, Tstate, jax.Array]

grouping()[source]#

Color grouping implementation. For each temperature the walkers are divided in ngroups equal sets. For each chain of each group the complementary chains are all the chains belonging to the other groups (and having the same temperature).

Returns:

indexes of the chains of each color, indexes of the complementary chains for each color.

Return type:

tuple[list[jax.Array], list[jax.Array]]

colored_step(key, state, group, cgroup)[source]#

Colored step implementation. It calls the proposal method and then accepts/reject (for each chain separately) the proposed point according to the MH acceptance encoded in the jexplore.steps.single.MHStep.mh method.

Parameters:
  • key (jax.Array) – PRNG key

  • state (Tstate) – current state

  • group (jax.Array) – indexes of the color chains

  • cgroup (jax.Array) – indexes of the complementary chains for this color.

Returns:

new state and boolean mask of the changed chains

Return type:

tuple[Tstate, jax.Array]