Source code for jexplore.tools.diagnostic
"""Diagnostic tool for assessing the convergence of the MCMC
sampling.
"""
import emcee
import jax.numpy as jnp
[docs]
def gelman_rubin_statistic(samples, ntemp=1):
"""Return square root of Gelman-Rubin statistic R.
Assess the convergence by comparing the variance between the
chains and the variance in the chains.
:param array samples: (nchain, dim, npoints) array with nchain
equal to nwalker if ntemp=1 (default) or nchain equal nwalker x
ntemp.
:param int ntemp: return R per temperature is ntemp > 1
:return: (ntemp, ndim) array of the statistic for each dimension
(and temperature)
:rtype: array
"""
ndim, nsamples = samples.shape[1:]
if ntemp > 1:
nwalker = samples.shape[0] // ntemp
_samples = samples.reshape(nwalker, ntemp, ndim, nsamples)
_samples = jnp.permute_dims(_samples, [3, 0, 1, 2])
ll, jj, _, _ = _samples.shape
else:
_samples = jnp.permute_dims(_samples, [2, 0, 1])
ll, jj, __ = _samples.shape
c_mean = jnp.mean(_samples, axis=0) # mean value of chain i
g_mean = jnp.mean(c_mean, axis=0) # global mean
within = jnp.var(_samples, axis=0, ddof=1) # intra chain variance
between = (
ll / (jj - 1) * jnp.sum((c_mean - g_mean) ** 2, axis=0)
) # variance between chains_eq
w = jnp.mean(within, axis=0) # averaged variance
r = ((ll - 1) / ll) * w + between * (1 / ll)
r = r / w
return jnp.sqrt(r)
[docs]
def auto_correlation_length(samples):
"""Return auto correlation length from emcee.
:param array samples: (nchain, dim, npoints) array with nchain
equal to nwalker
:return:
:rtype: array
"""
_samples = jnp.permute_dims(samples, [2, 0, 1]) # nsteps, nwalker, nparam
try:
res = emcee.autocorr.integrated_time(_samples)
except emcee.AutocorrError:
res = np.nan
return res