Source code for jexplore.backends.default

"""
Ths module defines the default version of a Jexplore
backend for storing data in memory and on disk.
"""

import glob
import logging
import os
from copy import deepcopy
from importlib import import_module
from types import EllipsisType
from typing import Self, cast

import h5py  # type: ignore
import numpy as np

from jexplore.sampling import Epoch

DEFAULTBKND_CONFIG = {
    "outdir": "./",
    "outfile_fmt": "epochs_%d.h5",
    "epochs_per_file": 1,
    "inmem_epochs": 0,
    "save_stats": True,
}
"""Default backend config parameters"""

logger = logging.getLogger(__name__)
"""logger instance"""


[docs] class DefaultBackend[Tepoch: Epoch]: r"""Default backend dumping samples on disk per epoch. :param burn: number of samples to burn (these will not be registered by the backend) :param \**config: backend configuration options. """ burn: int """burned samples""" _config: dict _registry: list[dict] _epochs: list[dict | None] def __init__(self, burn: int = 0, **config: dict): self._config = DEFAULTBKND_CONFIG | config self.burn = burn self._registry = [] self._epochs = []
[docs] def ingest(self, epoch: Tepoch | dict | str, eind=-1) -> Self: """loads and store one epoch in the backend. :param epoch: the epoch to be loaded. Can be an epoch instance, a dictionnary or the path to a backend file. :param eind: if `epoch` is the path to a backend file this is the index of the epoch. Same syntax that :py:attr:`jexplore.backend.default.DefaultBackend.read_epoch` :return: the backend object. """ _epoch: dict | None if isinstance(epoch, dict): _epoch = epoch copy = True elif isinstance(epoch, str): _epoch = self.read_epoch(epoch, eind) copy = False else: _, _, _epoch = cast(Tepoch, epoch).to_backend( burn=0, stats=self._config["save_stats"] ) copy = False return self.ingest_dict(cast(dict, _epoch), copy=copy)
[docs] def ingest_dict(self, epoch: dict, copy: bool = True) -> Self: """Loads and store one epoch dictionnary in the backend. :param epoch: the epoch dictionnary :param copy: if true the dictionary is copied before storing :return: the backend object """ if copy: epoch = deepcopy(epoch) _burn = self._get_nsamples()[1] _nsamples = epoch["samples"]["p"].shape[-1] if _burn >= _nsamples: _burn = _nsamples _nsamples = 0 else: _nsamples -= _burn epoch["samples"] = { _nam: _vals[:, :, _burn:] for _nam, _vals in epoch["samples"].items() } _ind = len(self._registry) self._epochs.append(epoch if _nsamples > 0 else None) self._registry.append( { "index": _ind, "name": f"epoch_{_ind}", "location": 0 if _nsamples > 0 else -1, "nsamples": _nsamples, "burn": _burn, } ) self.update() return self
[docs] def get(self, ind: int = -1) -> dict: """ Loads an epoch from registry. :param ind: index of the epoch in the registry. :return: epoch data as a dictionnary. """ _reg = self._registry[ind] if _reg["location"] == -1: raise ValueError(f"Epoch {ind} was burned.") if _reg["location"] == 0: return cast(dict, self._epochs[ind]) return self.read_epoch( fn=self._get_fname(_reg["location"]), iepoch=_reg["index"] )
def _get_nsamples(self) -> tuple[int, int]: """return number of samples that still have to be burn""" _burn = self.burn _nsamples = 0 for _epoch in self._registry: _burn -= _epoch["burn"] _nsamples += _epoch["nsamples"] return _nsamples, _burn
[docs] def set_burn(self, burn: int) -> None: """Sets the number of samples to be burned and removes stored samples if necessary. :param burn: number of burned samples """ if self.burn > burn: raise ValueError( "Burned samples are burned. You can only augment burn value." ) self.burn = burn self.update()
[docs] def update(self) -> None: """This implements the storing logic of this backend. It parses the registry of the ingested epochs and implements storing policies according to the configuration paremeters.""" ( _inmem, _file, ) = self._check_registry() _burn = self.burn for _ind, _entry in enumerate(self._registry): logging.debug( "Processing epoch %s: %s. Left to burn: %s.", _ind, _entry, _burn ) if _entry["location"] == -1: _burn -= _entry["burn"] logger.debug( "Epoch %s already burned: %s. Left to burn: %s.", _ind, _entry, _burn, ) continue _burn -= _entry["burn"] if _burn > 0: if _burn >= _entry["nsamples"]: _burn -= _entry["nsamples"] self._burn_epoch(_ind) logger.debug( "Epoch %s burned: %s. Left to burn: %s.", _ind, _entry, _burn ) continue _entry["nsamples"] -= _burn self._cut_to_size(_ind, _entry["nsamples"]) _entry["burn"] += _burn _burn = 0 logger.debug( "Epoch %s cut to size: %s. Left to burn: %s.", _ind, _entry, _burn ) # At this point we may have 1 (and only 1) inmem epoch # to put on disk. The last one. logger.debug("Pre-dump status: %s", self._registry) if (self._registry[-1]["location"] != 0) or ( _inmem <= self._config["inmem_epochs"] ): logger.debug("No need to dump on file.") return _iepoch = len(self._registry) - _inmem logger.debug("Dumping on file %s epoch %s", _file, _iepoch) self._dump_to_file(_file, _iepoch)
def _check_registry(self) -> tuple[int, int]: _find = 1 _inmem = 0 _fcount = 0 for _entry in self._registry: if _entry["location"] == -1: continue if _entry["location"] == 0: _inmem += 1 continue if _find != _entry["location"]: _find = _entry["location"] _fcount = 1 continue _fcount += 1 if _fcount == self._config["epochs_per_file"]: _find += 1 return _inmem, _find def _get_fname(self, fnum: int) -> str: return os.path.join(self._config["outdir"], self._config["outfile_fmt"] % fnum) def _burn_epoch(self, ind: int) -> None: _epoch = self._registry[ind] if _epoch["location"] == 0: self._epochs[ind] = None else: _fname = self._get_fname(_epoch["location"]) with h5py.File(_fname, "a") as _fhand: del _fhand[_epoch["name"]] _left = len(_fhand.keys()) if _left == 0: os.remove(_fname) _epoch["location"] = -1 _epoch["burn"] += _epoch["nsamples"] _epoch["nsamples"] = 0 def _cut_to_size(self, ind: int, nsamples: int) -> None: if self._registry[ind]["location"] == 0: self._cut_to_size_inmem(ind, nsamples) else: self._cut_to_size_h5(ind, nsamples) def _cut_to_size_inmem(self, ind: int, nsamples: int) -> None: def _cutter(_par, _pkey): if isinstance(_par[_pkey], dict): for _key in _par[_pkey]: _cutter(_par[_pkey], _key) return _par[_pkey] = _par[_pkey][:, :, -nsamples:] return _cutter(self._epochs[ind], "samples") def _cut_to_size_h5(self, ind: int, nsamples: int) -> None: epoch = self._registry[ind] fname = self._get_fname(epoch["location"]) fhand = h5py.File(fname, "a") def _cutter(_path): if isinstance(fhand[_path], h5py.Dataset): _dat = cast(h5py.Dataset, fhand[_path])[...][:, :, -nsamples:] del fhand[_path] fhand.create_dataset(_path, data=_dat, dtype=_dat.dtype) return for _key in cast(h5py.Group, fhand[_path]).keys(): _cutter(_path + "/" + _key) _cutter(epoch["name"] + "/samples") fhand.close() def _dump_to_file(self, find: int, iepoch: int) -> None: epoch = self._registry[iepoch] fname = self._get_fname(find) os.makedirs(self._config["outdir"], exist_ok=True) fhand = h5py.File(fname, "a") def _writer(_obj, _path): if isinstance(_obj, list): try: _obj = np.array(_obj) # Inomogeneous list can be transformed in dictionnary except ValueError: _obj = {str(_ind): _val for _ind, _val in enumerate(_obj)} if isinstance(_obj, dict): for _key, _val in _obj.items(): _writer(_val, _path + "/" + _key) return if isinstance(_obj, str): _dtype = h5py.string_dtype(encoding="utf-8") else: _dtype = _obj.dtype fhand.create_dataset(_path, data=_obj, dtype=_dtype) _writer(self._epochs[iepoch], epoch["name"]) fhand.close() self._epochs[iepoch] = None epoch["location"] = find
[docs] @staticmethod def flatten_samples( samples: dict[str, np.ndarray] | None, ) -> dict[str, np.ndarray] | None: """ Flatten a (chain, dim, nsamples) chain into a (n. all samples, dim) format. :param samples: the samples dictionnary :return: flatten samples dictionnary """ if samples is None: return samples _new_samples = {} for _key, _smpl in samples.items(): _dim = _smpl.shape[1] _new_samples[_key] = _smpl.swapaxes(1, 2).reshape(-1, _dim) return _new_samples
# pylint: disable=too-many-arguments,too-many-positional-arguments
[docs] def get_samples( self, burn: int = 0, thin: int = 1, pars: list[str] | None = None, mask: np.ndarray | EllipsisType = Ellipsis, flatten: bool = False, ) -> dict[str, np.ndarray] | None: """Get the samples from the backend. :param burn: number of initial samples to discart. This is added on top of the backend instance `burn` parameter. :param thin: thinning factor. :param pars: list of names of the sample parameters to be returned. Default: only "p". :param mask: mask (boolean mask or list of index) for the chains to be retrieved. :param flatten: returns flattened samples with shape (n. samples, dim). :return: a dictionnary which keys are the parameters names and values the returned chains. None if no samples could be retrieved. """ pars = ["p"] if pars is None else pars _samples: dict[str, list] = {_par: [] for _par in pars} for _eind, _entry in enumerate(self._registry): if _entry["location"] == -1: continue if _entry["nsamples"] <= burn: burn -= _entry["nsamples"] continue _dat = self.get(_eind)["samples"] for _par in pars: _samples[_par].append(_dat[_par][mask, :, burn::thin]) burn = 0 samples: dict[str, np.ndarray] | None = { _par: np.concatenate(_sam, axis=2) for _par, _sam in _samples.items() } if flatten: samples = self.flatten_samples(samples) return samples
[docs] def get_stats( self, epochs: list[int] | None = None, fname: str | None = None ) -> dict[str, dict[str, np.ndarray]]: """ Returns a dictionnary of epoch stats. Optionnally save the dictionnary to a h5 file. :param epochs: list of epochs. If None (default) all non burned epochs are considered. :param fname: if not None, the epochs stats are saved to file. :return: the stats of the selected epochs. """ _ret = {} for _entry in self._registry: if (epochs is not None) and (_entry["index"] not in epochs): continue _ename = f"epoch_{_entry['index']}" try: _ret[_ename] = self.get(_entry["index"])["stats"] except ValueError: continue if _entry["location"] == -1: continue if fname is not None: self._save_stats(_ret, fname) return _ret
@staticmethod def _save_stats(stats: dict, fname: str) -> None: with h5py.File(fname, "w") as fhand: for _ename, _epoch in stats.items(): for _key, _dat in _epoch.items(): fhand.create_dataset( f"{_ename}/{_key}", data=_dat, dtype=_dat.dtype )
[docs] def clean(self) -> None: """ Clean up all epochs data from backend """ self.set_burn(self.burn + self._get_nsamples()[0])
[docs] @classmethod def read_epoch(cls, fn: str, iepoch: int = -1) -> dict: """ Loads an epoch from a backend `h5` file. :param fn: file name :param iepoch: index of the epoch in the file. Note that this is the overall index in the backend that created the file. If it is a negative integer `-i` it takes the `-ith` epoch from the last (ordered by their epoch index). E.g if `iepoch` is -1 it takes the last epoch. :return: epoch data as a dictionnary. """ fhand = h5py.File(fn, "r") if iepoch < 0: try: _ename = np.sort(list(fhand.keys()))[iepoch] except IndexError as exc: _enum = len(list(fhand.keys())) raise ValueError( f"There are only {_enum} epochs in this file." ) from exc else: _ename = f"epoch_{iepoch}" if _ename not in fhand.keys(): raise ValueError(f"No epoch {_ename} in the file.") def _reader(_path): if isinstance(fhand[_path], h5py.Dataset): _ret = cast(h5py.Dataset, fhand[_path])[...][()] if isinstance(_ret, bytes): _ret = _ret.decode("utf-8") return _ret _dict = { _key: _reader(_path + "/" + _key) for _key in cast(h5py.Group, fhand[_path]).keys() } try: assert set(map(int, _dict.keys())) == set(range(len(_dict))) except (ValueError, AssertionError): return _dict # If this can be interpreted as a list... return a list return list(_dict.values()) _ret = _reader(_ename) fhand.close() return _ret
[docs] def load( self, iepoch: int = -1, eclass: type[Tepoch] | None = None, pars: list[str] | None = None, ) -> Tepoch: """ Loads an epoch from a backend `h5` file. :param iepoch: index of the epoch in the file. Note that this is the overall index in the backend that created the file. If it is a negative integer `-i` it takes the `-ith` epoch from the last (ordered by their epoch index). E.g if `iepoch` is -1 it takes the last epoch. :param eclass: epoch class. :param pars: samples parameters to be loaded. Default: all. :return: the loaded epoch. """ _epoch = self.get(iepoch) if eclass is None: _path = _epoch["class"].split(".") _ecl = cast( type[Tepoch], getattr(import_module(".".join(_path[:-1])), _path[-1]) ) else: _ecl = cast(type[Tepoch], eclass) pars = list(_ecl.statecls.__dataclass_fields__) if pars is None else pars return _ecl( {_key: _epoch["samples"][_key] for _key in pars} | {"covs": _epoch["covs"]} )
[docs] @staticmethod def load_epoch( fn: str, iepoch: int = -1, eclass: type[Tepoch] | None = None, pars: list[str] | None = None, ) -> Tepoch: """ Loads an epoch from a backend `h5` file. :param fn: file name :param iepoch: index of the epoch in the file. Note that this is the overall index in the backend that created the file. If it is a negative integer `-i` it takes the `-ith` epoch from the last (ordered by their epoch index). E.g if `iepoch` is -1 it takes the last epoch. :param eclass: epoch class. :param pars: samples parameters to be loaded. Default: all. :return: the loaded epoch. """ return ( DefaultBackend(inmem_epochs=1) # type: ignore[arg-type] .ingest(fn, eind=iepoch) .load(eclass=eclass, pars=pars) # type: ignore[arg-type] )
[docs] def reset(self): """Remove all existing files""" fmt = self._config["outfile_fmt"].replace("%d", "%s") fnames = glob.glob(os.path.join(self._config["outdir"], fmt % "*")) for fname in fnames: os.remove(fname)