"""
Ths module defines the default version of a Jexplore
backend for storing data in memory and on disk.
"""
import glob
import logging
import os
from copy import deepcopy
from importlib import import_module
from types import EllipsisType
from typing import Self, cast
import h5py # type: ignore
import numpy as np
from jexplore.sampling import Epoch
DEFAULTBKND_CONFIG = {
"outdir": "./",
"outfile_fmt": "epochs_%d.h5",
"epochs_per_file": 1,
"inmem_epochs": 0,
"save_stats": True,
}
"""Default backend config parameters"""
logger = logging.getLogger(__name__)
"""logger instance"""
[docs]
class DefaultBackend[Tepoch: Epoch]:
r"""Default backend dumping samples on disk per epoch.
:param burn: number of samples to burn (these will not be registered by the
backend)
:param \**config: backend configuration options.
"""
burn: int
"""burned samples"""
_config: dict
_registry: list[dict]
_epochs: list[dict | None]
def __init__(self, burn: int = 0, **config: dict):
self._config = DEFAULTBKND_CONFIG | config
self.burn = burn
self._registry = []
self._epochs = []
[docs]
def ingest(self, epoch: Tepoch | dict | str, eind=-1) -> Self:
"""loads and store one epoch in the backend.
:param epoch: the epoch to be loaded. Can be an epoch
instance, a dictionnary or the path to a
backend file.
:param eind: if `epoch` is the path to a backend file this
is the index of the epoch. Same syntax that
:py:attr:`jexplore.backend.default.DefaultBackend.read_epoch`
:return: the backend object.
"""
_epoch: dict | None
if isinstance(epoch, dict):
_epoch = epoch
copy = True
elif isinstance(epoch, str):
_epoch = self.read_epoch(epoch, eind)
copy = False
else:
_, _, _epoch = cast(Tepoch, epoch).to_backend(
burn=0, stats=self._config["save_stats"]
)
copy = False
return self.ingest_dict(cast(dict, _epoch), copy=copy)
[docs]
def ingest_dict(self, epoch: dict, copy: bool = True) -> Self:
"""Loads and store one epoch dictionnary in the
backend.
:param epoch: the epoch dictionnary
:param copy: if true the dictionary is copied before storing
:return: the backend object
"""
if copy:
epoch = deepcopy(epoch)
_burn = self._get_nsamples()[1]
_nsamples = epoch["samples"]["p"].shape[-1]
if _burn >= _nsamples:
_burn = _nsamples
_nsamples = 0
else:
_nsamples -= _burn
epoch["samples"] = {
_nam: _vals[:, :, _burn:] for _nam, _vals in epoch["samples"].items()
}
_ind = len(self._registry)
self._epochs.append(epoch if _nsamples > 0 else None)
self._registry.append(
{
"index": _ind,
"name": f"epoch_{_ind}",
"location": 0 if _nsamples > 0 else -1,
"nsamples": _nsamples,
"burn": _burn,
}
)
self.update()
return self
[docs]
def get(self, ind: int = -1) -> dict:
"""
Loads an epoch from registry.
:param ind: index of the epoch in the registry.
:return: epoch data as a dictionnary.
"""
_reg = self._registry[ind]
if _reg["location"] == -1:
raise ValueError(f"Epoch {ind} was burned.")
if _reg["location"] == 0:
return cast(dict, self._epochs[ind])
return self.read_epoch(
fn=self._get_fname(_reg["location"]), iepoch=_reg["index"]
)
def _get_nsamples(self) -> tuple[int, int]:
"""return number of samples that still have to be burn"""
_burn = self.burn
_nsamples = 0
for _epoch in self._registry:
_burn -= _epoch["burn"]
_nsamples += _epoch["nsamples"]
return _nsamples, _burn
[docs]
def set_burn(self, burn: int) -> None:
"""Sets the number of samples to be burned and removes stored
samples if necessary.
:param burn: number of burned samples
"""
if self.burn > burn:
raise ValueError(
"Burned samples are burned. You can only augment burn value."
)
self.burn = burn
self.update()
[docs]
def update(self) -> None:
"""This implements the storing logic of this backend. It parses
the registry of the ingested epochs and implements storing policies
according to the configuration paremeters."""
(
_inmem,
_file,
) = self._check_registry()
_burn = self.burn
for _ind, _entry in enumerate(self._registry):
logging.debug(
"Processing epoch %s: %s. Left to burn: %s.", _ind, _entry, _burn
)
if _entry["location"] == -1:
_burn -= _entry["burn"]
logger.debug(
"Epoch %s already burned: %s. Left to burn: %s.",
_ind,
_entry,
_burn,
)
continue
_burn -= _entry["burn"]
if _burn > 0:
if _burn >= _entry["nsamples"]:
_burn -= _entry["nsamples"]
self._burn_epoch(_ind)
logger.debug(
"Epoch %s burned: %s. Left to burn: %s.", _ind, _entry, _burn
)
continue
_entry["nsamples"] -= _burn
self._cut_to_size(_ind, _entry["nsamples"])
_entry["burn"] += _burn
_burn = 0
logger.debug(
"Epoch %s cut to size: %s. Left to burn: %s.", _ind, _entry, _burn
)
# At this point we may have 1 (and only 1) inmem epoch
# to put on disk. The last one.
logger.debug("Pre-dump status: %s", self._registry)
if (self._registry[-1]["location"] != 0) or (
_inmem <= self._config["inmem_epochs"]
):
logger.debug("No need to dump on file.")
return
_iepoch = len(self._registry) - _inmem
logger.debug("Dumping on file %s epoch %s", _file, _iepoch)
self._dump_to_file(_file, _iepoch)
def _check_registry(self) -> tuple[int, int]:
_find = 1
_inmem = 0
_fcount = 0
for _entry in self._registry:
if _entry["location"] == -1:
continue
if _entry["location"] == 0:
_inmem += 1
continue
if _find != _entry["location"]:
_find = _entry["location"]
_fcount = 1
continue
_fcount += 1
if _fcount == self._config["epochs_per_file"]:
_find += 1
return _inmem, _find
def _get_fname(self, fnum: int) -> str:
return os.path.join(self._config["outdir"], self._config["outfile_fmt"] % fnum)
def _burn_epoch(self, ind: int) -> None:
_epoch = self._registry[ind]
if _epoch["location"] == 0:
self._epochs[ind] = None
else:
_fname = self._get_fname(_epoch["location"])
with h5py.File(_fname, "a") as _fhand:
del _fhand[_epoch["name"]]
_left = len(_fhand.keys())
if _left == 0:
os.remove(_fname)
_epoch["location"] = -1
_epoch["burn"] += _epoch["nsamples"]
_epoch["nsamples"] = 0
def _cut_to_size(self, ind: int, nsamples: int) -> None:
if self._registry[ind]["location"] == 0:
self._cut_to_size_inmem(ind, nsamples)
else:
self._cut_to_size_h5(ind, nsamples)
def _cut_to_size_inmem(self, ind: int, nsamples: int) -> None:
def _cutter(_par, _pkey):
if isinstance(_par[_pkey], dict):
for _key in _par[_pkey]:
_cutter(_par[_pkey], _key)
return
_par[_pkey] = _par[_pkey][:, :, -nsamples:]
return
_cutter(self._epochs[ind], "samples")
def _cut_to_size_h5(self, ind: int, nsamples: int) -> None:
epoch = self._registry[ind]
fname = self._get_fname(epoch["location"])
fhand = h5py.File(fname, "a")
def _cutter(_path):
if isinstance(fhand[_path], h5py.Dataset):
_dat = cast(h5py.Dataset, fhand[_path])[...][:, :, -nsamples:]
del fhand[_path]
fhand.create_dataset(_path, data=_dat, dtype=_dat.dtype)
return
for _key in cast(h5py.Group, fhand[_path]).keys():
_cutter(_path + "/" + _key)
_cutter(epoch["name"] + "/samples")
fhand.close()
def _dump_to_file(self, find: int, iepoch: int) -> None:
epoch = self._registry[iepoch]
fname = self._get_fname(find)
os.makedirs(self._config["outdir"], exist_ok=True)
fhand = h5py.File(fname, "a")
def _writer(_obj, _path):
if isinstance(_obj, list):
try:
_obj = np.array(_obj)
# Inomogeneous list can be transformed in dictionnary
except ValueError:
_obj = {str(_ind): _val for _ind, _val in enumerate(_obj)}
if isinstance(_obj, dict):
for _key, _val in _obj.items():
_writer(_val, _path + "/" + _key)
return
if isinstance(_obj, str):
_dtype = h5py.string_dtype(encoding="utf-8")
else:
_dtype = _obj.dtype
fhand.create_dataset(_path, data=_obj, dtype=_dtype)
_writer(self._epochs[iepoch], epoch["name"])
fhand.close()
self._epochs[iepoch] = None
epoch["location"] = find
[docs]
@staticmethod
def flatten_samples(
samples: dict[str, np.ndarray] | None,
) -> dict[str, np.ndarray] | None:
"""
Flatten a (chain, dim, nsamples) chain into a (n. all samples, dim) format.
:param samples: the samples dictionnary
:return: flatten samples dictionnary
"""
if samples is None:
return samples
_new_samples = {}
for _key, _smpl in samples.items():
_dim = _smpl.shape[1]
_new_samples[_key] = _smpl.swapaxes(1, 2).reshape(-1, _dim)
return _new_samples
# pylint: disable=too-many-arguments,too-many-positional-arguments
[docs]
def get_samples(
self,
burn: int = 0,
thin: int = 1,
pars: list[str] | None = None,
mask: np.ndarray | EllipsisType = Ellipsis,
flatten: bool = False,
) -> dict[str, np.ndarray] | None:
"""Get the samples from the backend.
:param burn: number of initial samples to discart. This is added
on top of the backend instance `burn` parameter.
:param thin: thinning factor.
:param pars: list of names of the sample parameters to be returned.
Default: only "p".
:param mask: mask (boolean mask or list of index) for the chains to be
retrieved.
:param flatten: returns flattened samples with shape (n. samples, dim).
:return: a dictionnary which keys are the parameters names and values
the returned chains. None if no samples could be retrieved.
"""
pars = ["p"] if pars is None else pars
_samples: dict[str, list] = {_par: [] for _par in pars}
for _eind, _entry in enumerate(self._registry):
if _entry["location"] == -1:
continue
if _entry["nsamples"] <= burn:
burn -= _entry["nsamples"]
continue
_dat = self.get(_eind)["samples"]
for _par in pars:
_samples[_par].append(_dat[_par][mask, :, burn::thin])
burn = 0
samples: dict[str, np.ndarray] | None = {
_par: np.concatenate(_sam, axis=2) for _par, _sam in _samples.items()
}
if flatten:
samples = self.flatten_samples(samples)
return samples
[docs]
def get_stats(
self, epochs: list[int] | None = None, fname: str | None = None
) -> dict[str, dict[str, np.ndarray]]:
"""
Returns a dictionnary of epoch stats. Optionnally save the dictionnary to a
h5 file.
:param epochs: list of epochs. If None (default) all non burned epochs are
considered.
:param fname: if not None, the epochs stats are saved to file.
:return: the stats of the selected epochs.
"""
_ret = {}
for _entry in self._registry:
if (epochs is not None) and (_entry["index"] not in epochs):
continue
_ename = f"epoch_{_entry['index']}"
try:
_ret[_ename] = self.get(_entry["index"])["stats"]
except ValueError:
continue
if _entry["location"] == -1:
continue
if fname is not None:
self._save_stats(_ret, fname)
return _ret
@staticmethod
def _save_stats(stats: dict, fname: str) -> None:
with h5py.File(fname, "w") as fhand:
for _ename, _epoch in stats.items():
for _key, _dat in _epoch.items():
fhand.create_dataset(
f"{_ename}/{_key}", data=_dat, dtype=_dat.dtype
)
[docs]
def clean(self) -> None:
"""
Clean up all epochs data from backend
"""
self.set_burn(self.burn + self._get_nsamples()[0])
[docs]
@classmethod
def read_epoch(cls, fn: str, iepoch: int = -1) -> dict:
"""
Loads an epoch from a backend `h5` file.
:param fn: file name
:param iepoch: index of the epoch in the file. Note that this is
the overall index in the backend that created the
file. If it is a negative integer `-i` it takes
the `-ith` epoch from the last (ordered by their
epoch index). E.g if `iepoch` is -1 it takes the
last epoch.
:return: epoch data as a dictionnary.
"""
fhand = h5py.File(fn, "r")
if iepoch < 0:
try:
_ename = np.sort(list(fhand.keys()))[iepoch]
except IndexError as exc:
_enum = len(list(fhand.keys()))
raise ValueError(
f"There are only {_enum} epochs in this file."
) from exc
else:
_ename = f"epoch_{iepoch}"
if _ename not in fhand.keys():
raise ValueError(f"No epoch {_ename} in the file.")
def _reader(_path):
if isinstance(fhand[_path], h5py.Dataset):
_ret = cast(h5py.Dataset, fhand[_path])[...][()]
if isinstance(_ret, bytes):
_ret = _ret.decode("utf-8")
return _ret
_dict = {
_key: _reader(_path + "/" + _key)
for _key in cast(h5py.Group, fhand[_path]).keys()
}
try:
assert set(map(int, _dict.keys())) == set(range(len(_dict)))
except (ValueError, AssertionError):
return _dict
# If this can be interpreted as a list... return a list
return list(_dict.values())
_ret = _reader(_ename)
fhand.close()
return _ret
[docs]
def load(
self,
iepoch: int = -1,
eclass: type[Tepoch] | None = None,
pars: list[str] | None = None,
) -> Tepoch:
"""
Loads an epoch from a backend `h5` file.
:param iepoch: index of the epoch in the file. Note that this is
the overall index in the backend that created the
file. If it is a negative integer `-i` it takes
the `-ith` epoch from the last (ordered by their
epoch index). E.g if `iepoch` is -1 it takes the
last epoch.
:param eclass: epoch class.
:param pars: samples parameters to be loaded. Default: all.
:return: the loaded epoch.
"""
_epoch = self.get(iepoch)
if eclass is None:
_path = _epoch["class"].split(".")
_ecl = cast(
type[Tepoch], getattr(import_module(".".join(_path[:-1])), _path[-1])
)
else:
_ecl = cast(type[Tepoch], eclass)
pars = list(_ecl.statecls.__dataclass_fields__) if pars is None else pars
return _ecl(
{_key: _epoch["samples"][_key] for _key in pars} | {"covs": _epoch["covs"]}
)
[docs]
@staticmethod
def load_epoch(
fn: str,
iepoch: int = -1,
eclass: type[Tepoch] | None = None,
pars: list[str] | None = None,
) -> Tepoch:
"""
Loads an epoch from a backend `h5` file.
:param fn: file name
:param iepoch: index of the epoch in the file. Note that this is
the overall index in the backend that created the
file. If it is a negative integer `-i` it takes
the `-ith` epoch from the last (ordered by their
epoch index). E.g if `iepoch` is -1 it takes the
last epoch.
:param eclass: epoch class.
:param pars: samples parameters to be loaded. Default: all.
:return: the loaded epoch.
"""
return (
DefaultBackend(inmem_epochs=1) # type: ignore[arg-type]
.ingest(fn, eind=iepoch)
.load(eclass=eclass, pars=pars) # type: ignore[arg-type]
)
[docs]
def reset(self):
"""Remove all existing files"""
fmt = self._config["outfile_fmt"].replace("%d", "%s")
fnames = glob.glob(os.path.join(self._config["outdir"], fmt % "*"))
for fname in fnames:
os.remove(fname)