jexplore.sampler#

Jexplore: Jax based markov sampler

Attributes#

logger

logging instance

Tepoch

placeholder for epoch class

Tsampling

placeholder for samping definition class

Tbackend

placeholder for backend class

Tstate

placholder for state class

StepBuilder

Prototype definition of a step builder.

Classes#

Steps

Sampler steps class. This is a list of dictionnaries.

JaxSampler

Main jexplore sampler class.

Module Contents#

logger#

logging instance

Tepoch#

placeholder for epoch class

Tsampling#

placeholder for samping definition class

Tbackend#

placeholder for backend class

Tstate#

placholder for state class

type StepBuilder = Callable[[Tepoch], StepSample[Tstate]]#

Prototype definition of a step builder. This is a function getting an epoch as argument and returning a jexplore.sampling.epoch.StepSample function.

class Steps[source]#

Bases: list[dict[StepBuilder[Tepoch, Tstate], float]]

Sampler steps class. This is a list of dictionnaries. Each element of the list is a sequential substep and the corresponding dictionnary keys are the builders of the alternative steps for that substep and the values are the weights.

epoch_cycle(epoch)[source]#

Defines the epoch cycle. Each one of the StepBuilder dictionnary keys is called on with argument epoch, and replaced by the returned jexplore.sampling.epoch.StepSample instance.

Parameters:

epoch (Tepoch) – this epoch

Returns:

the epoch cycle list of dictionnaries.

Return type:

jexplore.sampling.EpochCycle

class JaxSampler[source]#

Bases: Generic[Tsampling, Tepoch, Tbackend]

Main jexplore sampler class.

Parameters:
  • sampling – sampling definition

  • steps – steps definition

  • backend – backend instance

sampling: Tsampling#

sampling parameters definition

steps: Steps#

steps definition

backend: Tbackend#

backend instance

run(epoch, niters, nepoch=1, seed=None)[source]#

Runs MCMC, stores results in the backend and returns last epoch.

Parameters:
  • epoch (Tepoch) – starting epoch

  • niters (int) – number of iterations per epoch

  • nepoch (int) – number of epochs

  • seed (int | None) – seed. If None a random seed is generated.

Returns:

last epoch.

Return type:

Tepoch

static run_epoch(epoch, niters, key)[source]#

Runs one MCMC epoch iterations and returns the resulting Epoch object.

Parameters:
  • epoch (Tepoch) – last epoch

  • nstep (int) – number of MCMC iterations.

  • key (PRNGKey) – PRNGKey number

  • niters (int)

Returns:

key and the new Epoch

Return type:

tuple[jax.Array, Tepoch]