jexplore.sampler#
Jexplore: Jax based markov sampler
Attributes#
Classes#
Sampler steps class. This is a list of dictionnaries. |
|
Main jexplore sampler class. |
Module Contents#
- logger#
logging instance
- Tepoch#
placeholder for epoch class
- Tsampling#
placeholder for samping definition class
- Tbackend#
placeholder for backend class
- Tstate#
placholder for state class
- type StepBuilder = Callable[[Tepoch], StepSample[Tstate]]#
Prototype definition of a step builder. This is a function getting an epoch as argument and returning a
jexplore.sampling.epoch.StepSamplefunction.
- class Steps[source]#
Bases:
list[dict[StepBuilder[Tepoch,Tstate],float]]Sampler steps class. This is a list of dictionnaries. Each element of the list is a sequential substep and the corresponding dictionnary keys are the builders of the alternative steps for that substep and the values are the weights.
- epoch_cycle(epoch)[source]#
Defines the epoch cycle. Each one of the StepBuilder dictionnary keys is called on with argument epoch, and replaced by the returned
jexplore.sampling.epoch.StepSampleinstance.- Parameters:
epoch (Tepoch) – this epoch
- Returns:
the epoch cycle list of dictionnaries.
- Return type:
- class JaxSampler[source]#
Bases:
Generic[Tsampling,Tepoch,Tbackend]Main jexplore sampler class.
- Parameters:
sampling – sampling definition
steps – steps definition
backend – backend instance
- sampling: Tsampling#
sampling parameters definition
- backend: Tbackend#
backend instance
- run(epoch, niters, nepoch=1, seed=None)[source]#
Runs MCMC, stores results in the backend and returns last epoch.
- Parameters:
epoch (Tepoch) – starting epoch
niters (int) – number of iterations per epoch
nepoch (int) – number of epochs
seed (int | None) – seed. If None a random seed is generated.
- Returns:
last epoch.
- Return type:
Tepoch
- static run_epoch(epoch, niters, key)[source]#
Runs one MCMC epoch iterations and returns the resulting Epoch object.
- Parameters:
epoch (Tepoch) – last epoch
nstep (int) – number of MCMC iterations.
key (PRNGKey) – PRNGKey number
niters (int)
- Returns:
key and the new Epoch
- Return type:
tuple[jax.Array, Tepoch]