MDL (minimum description length) utility functions#
\(\require{mathtools} \newcommand{\notag}{} \newcommand{\tag}{} \newcommand{\label}[1]{} \newcommand{\sfrac}[2]{#1/#2} \newcommand{\bm}[1]{\boldsymbol{#1}} \newcommand{\num}[1]{#1} \newcommand{\qty}[2]{#1\,#2} \renewenvironment{align} {\begin{aligned}} {\end{aligned}} \renewenvironment{alignat} {\begin{alignedat}} {\end{alignedat}} \newcommand{\pdfmspace}[1]{} % Ignore PDF-only spacing commands \newcommand{\htmlmspace}[1]{\mspace{#1}} % Ignore PDF-only spacing commands \newcommand{\scaleto}[2]{#1} % Allow to use scaleto from scalerel package \newcommand{\RR}{\mathbb R} \newcommand{\NN}{\mathbb N} \newcommand{\PP}{\mathbb P} \newcommand{\EE}{\mathbb E} \newcommand{\XX}{\mathbb X} \newcommand{\ZZ}{\mathbb Z} \newcommand{\QQ}{\mathbb Q} \newcommand{\fF}{\mathcal F} \newcommand{\dD}{\mathcal D} \newcommand{\lL}{\mathcal L} \newcommand{\gG}{\mathcal G} \newcommand{\hH}{\mathcal H} \newcommand{\nN}{\mathcal N} \newcommand{\pP}{\mathcal P} \newcommand{\BB}{\mathbb B} \newcommand{\Exp}{\operatorname{Exp}} \newcommand{\Binomial}{\operatorname{Binomial}} \newcommand{\Poisson}{\operatorname{Poisson}} \newcommand{\linop}{\mathcal{L}(\mathbb{B})} \newcommand{\linopell}{\mathcal{L}(\ell_1)} \DeclareMathOperator{\trace}{trace} \DeclareMathOperator{\Var}{Var} \DeclareMathOperator{\Span}{span} \DeclareMathOperator{\proj}{proj} \DeclareMathOperator{\col}{col} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\gt}{>} \definecolor{highlight-blue}{RGB}{0,123,255} % definition, theorem, proposition \definecolor{highlight-yellow}{RGB}{255,193,7} % lemma, conjecture, example \definecolor{highlight-orange}{RGB}{253,126,20} % criterion, corollary, property \definecolor{highlight-red}{RGB}{220,53,69} % criterion \newcommand{\logL}{\ell} \newcommand{\eE}{\mathcal{E}} \newcommand{\oO}{\mathcal{O}} \newcommand{\defeq}{\stackrel{\mathrm{def}}{=}} \newcommand{\Bspec}{\mathcal{B}} % Spectral radiance \newcommand{\X}{\mathcal{X}} % X space \newcommand{\Y}{\mathcal{Y}} % Y space \newcommand{\M}{\mathcal{M}} % Model \newcommand{\Tspace}{\mathcal{T}} \newcommand{\Vspace}{\mathcal{V}} \newcommand{\Mtrue}{\mathcal{M}_{\mathrm{true}}} \newcommand{\MP}{\M_{\mathrm{P}}} \newcommand{\MRJ}{\M_{\mathrm{RJ}}} \newcommand{\qproc}{\mathfrak{Q}} \newcommand{\D}{\mathcal{D}} % Data (true or generic) \newcommand{\Dt}{\tilde{\mathcal{D}}} \newcommand{\Phit}{\widetilde{\Phi}} \newcommand{\Phis}{\Phi^*} \newcommand{\qt}{\tilde{q}} \newcommand{\qs}{q^*} \newcommand{\qh}{\hat{q}} \newcommand{\AB}[1]{\mathtt{AB}~\mathtt{#1}} \newcommand{\LP}[1]{\mathtt{LP}~\mathtt{#1}} \newcommand{\NML}{\mathrm{NML}} \newcommand{\iI}{\mathcal{I}} \newcommand{\true}{\mathrm{true}} \newcommand{\dist}{D} \newcommand{\Mtheo}[1]{\mathcal{M}_{#1}} % Model (theoretical model); index: param set \newcommand{\DL}[1][L]{\mathcal{D}^{(#1)}} % Data (RV or generic) \newcommand{\DLp}[1][L]{\mathcal{D}^{(#1')}} % Data (RV or generic) \newcommand{\DtL}[1][L]{\tilde{\mathcal{D}}^{(#1)}} % Data (RV or generic) \newcommand{\DpL}[1][L]{{\mathcal{D}'}^{(#1)}} % Data (RV or generic) \newcommand{\Dobs}[1][]{\mathcal{D}_{\mathrm{obs}}^{#1}} % Data (observed) \newcommand{\calibset}{\mathcal{C}} \newcommand{\N}{\mathcal{N}} % Normal distribution \newcommand{\Z}{\mathcal{Z}} % Partition function \newcommand{\VV}{\mathbb{V}} % Variance \newcommand{\T}{\mathsf{T}} % Transpose \newcommand{\EMD}{\mathrm{EMD}} \newcommand{\dEMD}{d_{\mathrm{EMD}}} \newcommand{\dEMDtilde}{\tilde{d}_{\mathrm{EMD}}} \newcommand{\dEMDsafe}{d_{\mathrm{EMD}}^{\text{(safe)}}} \newcommand{\e}{ε} % Model confusion threshold \newcommand{\falsifythreshold}{ε} \newcommand{\bayes}[1][]{B_{#1}} \newcommand{\bayesthresh}[1][]{B_{0}} \newcommand{\bayesm}[1][]{B^{\mathcal{M}}_{#1}} \newcommand{\bayesl}[1][]{B^l_{#1}} \newcommand{\bayesphys}[1][]{B^{{p}}_{#1}} \newcommand{\Bconf}[1]{B^{\mathrm{epis}}_{#1}} \newcommand{\Bemd}[1]{B^{\mathrm{EMD}}_{#1}} \newcommand{\Bconfbin}[1][]{\bar{B}^{\mathrm{conf}}_{#1}} \newcommand{\Bemdbin}[1][]{\bar{B}_{#1}^{\mathrm{EMD}}} \newcommand{\bin}{\mathcal{B}} \newcommand{\Bconft}[1][]{\tilde{B}^{\mathrm{conf}}_{#1}} \newcommand{\fc}{f_c} \newcommand{\fcbin}{\bar{f}_c} \newcommand{\paramphys}[1][]{Θ^{{p}}_{#1}} \newcommand{\paramobs}[1][]{Θ^{ε}_{#1}} \newcommand{\test}{\mathrm{test}} \newcommand{\train}{\mathrm{train}} \newcommand{\synth}{\mathrm{synth}} \newcommand{\rep}{\mathrm{rep}} \newcommand{\MNtrue}{\mathcal{M}^{{p}}_{\text{true}}} \newcommand{\MN}[1][]{\mathcal{M}^{{p}}_{#1}} \newcommand{\MNA}{\mathcal{M}^{{p}}_{Θ_A}} \newcommand{\MNB}{\mathcal{M}^{{p}}_{Θ_B}} \newcommand{\Me}[1][]{\mathcal{M}^ε_{#1}} \newcommand{\Metrue}{\mathcal{M}^ε_{\text{true}}} \newcommand{\Meobs}{\mathcal{M}^ε_{\text{obs}}} \newcommand{\Meh}[1][]{\hat{\mathcal{M}}^ε_{#1}} \newcommand{\MNa}{\mathcal{M}^{\mathcal{N}}_a} \newcommand{\MeA}{\mathcal{M}^ε_A} \newcommand{\MeB}{\mathcal{M}^ε_B} \newcommand{\Ms}{\mathcal{M}^*} \newcommand{\MsA}{\mathcal{M}^*_A} \newcommand{\MsB}{\mathcal{M}^*_B} \newcommand{\Msa}{\mathcal{M}^*_a} \newcommand{\MsAz}{\mathcal{M}^*_{A,z}} \newcommand{\MsBz}{\mathcal{M}^*_{B,z}} \newcommand{\Msaz}{\mathcal{M}^*_{a,z}} \newcommand{\MeAz}{\mathcal{M}^ε_{A,z}} \newcommand{\MeBz}{\mathcal{M}^ε_{B,z}} \newcommand{\Meaz}{\mathcal{M}^ε_{a,z}} \newcommand{\zo}{z^{0}} \renewcommand{\lL}[2][]{\mathcal{L}_{#1|{#2}}} % likelihood \newcommand{\Lavg}[2][]{\mathcal{L}^{/#2}_{#1}} % Geometric average of likelihood \newcommand{\lLphys}[2][]{\mathcal{L}^{{p}}_{#1|#2}} \newcommand{\Lavgphys}[2][]{\mathcal{L}^{{p}/#2}_{#1}} % Geometric average of likelihood \newcommand{\lLL}[3][]{\mathcal{L}^{(#3)}_{#1|#2}} \newcommand{\lLphysL}[3][]{\mathcal{L}^{{p},(#3)}_{#1|#2}} \newcommand{\lnL}[2][]{l_{#1|#2}} % Per-sample log likelihood \newcommand{\lnLt}[2][]{\widetilde{l}_{#1|#2}} \newcommand{\lnLtt}{\widetilde{l}} % Used only in path_sampling \newcommand{\lnLh}[1][]{\hat{l}_{#1}} \newcommand{\lnLphys}[2][]{l^{{p}}_{#1|#2}} \newcommand{\lnLphysL}[3][]{l^{{p},(#3)}_{#1|#2}} \newcommand{\Elmu}[2][1]{μ_{{#2}}^{(#1)}} \newcommand{\Elmuh}[2][1]{\hat{μ}_{{#2}}^{(#1)}} \newcommand{\Elsig}[2][1]{Σ_{{#2}}^{(#1)}} \newcommand{\Elsigh}[2][1]{\hat{Σ}_{{#2}}^{(#1)}} \newcommand{\pathP}{\mathop{{p}}} % Path-sampling process (generic) \newcommand{\pathPhb}{\mathop{{p}}_{\mathrm{Beta}}} % Path-sampling process (hierarchical beta) \newcommand{\interval}{\mathcal{I}} \newcommand{\Phiset}[1]{\{\Phi\}^{\small (#1)}} \newcommand{\Phipart}[1]{\{\mathcal{I}_Φ\}^{\small (#1)}} \newcommand{\qhset}[1]{\{\qh\}^{\small (#1)}} \newcommand{\Dqpart}[1]{\{Δ\qh_{2^{#1}}\}} \newcommand{\LsAzl}{\mathcal{L}_{\smash{{}^{\,*}_A},z,L}} \newcommand{\LsBzl}{\mathcal{L}_{\smash{{}^{\,*}_B},z,L}} \newcommand{\lsA}{l_{\smash{{}^{\,*}_A}}} \newcommand{\lsB}{l_{\smash{{}^{\,*}_B}}} \newcommand{\lsAz}{l_{\smash{{}^{\,*}_A},z}} \newcommand{\lsAzj}{l_{\smash{{}^{\,*}_A},z_j}} \newcommand{\lsAzo}{l_{\smash{{}^{\,*}_A},z^0}} \newcommand{\leAz}{l_{\smash{{}^{\,ε}_A},z}} \newcommand{\lsAez}{l_{\smash{{}^{*ε}_A},z}} \newcommand{\lsBz}{l_{\smash{{}^{\,*}_B},z}} \newcommand{\lsBzj}{l_{\smash{{}^{\,*}_B},z_j}} \newcommand{\lsBzo}{l_{\smash{{}^{\,*}_B},z^0}} \newcommand{\leBz}{l_{\smash{{}^{\,ε}_B},z}} \newcommand{\lsBez}{l_{\smash{{}^{*ε}_B},z}} \newcommand{\LaszL}{\mathcal{L}_{\smash{{}^{*}_a},z,L}} \newcommand{\lasz}{l_{\smash{{}^{*}_a},z}} \newcommand{\laszj}{l_{\smash{{}^{*}_a},z_j}} \newcommand{\laszo}{l_{\smash{{}^{*}_a},z^0}} \newcommand{\laez}{l_{\smash{{}^{ε}_a},z}} \newcommand{\lasez}{l_{\smash{{}^{*ε}_a},z}} \newcommand{\lhatasz}{\hat{l}_{\smash{{}^{*}_a},z}} \newcommand{\pasz}{p_{\smash{{}^{*}_a},z}} \newcommand{\paez}{p_{\smash{{}^{ε}_a},z}} \newcommand{\pasez}{p_{\smash{{}^{*ε}_a},z}} \newcommand{\phatsaz}{\hat{p}_{\smash{{}^{*}_a},z}} \newcommand{\phateaz}{\hat{p}_{\smash{{}^{ε}_a},z}} \newcommand{\phatseaz}{\hat{p}_{\smash{{}^{*ε}_a},z}} \newcommand{\Phil}[2][]{Φ_{#1|#2}} % Φ_{\la} \newcommand{\Philt}[2][]{\widetilde{Φ}_{#1|#2}} % Φ_{\la} \newcommand{\Philhat}[2][]{\hat{Φ}_{#1|#2}} % Φ_{\la} \newcommand{\Philsaz}{Φ_{\smash{{}^{*}_a},z}} % Φ_{\lasz} \newcommand{\Phileaz}{Φ_{\smash{{}^{ε}_a},z}} % Φ_{\laez} \newcommand{\Philseaz}{Φ_{\smash{{}^{*ε}_a},z}} % Φ_{\lasez} \newcommand{\mus}[1][1]{μ^{(#1)}_*} \newcommand{\musA}[1][1]{μ^{(#1)}_{\smash{{}^{\,*}_A}}} \newcommand{\SigsA}[1][1]{Σ^{(#1)}_{\smash{{}^{\,*}_A}}} \newcommand{\musB}[1][1]{μ^{(#1)}_{\smash{{}^{\,*}_B}}} \newcommand{\SigsB}[1][1]{Σ^{(#1)}_{\smash{{}^{\,*}_B}}} \newcommand{\musa}[1][1]{μ^{(#1)}_{\smash{{}^{*}_a}}} \newcommand{\Sigsa}[1][1]{Σ^{(#1)}_{\smash{{}^{*}_a}}} \newcommand{\Msah}{{\color{highlight-red}\mathcal{M}^{*}_a}} \newcommand{\Msazh}{{\color{highlight-red}\mathcal{M}^{*}_{a,z}}} \newcommand{\Meah}{{\color{highlight-blue}\mathcal{M}^{ε}_a}} \newcommand{\Meazh}{{\color{highlight-blue}\mathcal{M}^{ε}_{a,z}}} \newcommand{\lsazh}{{\color{highlight-red}l_{\smash{{}^{*}_a},z}}} \newcommand{\leazh}{{\color{highlight-blue}l_{\smash{{}^{ε}_a},z}}} \newcommand{\lseazh}{{\color{highlight-orange}l_{\smash{{}^{*ε}_a},z}}} \newcommand{\Philsazh}{{\color{highlight-red}Φ_{\smash{{}^{*}_a},z}}} % Φ_{\lasz} \newcommand{\Phileazh}{{\color{highlight-blue}Φ_{\smash{{}^{ε}_a},z}}} % Φ_{\laez} \newcommand{\Philseazh}{{\color{highlight-orange}Φ_{\smash{{}^{*ε}_a},z}}} % Φ_{\lasez} \newcommand{\emdstd}{\tilde{σ}} \DeclareMathOperator{\Mvar}{Mvar} \DeclareMathOperator{\AIC}{AIC} \DeclareMathOperator{\epll}{epll} \DeclareMathOperator{\elpd}{elpd} \DeclareMathOperator{\MDL}{MDL} \DeclareMathOperator{\comp}{COMP} \DeclareMathOperator{\Lognorm}{Lognorm} \DeclareMathOperator{\erf}{erf} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator{\Image}{Image} \DeclareMathOperator{\sgn}{sgn} \DeclareMathOperator{\SE}{SE} % standard error \DeclareMathOperator{\Unif}{Unif} \DeclareMathOperator{\Poisson}{Poisson} \DeclareMathOperator{\SkewNormal}{SkewNormal} \DeclareMathOperator{\TruncNormal}{TruncNormal} \DeclareMathOperator{\Exponential}{Exponential} \DeclareMathOperator{\exGaussian}{exGaussian} \DeclareMathOperator{\IG}{IG} \DeclareMathOperator{\NIG}{NIG} \DeclareMathOperator{\Gammadist}{Gamma} \DeclareMathOperator{\Lognormal}{Lognormal} \DeclareMathOperator{\Beta}{Beta} \newcommand{\sinf}{{s_{\infty}}}\)
Specifically we compute MDL measures using normalized maximum likelihood, which for MDL is usually the preferred choice. (Although almost always an approximation is used to make it tractable.)
In NML parlance, a model contains free parameters, which can be fitted to data. This corresponds to a hypothesis class in learning theory. Model comparison within the MDL framework is in fact a comparison of hypothesis class, rather than a comparison of specific hypotheses; this is not dissimilar from other common methods like Bayes factors, WAIC, and even information criteria. However it is distinct from \(\Bemd{}\), which seeks to compare specific hypotheses.
More complex models tend to fit data better simply by virtue of having a larger hypothesis class. Model comparison methods therefore seek to penalize more complex models, to avoid overfitting. In some sense the main differentiator between these methods is how they compute the complexity of a model.
For model evidence (the normalization factor in a Bayesian posterior), model complexity is related to the volume of the prior. Specifically, model evidence \(\eE\) is computed as
This quantity is maximized when the prior \(π\) is concentrated on those parameters \(θ\) which fit the data best. Conversely, a broader prior may afford more flexibility to find a parameter which fits the data, but will lower the model evidence.
The normalized maximum likelihood (NML) takes a different approach: instead of a posterior, one uses the maximum likelihood fit of the model on those data:
In a sense, this is a more aggressive hypothesis (MAP instead of posterior), and therefore it needs a more aggressive normalization than the model evidence. In MDL this is achieved by integrating over all possible datasets to obtain what is called the model complexity:
Note here that the integral over \(z^n\) is agnostic to what the data actually look like. So whereas a method like cross-validation will focus on plausible variations of the dataset, NML penalizes all dataset variations equally.
The main challenge of MDL is usually to compute the complexity term \(\comp(\M, π)\), since the integral over all datasets is almost never tractable. In the code below, we exploit two particular features of our data to make this feasible:
The Poisson noise means that possible values are discretized. Technically there are infinitely many possible values, but their probability quickly vanishes, so we need to consider only a dozen or so.
Fixed, exact \(λ\) values: for a given dataset size \(L\), we always generate data with the abscissa \(λ_j = μλ_{\mathrm{min}} + j \frac{λ_{\mathrm{max}} - λ_{\mathrm{min}}}{L-1}\), where \(j=0, \dotsc, L-1\).
Taken together, these features make it just about possible to enumerate the possible datasets, and thereby calculate \(\comp(\M, π)\) exactly. Concretely we implement enumeration as follows:
For each \(λ_j\), determine a set of possible values for \(\Bspec_j\).
Sort those values according to their likelihood. Values below a certain threshold are discarded, so that each set \(\lvert\Bspec_j\rvert\) is finite.
A dataset is generated by picking a random value from each set: \(\D_{\vec{i}} = \{(λ_j, \Bspec_{ji_j}) : 0 \leq i_j < \lvert \{\Bspec_j\} \rvert\}\).
The total number of datasets is therefore \(\prod_{j=0}^{L-1} \lvert \{\Bspec_j\} \rvert\).
Moreover, we can identify each dataset with its index tuple \(\vec{i} = (i_0, i_1, \dotsc, i_{L-1})\).
Fig. 1 Schematic of the dataset sampling algorithm.#
Note that because a) each set is ordered, and b) the likelihood factorizes, if we generate index tuples lexicographically
(0, …, 0, 0) # Total: 0
(0, …, 0, 1) # Total: 1
(0, …, 1, 0) # Total: 1
(1, …, 0, 0) # Total: 1
(0, …, 0, 2) # Total: 2
(0, …, 1, 1) # Total: 2
(1, …, 0, 1) # Total: 2
(0, …, 2, 0) # Total: 2
...
then the corresponding datasets should be approximately ordered according to their likelihood. Moreover, datasets with the same “index total” will have more similar likelihoods.
Note
Those familiar with Lagrangian mechanics may find the following picture helpful. What we are doing is akin to starting from the function which maximizes the overall likelihood (index (0, 0, …, 0)
). Then iterating through all possible variational changes, from smallest (just one impulse \(δ_{ij}\) at one location) to largest (many impulses at many locations).
In practice there are still an unfeasibly large number of datasets – it would take years to generate all possible combinations of \(\{\Bspec_{ji_j}\}\). Therefore instead of generating them all, we group them according to their total index – which we can check correlates strongly with their fitted likelihood. There are only \(K := \sum_{j=1}^L \lvert \{\Bspec_j\} \rvert \sim 1000\) such groups, and we can estimate the average likelihood within each with \(r \sim 30\) exemplars. Thus if we decompose the complexity into a sum over “index total” classes \(\iI_k\), we can get a reasonable estimate from only about 30,000 sample datasets – something that can be done in a few minutes.
(In the expression, \(\displaystyle \sum_{\vec{i}\sim \iI_k}^r\) means to draw \(r\) samples \(\vec{i}\) from \(\iI_k\).)
import logging
import sys
import itertools
import multiprocessing # Only used to detect if we are a child process, and offset progress bars
from collections.abc import Callable, Generator
from dataclasses import dataclass
from functools import wraps, cache
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
import math
import numpy as np
from scipy import stats
from scipy.special import logsumexp
#from scipy.special import binom, gammaln
#def logbinom(M, N): return gammaln(M+1) - gammaln(N+1) - gammaln(M-N+1)
#def polytopic(d, n): return binom(d + (n-1), d)
from Ex_UV import memory, Dataset, DataModel, Bunits
Utility: support for big integers.
For moderate-sized datasets, we can easily run into multiplicities so large they can’t even be stored as double-precision floats. Therefore we need a function to compute the log of arbitrarily large Python integers (math.log
doesn’t work because it first casts its argument as a float). We do this by using integer division to reduce the value until it fits in a float:
where the approximation holds as long as \(\bigl\lfloor\frac{x}{2^b}\bigr\rfloor \gg 1\). CPython uses 30 bits per digit for its arbitrary-precision integers, so below we will use \(b=30\). We can stop dividing once the value fits into a float, i.e. when \(x \lessapprox 10^{30}\), which is plenty large to ensure \(\bigl\lfloor\frac{x}{2^b}\bigr\rfloor \gg 1\).
As a final optimization note, we use bitshifts (x >> 30
) instead of integer division (x // 2**30
), which are about 8x faster.
def log(x, _convert_to_base_e=math.log2(math.exp(1)), _max_bit_length=sys.float_info.max_exp-1):
d = max(0, x.bit_length() - _max_bit_length)
return d/_convert_to_base_e + math.log(x >> d)
assert math.isclose(log(int(1e100), _max_bit_length=30),
math.log(1e100))
large_int = 8764497313364666504477093928516953288566469882762807028232063745072159479579895458028096313509522235728747629335545934701428474028035170831623118616342985187161325731829558936866587537745728559685870721919847387389580294277953449931849502390335763437286213073335784806230330753580352813054111832792700469841406590733261958771123
log(large_int) # This would fail with stdlib log()
assert log(1) == 0
assert log(400) == math.log(400)
Generating index tuples with a given total index \(k\)#
Generating all index tuples#
all indices with a given total index;
all indices, ordered by their total index.
def gen_idcs_with_total(sizes, tot_index) -> Generator[tuple[int]]:
"""
Yield all indices for an array of size `sizes`
which sum up to `tot_index`.
"""
sizes = np.asarray(sizes, dtype=int)
max_idx = sizes - 1
remaining_index_space = np.cumsum(max_idx[::-1])[::-1] # The amount of combined "index space" remaining at positions ⩾ j
yield from _gen_idcs_with_total(sizes, tot_index, remaining_index_space)
def _gen_idcs_with_total(sizes, tot_index, remaining_index_space=None) -> Generator[tuple[int]]:
if len(sizes) == 1:
if tot_index < sizes[0]:
yield (tot_index,)
else:
j = len(sizes) // 2
remaining_left = remaining_index_space[:j] - remaining_index_space[j]
remaining_right = remaining_index_space[j:]
for w in range(max(0,tot_index-remaining_right[0]), min(remaining_left[0]+1, tot_index+1)):
for left_idx, right_idx in itertools.product(
_gen_idcs_with_total(sizes[:j], w, remaining_left),
_gen_idcs_with_total(sizes[j:], tot_index-w, remaining_right)):
yield left_idx + right_idx
Simpler version
This version has easier to follow recursion, but it hits the recursion limit much faster because each recursion level reduces the length of sizes
only by 1. It is also substantially slower.
def gen_idcs_with_total(sizes, tot_index, remaining_index_space=None) -> Generator[tuple[int]]:
"""
Yield all indices for an array of size `sizes`
which sum up to `tot_index`.
"""
sizes = np.asarray(sizes, dtype=int)
if remaining_index_space is None:
max_idx = sizes - 1
remaining_index_space = max_idx.sum() - np.cumsum(max_idx) # The amount of combined "index space" we have remaining at positions to the right of i
if len(sizes) == 1:
if tot_index < sizes[0]:
yield (tot_index,)
else:
for w in range(max(0,tot_index-remaining_index_space[0]), min(sizes[0], tot_index+1)):
for nested_ws in gen_idcs_with_total(sizes[1:], tot_index-w, remaining_index_space[1:]):
yield (w, *nested_ws)
def gen_idcs_by_total(sizes) -> Generator[tuple[int]]:
"""
Return a generator which acts like a nested `range` on `sizes`,
with the important difference that the returned indices are sorted
according to their total instead of lexicographically.
So all “low” indices are returned first.
For example, the iterator
gen_idcs_by_total([3, 5, 2])
generates indices in the following order:
(0, 0, 0) # Total 0
(0, 0, 1) # Total 1
(0, 1, 0) # Total 1
(1, 0, 0) # Total 1
(0, 0, 2) # Total 2
(0, 1, 1) # Total 2
...
(2, 3, 0) # Total 5
(1, 4, 1) # Total 6
(2, 3, 1) # Total 6
(2, 4, 0) # Total 6
(2, 4, 1) # Total 7
For a fixed total, index tuples are sorted lexicographically.
"""
sizes = np.asarray(sizes, dtype=int)
#tot_size = int(np.prod(sizes.astype(float))) # Even with floats this can overflow
max_idx = sizes - 1
for k in range(max_idx.sum()):
yield from gen_idcs_with_total(sizes, k)
Generating exemplar index tuples#
The “evenly distributed” index, where the index \(i_j\) is approximately proportional to the size \(\lvert \Bspec_j \rvert\). For the example above, this would return \((1, 2, 1)\) and \((1, 2, 2)\) for \(k=4\) and \(k=5\) respectively.
Sample one index tuples uniformly from all index tuples that have total \(k\).
Sample multiple random index tuples without replacement.
def get_unif_idx(sizes, tot_index) -> np.ndarray[int]:
"""
Return an index tuple which sums to `tot_index` and
is approximately proportional to `sizes`. (So that
index values are all about “equally far” from 0.)
"""
sizes = np.asarray(sizes)
idx = np.round(tot_index/sizes.sum() * sizes) \
.astype(int) \
.clip(0, sizes-1)
for i in np.argsort(sizes)[::-1]:
if idx.sum() < tot_index:
if idx[i] < sizes[i]-1:
idx[i] += 1
elif idx.sum() > tot_index:
if idx[i] > 0:
idx[i] -= 1
else:
break
else:
raise AssertionError("Failed to find a valid index tuple for sizes\n"
f" {sizes}\nand total index {tot_index}.")
assert idx.sum() == tot_index, "Returned index tuple does not have the correct sum."
return idx
def get_rnd_idx(sizes, tot_index, rng=None) -> tuple[int]:
"""
Return an index tuple which sums to `tot_index`,
selected with uniform probability over all such index tuples.
"""
rng = np.random.default_rng(rng)
idx = np.zeros(len(sizes), dtype=int)
idx_space = sizes-1 # Number of “slots” available to draw at each index position
for _ in range(tot_index):
i = rng.choice(len(sizes), p=idx_space/idx_space.sum())
idx[i] += 1
idx_space[i] -= 1
return tuple(idx)
def get_multiple_rnd_idcs(sizes, tot_index, num, rng=None) \
-> set[tuple[int]]:
"""
Call `get_rnd_idx` `num` times to get multiple index tuples.
The number of returned values may be less, if `num` is
larger than the total number of possible index tuples.
The returned index tuples are guaranteed to be distinct.
"""
rng = np.random.default_rng(rng)
with np.errstate(over="ignore"): # If we get overflow, there are *definitely* more index tuples than `tot_index`
if index_multiplicity(tuple(sizes), tot_index) < num:
# No point in generating them randomly if we will generate them all
return list(gen_idcs_with_total(sizes, tot_index))
# We use a set to ensure no duplicates
# The probability of duplicates should be negligible, unless num is close to the multiplicity
idx_tuples = set()
while len(idx_tuples) < num:
idx_tuples.add(get_rnd_idx(sizes, tot_index, rng))
return idx_tuples
Computing the multplicity of an index class \(\iI_k\).#
This is equivalent to counting the number of ways of placing \(k\) balls into \(L\) bins, each with capacity \(\lvert \Bspec_j \rvert\).
Important
Notation For brevity, this section defines \(s_j = \lvert \Bspec_j \rvert\) for the number of possible values at \(λ_j\). The ordered set of all sizes is \(\vec{s} := (s_1, \dotsc, s_{L})\). The multiplicity of an index class is \(m(\vec{s}, k) := \lvert \iI_k \rvert\).
Note
Multiplicity sequences are symmetric
We can halve the required function evaluations (and thus cache memory) with one simple trick™: if \(m\) is the index_multiplicity
function, then we always have
This is easiest to see by realizing that an equivalent way to frame the problem is to start with full bins and count the number of ways of removing \(k\) balls.
Note that since each index runs from \(0\) to \(s_j -1\) inclusive, the maximum value \(k\) can take is \(\sum_j s_j - L\). Thus Eq. 3 implies that the sequence of multiplicities
is symmetric.
Simple recursive algorithm#
# Implementation of the basic idea, but without important optimizations
@cache
def index_multiplicity_simple(sizes: tuple[int,...], tot_index: int) -> int:
if len(sizes) == 1:
return 1 if tot_index < sizes[0] else 0 # Same time as the “fancier” int(tot_index < sizes[0])
return sum(index_multiplicity_simple(sizes[1:], tot_index-i)
for i in range(min(sizes[0], tot_index+1)))
We can improve on the *_simple
solution in a few ways:
A guard against recursion errors, which pre-computes values if necessary to populate the cache. Otherwise an exception is raised when len(sizes) ⩾ 1000.
Exploit the remark above to avoid computing and caching half of the tuples.
Managing the cache ourselves allows to normalize
tot_index
at the top of the function call, before checking the cache.
Precomputation of the reverse cumulative sum.
This Identify when only one index combination is position and terminate the recursion early.
The runtime of this function is very dependent on the value of \(k\).
At a recursion level \(r\), it spawns (in expectation) \(\sim (k-r)\) evaluations at the next level. So the runtime (and cache memory requirements) grows almost exponentially in \(L\).
On the one hand, this means that when \(k\) is close to zero or \(\sum_i s_j - L\), index_multiplicity_recursive
terminates in milliseconds after just a few recursions. This is true even for large dimensions \(L \gtrsim 1000\).
On the other hand, for \(100 \lesssim k \lnapprox \sum_i s_j - L\), the function might run for days and use many GBs of RAM for the cache
Note
The basic idea of this algorithm is the same as the “Simple Divide and Conquer” approach described by Glück and Köppl, albeit the version here is even more naive. Glück and Köppl use a more clever scheme for splitting the size tuple, which likely results in fewer recursion steps than what we do here, although I have not checked this. (The “Divide and conquer” approach is already mention to in their earlier paper (Glück et al., 2013), albeit with a less explicit description.)
def index_multiplicity_recursive(sizes: tuple[int,...], tot_index: int) -> int:
"""
Args:
sizes: Maximum value of the index at each position.
tot_index: Only count index tuples which sum to this.
"""
neg_cum_sizes = tuple((np.array(sizes)-1)[::-1].cumsum()[::-1].tolist())
return _index_multiplicity_recursive(sizes, tot_index, neg_cum_sizes)
# Below we manage the function cache ourselves, instead of using @cache
# This allows us to ignore `neg_cum_sizes` as a caching argument, and
# to convert `tot_index` to its symmetric equivalent before querying the cache
def _index_multiplicity_recursive(sizes: tuple[int,...], tot_index: int,
neg_cum_sizes: tuple[int,...]=None, mycache={}
)-> int:
"""
Args:
sizes: Maximum value of the index at each position.
tot_index: Only count index tuples which sum to this.
neg_cum_sizes: The remaining available index space, including the current bin.
Todo:
For large lists of sizes, indexing a NumPy array can be ~5x faster than
indexing the equivalent tuple. Is there a way to pass `sizes` as an array,
while still preserving the cache (which requires hashable args)?
"""
# Use symmetry to skip half of the computations
if tot_index > neg_cum_sizes[0] // 2: # This is 3x faster than min(tot_index, neg_cum_sizes[0]-tot_index)
tot_index = neg_cum_sizes[0] - tot_index
# Check the cache
res = mycache.get((sizes, tot_index))
if res is not None:
return res
# Exit condition for the recursion
if tot_index < 0:
return 0
elif tot_index == 0 or len(sizes) == 1:
return 1
# Avoid RecursionError by precomputing smaller array sizes and putting them in cache.
# We don’t want to precompute too much, because it does add a bit of overhead – with chunks of 400, overhead is ~6%
chunksize = 400
if len(sizes) > chunksize:
#for i in range(min(sum(sizes[chunksize:]), tot_index+1)):
for i in range(min(neg_cum_sizes[chunksize]+1, tot_index+1)):
_index_multiplicity(sizes[chunksize:], tot_index-i, neg_cum_sizes[chunksize:])
# Now perform the actual calculation with recursion
res = sum(_index_multiplicity(sizes[1:], tot_index-i, neg_cum_sizes[1:])
for i in range(max(0, tot_index-neg_cum_sizes[1]),
min(sizes[0], tot_index+1))
)
mycache[(sizes, tot_index)] = res
return res
An inferior solution: Counting indices
Generating (and counting) all possible indices is orders of magnitude slower than even the recursive solution, even though almost the entire iteraton takes place within a single set comprehension. (At therefore runs at near C speeds.) This implementation also tends to crashes the kernel with large size tuples :-/
def index_mult_enumerate(sizes, k):
"""
IMPORTANT: `sizes` must be sorted from highest to lowest.
"""
sizes = tuple(sorted(sizes, reverse=True))
sizes_m1 = tuple(sj-1 for sj in sizes)
iset = {(k, *(0,)*(len(sizes)-1))}
# Iterate through the "top" indices. We don’t count them because they
# are invalid (i0 > s1), but this serves to generate all indices with i0=s1
for _ in range(k-sizes_m1[0]):
iset = set((itup[0]-1, *itup[1:j], itup[j]+1, *itup[j+1:])
for itup in iset for j in range(1, len(itup))
if itup[0]>0 and itup[j]<sizes_m1[j]) # Ensures only valid indices are generated
# Now we start counting
count = len(iset) # The current indices in iset are valid
while iset:
iset = set((itup[0]-1, *itup[1:j], itup[j]+1, *itup[j+1:])
for itup in iset for j in range(1, len(itup))
if itup[0]>0 and itup[j]<sizes_m1[j])
count += len(iset)
return count
Polynomial algorithm (current state-of-the-art)#
A search of the literature turns up paper by Glück et al (2013) (along with a C++ implementation and follow-up paper by Glück and Köppl (2020)).
Somewhat annoyingly, neither paper gives a complete, end-to-end description of the algorithm, leaving some parts only implicitely defined. (The follow-up paper does improve in this regard though.) It is therefore hard to say exactly how their algorithm works without implementing it ourselves, but from what I gather it is something like this:
Construct a special data structure in \(\oO(2^L L^3)\) time.
For any desired total \(k\), query this data structure by solving a simple polynomial inverse problem; this takes \(\oO(L)\) time.
From what I understand, their data structure could be exploiting the same symmetries we describe below, although clearly doing so differently. I did not investigate this further since the solution I arrived at seems even better, at least for my requirements where I need to compute multiplicities for most values of \(k\).
Exploit differential symmetry between multiplicity sequences#
The best solution by far seems to be to exploit a strong symmetry between different multiplicity sequences. Consider the sequences for size tuples \([4, 4, 4]\), \([4, 4, 5]\), \([4, 5, 5]\):
As we would expect from Eq. 3, the three multiplicity sequences are symmetric. Note however, that also the differences between them are symmetric. This holds for any pair of size tuples, at least if they differ in only one position.[1] We can use this to quickly deduce the entire multiplicity sequence for a new size tuple:
Start from a known multiplicity sequence, e.g. that of \([4, 4, 4]\).
Choose another size tuple which differs in only one position, e.g. \([4, 4, 5]\). It need not be consecutive with the previous one.
The first 4 elements of the sequence for \([4, 4, 5]\) are the same as for \([4, 4, 4]\); copy those over.
Reflect the copied values to the bottom of the sequence for \([4, 4, 5]\). Compute the difference between these copied values and their corresponding values in \([4, 4, 4]\).
These differences must also be symmetric, so reflect the 4 differences to positions 5–8 and use them to compute the \([4, 4, 5]\) multiplicities at those positions.
Repeat until the entire sequence for \([4, 4, 5]\) is filled.
The runtime and memory requirements are essentially linear in L (the length of an index tuple).
Compared to the recursive one, this algorithm will be notably slower for evaluating multiplicity for a single small total index \(k\) and large \(L\). It might take 100 s where the recursive one takes 10 ms. However
The recursive one will require additional computations if it is evaluated at a new value \(k_2 > k\). This algorithm caches the entire sequence on the first call.
For large total index \(k\), where the recursive algorithm may take days to finish, this algorithm will still only take ~100s.
The cache is much more efficient:
Here we only cache a single array of all multiplicities for a given size tuple \(\vec{s}\).
In contrast, the recursive algorithm has separate entries for each total \(k\), and for each subset \((s_l, \dotsc, s_L)\), \(1 \leq l \leq L\) of the size vector.
There are obvious further opportunities to optimize, such as determining the most efficient way to build up to a size tuple with arbitrary values \(s_j\) at each position, but the implementation we have below is already good enough for our purposes.
An important advantage of this algorithm is that it can be entirely implemented in terms log multiplicities. A log space implementation will return floats instead of integers, with maybe only 8-10 significant digits when we exponentiate them to recover actual multiplicities. However in the MDL calculation we actually want log counts, and having the entire computation in terms of logs means that we effectively never run into overflow issues, even with the astronomical multiplicities that we encounter.
Version without logarithm
More exact at low values, but runs into overflows at high values.
def index_mult_reflect(S1: np.ndarray[int], S2: np.ndarray[int], mults_S1: np.ndarray[float]
) -> np.ndarray:
"""
Return a sequence of index multiplicities (for all totals k) for the
size tuple `S2`.
This works by taking a known sequence of multiplicities `mults_S1`,
for the size tuple `S1`. By using symmetries in the difference between
the multiplicity sequences of any two size tuples, we can reconstruct
the values of `mults_S2` from those of `mults_S2`.
Returns:
mults_S2
"""
# Infer the changed index position from S1 and S2
ΔS = S2 - S1
assert ΔS.ndim == 1
jΔ = np.nonzero(ΔS)[0] # Returns one array per dim; we only have one dim
# S2 must be bigger than S1 and different in only one position
assert (len(jΔ) == 1) and (ΔS[jΔ] > 0)
jΔ = jΔ[0] # It’s length one, so replace array by an integer
# Initialize the new mult array for S2. k can run from 0 to sum(S2-1)
mults_S2 = np.zeros((S2-1).sum()+1)
assert mults_S1.size == (S1-1).sum()+1
# Pad mults_S1 with zeros so it has the same length as mults_S2
mults_S1 = np.pad(mults_S1, (0, mults_S2.size-mults_S1.size))
# Multiplicities will be the same for S1 & S2 until k=S1[jΔ]
# This will serve as our block size (or "width")
w = S1[jΔ]
Δmults = np.zeros(w)
k0 = 0
kn = mults_S2.size
while k0 <= kn:
mults_S2[k0:k0+w] = mults_S1[k0:k0+w] + Δmults[::-1]
mults_S2[kn-w:kn] = mults_S2[k0:k0+w][::-1]
Δmults = mults_S2[kn-w:kn] - mults_S1[kn-w:kn]
k0 += w
kn -= w
return mults_S2
def _index_logmult_reflect(S1: np.ndarray[int], S2: np.ndarray[int], logmults_S1: np.ndarray[float]
) -> np.ndarray:
"""
Return a sequence of index multiplicities (for all totals k) for the
size tuple `S2`.
This works by taking a known sequence of multiplicities `mults_S1`,
for the size tuple `S1`. By using symmetries in the difference between
the multiplicity sequences of any two size tuples, we can reconstruct
the values of `mults_S2` from those of `mults_S2`.
Returns:
logmults_S2
"""
# Infer the changed index position from S1 and S2
ΔS = S2 - S1
assert ΔS.ndim == 1, "`S1` and `S2` size arrays should differ in only one position."
jΔ = np.nonzero(ΔS)[0] # Returns one array per dim; we only have one dim
# S2 must be bigger than S1 and different in only one position
assert (len(jΔ) == 1) and (ΔS[jΔ] > 0), f"{jΔ=}, {ΔS[jΔ]=}"
jΔ = jΔ[0] # It’s length one, so replace array by an integer
# Initialize the new mult array for S2. k can run from 0 to sum(S2-1)
logmults_S2 = np.empty((S2-1).sum()+1); logmults_S2.fill(-np.inf) # -∞ is the log equivalent of initializing with zero
assert logmults_S1.size == (S1-1).sum()+1, "The provided `logmults_S1` does not match the size expected for size array `S1`."
# Pad mults_S1 with zeros so it has the same length as mults_S2
logmults_S1 = np.pad(logmults_S1, (0, logmults_S2.size-logmults_S1.size), constant_values=-np.inf)
# Multiplicities will be the same for S1 & S2 until k=S1[jΔ]
# This will serve as our block size (or "width")
w = int(S1[jΔ])
Δlogmults = np.empty(w); Δlogmults.fill(-np.inf) # -∞ is the log equivalent of initializing with zero
k0 = 0
kn = logmults_S2.size
while k0 <= kn:
logmults_S2[k0:k0+w] = np.logaddexp(logmults_S1[k0:k0+w], Δlogmults[::-1])
logmults_S2[kn-w:kn] = logmults_S2[k0:k0+w][::-1]
# In contrast to `logaddexp`, `logsumexp` supports the scaling factor b, which we use to implement subtraction.
Δlogmults = logsumexp( [ lmS2 := logmults_S2[kn-w:kn],
#logmults_S1[kn-w:kn] ],
np.clip(logmults_S1[kn-w:kn], None, lmS2) ], # clip to avoid difference becoming neg due to numerical precision
axis=0, b=[[1], [-1]])
k0 += w
kn -= w
return logmults_S2
index_logmult_cache = {}
def index_logmultiplicity(sizes: tuple[int,...], k: int) -> float:
logmults = index_logmult_cache.get(sizes)
if k < 0:
return -np.inf
if logmults is None:
# Build up the target size tuple from the (1,1,…) case
S2 = np.ones(len(sizes), dtype=int) # With indices of size 1, the only possible value is 0
logmults = np.zeros(1) # There is only one case, k=0, and it has multiplicity m=1 (ergo log(m)=0)
for i, s in enumerate(sizes):
while S2[i] < s:
S1 = S2.copy()
S2[i] = min(2*S2[i], s)
logmults = _index_logmult_reflect(S1, S2, logmults)
# Store in the cache
index_logmult_cache[sizes] = logmults
if k >= logmults.size:
return -np.inf
return logmults[k]
def index_multiplicity(sizes: tuple[int,...], k: int) -> float:
return np.exp(index_logmultiplicity(sizes, k))
Timing: This takes about 90 seconds to execute on a size tuple of dimension 1000. Afterwards every index total \(k\) is a quick cache lookup.
Confirm that the optimized version returns the same values.
for k in range(16):
#assert index_multiplicity((2, 3, 1, 5, 4), k) == index_multiplicity_simple((2, 3, 1, 5, 4), k)
assert math.isclose(index_multiplicity((2, 3, 1, 5, 4), k), index_multiplicity_simple((2, 3, 1, 5, 4), k))
#assert index_multiplicity((5,)*300, 1000) == index_multiplicity_simple((5,)*300, 1000)
assert math.isclose(index_multiplicity((5,)*300, 1000), index_multiplicity_simple((5,)*300, 1000))
Augment the Dataset
class with methods to generate and count possible datasets.#
value_sets
value_set_sizes
num_possible_datasets
num_index_classes
index_logmultiplicities
gen_possible_datasets
generates all datasets.gen_representative_datasets
generates up to \(r\) different datasets from each index class \(\iI_k\). (It may generate fewer datasets, if \(\lvert \iI_k \rvert < r\).)
@dataclass(frozen=True)
class EnumerableDataset(Dataset):
def value_sets(self, thresh=0.01):
"""
Return a list of L arrays, one for each λ position.
Each array contains the possible radiance values for that wavelength λ.
"""
if thresh < 0 or thresh > 1:
raise ValueError(f"`thresh` must be between 0 and 1. Received {thresh}.")
# Implementation is mirrored on poisson_noise.__call__, except that instead of
# drawing from the Poisson, we use it to compute the range of possible values
λ, B = DataModel(λ_min=self.λmin, λ_max=self.λmax, T=self.T, phys_model="Planck")(self.L)
sB = self.s*B
assert sB.dimensionless; "Dimension error: s*B should be dimensionless"
# Determine possible values for ℬ at each λ
lower_counts = stats.poisson(sB.m).ppf(thresh).astype(int)
upper_counts = stats.poisson(sB.m).ppf(1-thresh).astype(int)
# For each k, sort the ℬ values from most to least likely
ℬsets = []
for _sB, _lc, _uc in zip(sB, lower_counts, upper_counts):
ℬprobs = stats.poisson(_sB.m).pmf(np.arange(_lc, _uc+1))
ℬvals = np.arange(_lc, _uc+1) / self.s + self.B0
ℬsets.append(ℬvals[np.argsort(ℬprobs)[::-1]].m) # Need to work with magnitudes to assign in-place below. Also assigning magnitudes is faster since it skips the unit checks
return λ, ℬsets
def value_set_sizes(self, thresh=0.01):
_, ℬsets = self.value_sets(thresh)
return np.array([len(ℬvals) for ℬvals in ℬsets])
def num_possible_datasets(self, thresh=0.01):
# Take the product of set sizes |ℬk|
ℬsizes = self.value_set_sizes(thresh=0.01)
return int(np.prod(ℬsizes, dtype=float)) # float prevents overflow
def num_index_classes(self, thresh=0.01):
ℬsizes = self.value_set_sizes(thresh=0.01)
return (ℬsizes-1).sum() + 1
def index_logmultiplicities(self, thresh=0.01) -> np.ndarray[float]:
"""Returns a array of the logarithm of index multiplicities."""
ℬsizes = self.value_set_sizes(thresh=0.01)
num_classes = (ℬsizes-1).sum() + 1
return np.fromiter( ( index_logmultiplicity(tuple(ℬsizes), k)
for k in range(num_classes) ),
dtype=float, count=num_classes) # Float avoids overflow
def gen_possible_datasets(self, thresh=0.01):
"""
Iterator for every dataset.
We exclude datapoints which collectively have a total probability of at most 2*thresh (default: 0.01%).
Datasets are returned in (approximate) order of most to least likely: the set of possible
values at each λ stop is ordered from most to least likely, and a dataset is constructed
by choosing an index i for the value at each λ stop.
Since each set of values is ordered by likelihood, the sum of indices should strongly correlate
with the likelihood of the returned dataset.
"""
# Returns a tuple of two arrays of possibly values for λ and B respectively.
# """
λ, ℬsets = self.value_sets(thresh)
# Go through all combinations which pick one ℬk from each ℬset
ℬarr = np.zeros_like(B) * B.units # Pre-allocate array for efficiency
for idcs in gen_idcs_by_total(upper_counts - lower_counts + 1):
ℬarr.m[:] = tuple(ℬvals[i] for ℬvals, i in zip(ℬsets, idcs))
yield λ, ℬarr
def gen_representative_datasets(self, thresh=0.0001, r=10, rng=None) \
-> Generator[tuple[np.ndarray, list[np.ndarray]]]:
"""
Iterator for a sampling of (hopefully) representative datasets.
As in `gen_possible_datasets`, datasets are ordered from most to least likely;
the “index total” should strongly correlate with dataset likelihood.
Index totals partition the set of all possible datasets into classes; we return
up to `r` “representative” datasets per class, along with the multiplicity
of that class. (Actual number of representatives will be less if the class
multiplicity is less than `r`.) The representatives are drawn randomly
and guaranteed to be distinct.
This allows to quickly compute preliminary estimates for averages over the entire dataset,
by evaluating a functional on the representative datasets and multiplying their
average by the multiplicity.
Yields
------
λ : array
{ℬ} : list[array]
"""
rng = np.random.default_rng(rng)
λ, ℬsets = self.value_sets(thresh)
ℬsizes = np.array([len(ℬvals) for ℬvals in ℬsets])
# Iterate over index total classes and return one rep per class
#tot_size = int(np.prod(ℬsizes.astype(float))) # float prevents overflow; casting to Python int is safe (worst case we get a bigint)
max_idx = ℬsizes - 1
#L = len(ℬsizes)
# Pre-allocate value arrays. We will modify them in-place
ℬarrs = [np.zeros(self.L) * Bunits for _ in range(r)]
for k in range(max_idx.sum()+1):
#rep_idx = get_unif_idx(ℬsizes, k)
rep_idx_tuples = get_multiple_rnd_idcs(ℬsizes, k, r, rng=rng)
for ℬarr, idx_tuple in zip(ℬarrs, rep_idx_tuples):
ℬarr.m[:] = tuple(ℬvals[i] for ℬvals, i in zip(ℬsets, idx_tuple))
#yield index_multiplicity(tuple(ℬsizes), k), λ, ℬarrs[:len(rep_idx_tuples)]
yield λ, ℬarrs[:len(rep_idx_tuples)]
@cache # Add ability to strip units (not done in base class in case it would invalidate a cache)
def get_data(self, rng=None, strip_units=False):
λ, B = self.__call__(self.L, rng)
return (λ.m, B.m) if strip_units else (λ, B)
Compute the complexity by decomposing the sum over index classes.#
If we define \(l(\vec{i}) := \max_{θ \in Θ} p_θ(\D_{\vec{i}}) π(θ)\), then the expression for complexity becomes
Below,
rep_lk
is a list of representative log likelihoods for a given \(k\).rep_l = [rep_l0, rep_l1, …]
is the list of all representative log likelihoods.mults
is the list of multiplicities \(\lvert \iI_k \lvert\) for each index class \(k\).
In _comp
, we then compute
avg_lk
: an estimator for \( \log\Bigl[\Braket{e^{l_j}}_{\sum j = k}\Bigr]\) based on the representative datasets for each index class \(k\).
The complexity is then finally computed as
Storing intermediate values as logarithms avoids issues with numerical underflow/overflow.
One more optimization becomes necessary for datasets with larger \(L\): the number of index classes \(K\) can grow very large, since \(K \sim C^L\). On the other hand, the likelihood becomes more peaked with large \(L\), so the contribution of datasets with large \(k\) becomes more negligible. (This is not as pronounced as in a Bayesian calculation, because the model is refit to each dataset, but it still happens.)
Therefore to avoid spending all our time estimating irrelevant terms, we need to check if we can truncate Eq. 4. For this we use the fact that likelihoods are decreasing with \(k\), so that
Let \(\comp_{k'}(\M, π) := \log \Bigl[ \sum_{k=0}^{k'} \Braket{e^{\log \lvert \iI_k \rvert + l(\vec{i})}}_{\sum \vec{i} = k} \Bigr]\) denote the estimated complexity after truncating the sum at \(k'\). The relative error we make is bounded by
The implementation below truncates by default when the relative error is 0.01%, giving us 4 significant digits.
In practice, even in the worse cases, this approach allows to halve the computation time, since the multiplicities follow a binomial-like curve: symmetric, and maximum at \(k=K/2\). Once we’ve computed the 50% largest terms, the number of remaining contributions quickly decreases and becomes negligible.
def gen_representative_likelihoods(D: EnumerableDataset, r: int, m:int,
fitΘ: Callable, logL: Callable, thresh=0.01, rng=None) \
-> Generator[tuple[int, np.ndarray[float]]]:
#mults = []
#rep_l = []
mp_process_id = multiprocessing.current_process()._identity # Stores a tuple; parent has (), 1st level children (i,), 2nd level children (i,j), etc.
progbar_offset = 0 if (len(mp_process_id) == 0) else 1 + mp_process_id[0] # Add one to child processes to leave space for a progbar of the parent process
for (_λ, ℬs) in tqdm(D.gen_representative_datasets(thresh, r=r, rng=rng),
desc="index class", total=D.num_index_classes(thresh),
position=progbar_offset):
rep_lk = []
for _ℬ in ℬs:
σ, T, coeffs = fitΘ((_λ,_ℬ), m)
rep_lk.append(logL(σ, T, coeffs)((_λ,_ℬ)).sum())
yield np.array(rep_lk)
#rep_l.append(np.array(rep_lk))
#mults.append(mult)
#return mults, rep_l
def _comp(D: EnumerableDataset, r: int, m: int, fitΘ: Callable, logL: Callable,
value_thresh: float=0.01, rtol=0.01, cache_key=None, rng=None) -> float:
"""
Return a Monte Carlo estimate of the MDL model complexity.
:param:fitΘ: is used to find the maximum likelihood parameters.
:param:logL: is used to evaluate the log likelihood at those parameters.
For each index class, up to `r` datasets are generated.
This function is cached using `joblib.Memory` based on argument values which
__exclude__ `fitΘ` and `logL`. This is because callables aren’t reproducibly hashed.
Use `cache_key` instead to pass a hashable value which is unique to every
`fitΘ` & `logL` combination.
"""
logmults = D.index_logmultiplicities(value_thresh)
#n_remaining_terms = mults.sum() - mults.cumsum() # <- Version below equiv to this, except that it uses log multiplicities
logn_remaining_terms = np.r_[np.logaddexp.accumulate(logmults[::-1])[::-1][1:], -np.inf]
avg_l = np.zeros_like(logmults, dtype=float)
# We use a running update of the complexity to track when to stop the sum.
# This is faster but less accurate than computing all at once with `logsumexp`.
compk = -np.inf # Log equivalent to initializing a sum accumulator at 0
for k, rep_lk in enumerate(gen_representative_likelihoods(
D, r, m, fitΘ, logL, value_thresh, rng)):
avg_l[k] = logsumexp(rep_lk - np.log(rep_lk.size))
compk = np.logaddexp(compk, logmults[k] + avg_l[k])
#abs_err = np.log1p(np.exp(avg_l[k] - compk) * n_remaining_terms[k])
abs_err = np.logaddexp(1, avg_l[k] - compk + logn_remaining_terms[k])
if abs(abs_err/compk) < rtol:
logger.info(f"Truncated complexity sum after {k+1}/{len(logmults)} terms, "
f"representing 10^{logsumexp(logmults[:k])/math.log(10)}"
f"/10^{logsumexp(logmults)/math.log(10)} datasets.")
break
# Recalculate compk using logsumexp for better accuracy
return logsumexp(logmults[:k] + avg_l[:k])
#mults, rep_l = get_representative_likelihoods(D, r, m, fitΘ, logL, thresh, rng)
#avg_l = np.fromiter((logsumexp(rep_lk - np.log(rep_lk.size)) for rep_lk in rep_l),
# count=len(rep_l), dtype=float)
#return logsumexp(np.log(mults) + avg_l)
_cached_comp = memory.cache(_comp, ignore=["fitΘ", "logL"])
# Small wrapper which skips caching if either `cache_key` or `rng` is None
@wraps(_comp)
def comp(D: EnumerableDataset, r: int, m: int, fitΘ: Callable, logL: Callable,
value_thresh: float=0.01, rtol=0.01, cache_key=None, rng=None,
no_compute: bool=False) -> float|None:
"""
If `no_compute` is True, then the function only returns a float if a pre-computed
cached value is available. Otherwise it returns None.
"""
if cache_key is None or rng is None:
return None if no_compute else _comp(D, r, m, fitΘ, logL, value_thresh, rtol, cache_key, rng)
else:
in_cache = _cached_comp.check_call_in_cache(D, r, m, fitΘ, logL, value_thresh, rtol, cache_key, rng)
return None if (no_compute and not in_cache) \
else _cached_comp(D, r, m, fitΘ, logL, value_thresh, rtol, cache_key, rng)