jexplore.steps.modsel ===================== .. py:module:: jexplore.steps.modsel .. autoapi-nested-parse:: Model Selection Markov Steps Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/jexplore/steps/modsel/inmodel_colored/index /autoapi/jexplore/steps/modsel/inmodel_single/index /autoapi/jexplore/steps/modsel/model_switch/index /autoapi/jexplore/steps/modsel/swap/index Attributes ---------- .. autoapisummary:: jexplore.steps.modsel.TSwapMS Classes ------- .. autoapisummary:: jexplore.steps.modsel.IMDEStep jexplore.steps.modsel.IMStretch jexplore.steps.modsel.AllChainsWrap jexplore.steps.modsel.DrawModel jexplore.steps.modsel.DrawPseudo jexplore.steps.modsel.OrderByModel Package Contents ---------------- .. py:class:: IMDEStep[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](gamma = 2.38, ngroups = 2, permute = False) Bases: :py:obj:`jexplore.steps.de.DEStep`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Class implementing a model selection in-model Differential evolution step. :param gamma: :math:`\gamma` scale parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. .. py:attribute:: sigmas :type: jax.Array sigma parameters for all models .. py:attribute:: get_partners Method for getting partners samples fore each chain of a group. The main difference from the base implementation in :py:attr:`jexplore.steps.colored.ColoredSC` is that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model. :param key: PRNG key :param state: current state :param group: group chains :param cgroup: complementary group chains :return: the parners as an array with shape (self.npars, group.size, dim) .. py:method:: build(epoch) Step initialisation method. It extends :py:attr:`jexplore.steps.colored.Colored.build` by simply adding the computation of the :math:`\sigma` of the :math:`\gamma` distribution. :param epoch: current epoch. .. py:method:: sample_gamma(key, state) Sample :math:`\gamma` from normal distribution :param key: PRNG key :param size: output size :return: samples .. py:class:: IMStretch[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](a = 2.0, ngroups = 2, permute = False) Bases: :py:obj:`jexplore.steps.stretch.Stretch`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Class implementing a model selection in-model steps based on stretch proposal. :param a: stretch proposal `a` parameter :param ngroups: number of groups. Default 2. :param permute: if true walkers are permuted at each iteration. .. py:attribute:: mdims :type: jax.Array .. py:attribute:: get_partners Method for getting partners samples fore each chain of a group. The main difference from the base implementation in :py:attr:`jexplore.steps.colored.ColoredSC` is that here each group chain gets its partners only among the complementary group chains (at the same temperature) that are in the same model. :param key: PRNG key :param state: current state :param group: group chains :param cgroup: complementary group chains :return: the parners as an array with shape (self.npars, group.size, dim) .. py:method:: build(epoch) Step initialisation method. It extends :py:attr:`jexplore.steps.step.Step.build` adding a call to `grouping`, to define the colors group, checks that the defined colors groups are even, rising a :py:attr:`BadColoring` exception otherwise. :param epoch: current epoch. .. py:class:: AllChainsWrap[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS, Tstep: jexplore.steps.mh.AllChains](step) Bases: :py:obj:`jexplore.steps.mh.AllChains`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Model selection wrapper for :py:attr:`jexplore.steps.mh.AllChains` steps. The step is wrapped so that it is applied only on chains that are in a model which (dimensions) mask fully contains the mask of the step. Note that the resulting step does not propose changes in to the chains' models. :param step: the step to be wrapped .. py:attribute:: wrapped_step :type: Tstep .. py:attribute:: active_models :type: jax.Array .. py:method:: build(epoch) Step epoch initialisation method. Extend :py:attr:`jexplore.steps.step.Step.build` by populating the `betas` attribute. :param epoch: current epoch. .. py:method:: proposal(key, state) Propose a new state. The proposed state is identical to the old one for all chains in inactive models (i.e. models whose mask do not contain the mask of the wrapped step) and correspond to the state proposed by the wrapped step for all other chains. :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. .. py:class:: DrawModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS](reorder = True, draw_pseudo = True) Bases: :py:obj:`jexplore.steps.mh.MHStep`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Draw models for each chain from a categorical which log weights are the sum of log likelihood, log priori and log pseudo prior. :param reorder: reorder chains by model (temp by temp) after running the model draw. :param draw_pseudo: draw from pseudo prior before drawing the models. .. py:attribute:: reorder :type: bool Reorder the chain by model (and temp) after drawing models .. py:attribute:: draw_pseudo :type: bool Draw pseudo prior before drawing models .. py:method:: build(epoch) Step epoch initialisation method. Extend :py:attr:`jexplore.steps.step.Step.build` by populating the `betas` attribute. :param epoch: current epoch. .. py:method:: step(key, state) Model drawing step :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. .. py:class:: DrawPseudo[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS] Bases: :py:obj:`jexplore.steps.step.Step`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Draws model pseudo prior samples for each chain .. py:attribute:: draw_funcs :type: list static list of pseudo prior rowing function for each model .. py:method:: build(epoch) Epoch initialisation method. :param epoch: current epoch. .. py:method:: step(key, state) Pseudo drawing step :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. .. py:class:: OrderByModel[Tepoch: jexplore.sampling.EpochMS, Tstate: jexplore.sampling.StateMS, Tsampling: jexplore.sampling.SamplingMS] Bases: :py:obj:`jexplore.steps.step.Step`\ [\ :py:obj:`Tepoch`\ , :py:obj:`Tstate`\ , :py:obj:`Tsampling`\ ] Swap chains to restore model ordering (temp by temp) .. py:method:: step(key, state) Pseudo drawing step :param key: PRNG key :param state: current state :return: new state and the boolean mask of the chains modified by the step. .. py:data:: TSwapMS Temperature swap for model selection