Pyloric network simulator#

This is an implementation of the three-cell model of crustacean stomatogastric ganglion described in Marder and Abbott [1998] and used by Prinz et al. [2004] to demonstrate that disparate parameter sets lead to comparable network activity. The main features of this particular implementation are:

JAX-accelerated Python implementation

Combines the flexibility of Python with the speed of compiled JAX code.

The specification of update equations for ion channels was carefully designed to maximize vectorization.

Choice of ODE solver

The official C++ implementation uses hard-coded Euler integration, which can make it vulnerable to numerical instabilities. (Neuron conductance-based models, after all, are stiff by design.)

In contrast, this implementation is designed to be used with Python’s ODE solvers, making it easy to use a more appropriate solver. In particular, adaptive step solvers can dramatically reduce simulation time, while still keeping enough time precision to resolve spikes.

Modularity

New cell types are defined simply by extending Python arrays for the ion channel parameters.
If needed, an arbitrary function can also be given for specific ion channels.

Flexibility

You are not limited to the published three-cell network: studying a smaller two-cell network, or a larger one with different LP cells, is a simple matter of changing three arrays.

All-in-one documentation

The original specification of the pyloric circuit model is spread across at least three resources[1].
Here all definitions are included in the inlined documentaiton, are fully referenced and use consistent notation.

Single source of truth

Both the documentation and the code use the same source for parameter values, so you can be sure the documentated values are actually those used in the code.

In cases where this was not practical, Python values are specified with a string identical to the source of the documentation’s Markdown table. Agreement between documentation and code can be checked at any time by “diff”-ing the markdown and Python code blocks, or copy-pasting one onto the other.

Hide code cell source
import logging
import shelve
import re
import numpy as np
import pandas as pd
import holoviews as hv

from ast import literal_eval
from collections import namedtuple
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Union, Type, Generator, Callable, Tuple, ClassVar
from flufl.lock import Lock
from numpy import exp
from numpy.typing import ArrayLike
from scipy import integrate
from addict import Dict

from scityping.numpy import Array

logger = logging.getLogger(__name__)
Hide code cell source
from config import config
try:
    import jax
    from jax import config as jax_config; jax_config.update("jax_enable_x64", True); del jax_config
    import jax.numpy as jnp
except ImportError:
    import jax_shim as jax
    import jax_shim.numpy as jnp

Usage example#

import holoviews as hv
from pyloric_simulator.prinz2004 import Prinz2004, neuron_models, dims
hv.extension("bokeh")  # Bokeh plots allow to zoom in on the trace much more easily

Instantiate a new model by specifying:

  • The number of neurons of each type (PD, AB, LP and PY)

  • The connectivity gs between populations

  • The set of membrane conductances for each type g_cond.
    The 16 sets which define the neuron models used in Prinz et al. [2004] are provided in neuron_models.

# Note that cholinergic (PD) and glutamatergic (AB, LP, PY) neurons are contiguous.
# This allows the model to use more efficient indexing based on slices,
# instead of slower indexing arrays.
# `pop_size` keys should match one of the four type labels (AB, PD, LP, PY) or the
# combined label "AB/PD". If there are multiple neurons of one type, they can be
# followed by a number: "AB 1", "AB 2", ...
# IMPORTANT: The `Prinz2004` class uses regular expressions on these labels to determine
# the right type of connectivity, so they must follow this labelling pattern exactly.
model = Prinz2004(pop_sizes = {"PD": 2, "AB": 1, "LP": 1, "PY": 5},
                  gs      = [ [    0  ,     0  ,     3  ,     0 ],
                              [    0  ,     0  ,     3  ,     3 ],
                              [    3  ,     3  ,     0  ,     3 ],
                              [    3  ,     3  ,     3  ,     0 ] ],
                  g_ion = neuron_models.loc[["AB/PD 1", "AB/PD 1", "LP 1", "PY 1"]]  # May be specified as DataFrame or Array
                  )

Define a set of time points and evaluate the model at those points. A model simulation is always initialized with the result of a cached warm-up simulation (see Initialization). If no such simulation matching the model parameters is found in the cache, it is performed first and cached for future runs.

The example below reproduces the procedure of Prinz et al. [2004]: After neurons are connected, an initial simulation of 3s is thrown away to let transients decay. Then 1s is simulated to generate the data for that network model.

Note

The time points we provide to model are only those which are recorded; they have no bearing on the integration time step.[2] To resolve spikes, a time resolution of 1 ms or less is needed.

# Generate 1001 sequential time points (3000–4000), and evaluate the model at those points
res = model(np.linspace(3000, 4000, 1001))

Results are returned as a SimResult object, which has attributes to retrieve the different traces: membrane voltage V, calcium concentration Ca, synaptic activation s, membrane activation m and membrane inactivation h. These are returned as Pandas DataFrames.

TODO

The m and h variables for now are still returned as plain arrays.

# Retrieve trace for one neuron per population
Vtraces = res.V.loc[:,[("AB", 1), ("LP", 1), ("PY", 1)]].droplevel("index", axis="columns")
# In addition to voltage, the following time-dependent variables are also available:
res.Ca  # Intracellular calcium concentration  (n_time_bins x n_neurons)
res.s   # Synapse activation                   (n_time_bins x n_neurons)
res.m   # Channel activation                   (n_time_bins x n_channels x n_neurons)
res.h   # Channel inactivation                 (n_time_bins x n_channels x n_neurons)
# A plain time-bin array can also be retrieved
res.t
array([3000., 3001., 3002., ..., 3998., 3999., 4000.])

For simple plots, one can use the DataFrame plot method directly. For more complex plots, it is often more convenient to retrieve the data and call the plotting functions directly.

Vtraces.plot("AB")  # Simple plot using the DataFrame method
# Call plotting functions directly after retrieving data
hv.Curve(Vtraces.loc[:,"AB"]) + hv.Curve(Vtraces.loc[:,"LP"]) + hv.Curve(Vtraces.loc[:,"PY"])

To inspect the initialization curves, we use the private method thermalize; this is used internally to generate the thermalized state. It returns a SimResult object. (NB: Since during thermalization, we disconnect all neurons to get each’s individual spontaneous activity, it is only necessary to simulate one neuron per population. This is why the result returned by thermalize always has population sizes of 1.)

res = model.thermalize()
Vtraces = res.V.droplevel("index", axis="columns")  # The initialization run has only one neuron per pop, since they are identical

hv.output(backend="bokeh")  # Workaround: If we don’t switch the default, Bokeh ignores the new colors
ov = hv.Overlay([hv.Curve(Vtraces.loc[:,pop_name], label=pop_name)
                 for i, pop_name in enumerate(model.pop_model_list)]) \
     .opts(ylabel="voltage (mV)") \
     .redim(time=dims.t)
# Assign trace colours
for pop_name, c in zip(res.pop_slices, hv.Cycle("Dark2").values):
    ov.opts(hv.opts.Curve(f"Curve.{hv.core.util.sanitize_identifier_fn(pop_name)}", color=c))
ov.opts(width=800, legend_position="right", backend="bokeh")

Definitions#

Hide code cell source
# IMPORTANT: List orders must be kept consistent, because for efficiency the
#            implementation uses plain arrays, vectorized ops and indexing by position
act_varnames = ["minf", "hinf", "taum", "tauh"]
channels     = ["INa", "ICaT", "ICaS", "IA", "IKCa", "IKd", "IH"]
neuron_types = ["PD", "AB", "LP", "PY"]  # Putting PD first allows us to have all glutamatergic neurons in one contiguous slice
neuron_pops  = ["AB/PD", "LP", "PY"]

dims = dict(
    minf = r"$m_\infty$",
    hinf = r"$h_\infty$",
    taum = r"$τ_m$",
    tauh = r"$τ_h$",
    INa  = r"$I_\mathrm{Na}$",
    ICaT = r"$I_\mathrm{CaT}$",
    ICaS = r"$I_\mathrm{CaS}$",
    IA   = r"$I_\mathrm{A}$",
    IKCa = r"$I_\mathrm{K(Ca)}$",
    IKd  = r"$I_\mathrm{Kd}$",
    IH   = r"$I_\mathrm{H}$",
    V    = r"$V$",
    
    Es   = r"$E_\mathrm{s}$",
    km   = r"$k_{-}$",
    Vth  = r"$V_{\mathrm{th}}$",
    Δ    = r"$\Delta$",
    
    t    = "time"
)
dims = Dict({nm: hv.Dimension(nm, label=label) for nm, label in dims.items()})

dims.V.unit = "mV"
dims.t.unit = "ms"

Single compartment conductance model#

The conductance model is described in Prinz et al. [2003] (p.1-2, §Model). Where differences in notation occur, we prefer those in Prinz et al. [2004].

In the equations, Ii is the current for each channel, while Iinput is the current from synaptic inputs Ie and Is. These are computed using the electrical and chemical synapse models defined below. We also allow for an additional external input current Iext; this current is not necessary to drive the system (after all, a defining feature of the pyloric circuit is its spontaneous rhythm). Expected magnitudes for Iext are 3–6 nA.[3]

Note

Although a concentration must always be positive, the differential equation for [Ca2+] reproduced here from Prinz et al. [2004] does not per se prevent the occurence of a negative concentration. (Sustained CaT and CaS currents could drive [Ca2+] below zero.) In practice the rest of the ODE dynamics seem to prevent this, but nevertheless to ensure [Ca2+] is always positive and improve numerical stability, in our implementation below we track log[Ca2+] instead. Since concentrations can span multiple orders of magnitude, considering them in log space is in fact rather natural.

A similar thing can be said of m and h, which must be bounded within [0,1]; in this case we use a logit transformation to ensure the variables never exceed their bounds. (We also do this for the synapse activations.)

(1)#CAdVdt=iIiIinputIi=gimiphi(VEi)Iinput=Ie+Is+Iextτmdmdt=mmτhdhdt=hh
τCad[Ca2+]dt=f(ICaT+ICaS)[Ca2+]+[Ca2+]0

Constant low-dimensional parameters#

Table 1 (Values provided in Prinz et al. [2003].)#

INa

ICaT

ICaS

IA

IK(Ca)

IKd

IH

Ileak

unit

Ei

50

ECa

ECa

-80

-80

-80

-20

-50

mV

p

3

3

3

3

4

4

1

Calcium reversal potentials ECa are determined by solving the Nernst equation:

ECa=RTzFln[Ca2+]out[Ca2+]in=kBTzeγln[Ca2+]out[Ca2+]inγln[Ca2+]out[Ca2+]in[mV]

where R is the ideal gas constant (or kB the Boltzman constant), T the temperature, z the charge of the ions and F the Faraday constant (or e the elementary charge). Since lobsters are ectotherms (cold-blooded) and live in cold shallow Atlantic waters, a reasonable temperature range might be 1–10°C, which corresponds to the range γ(11.8,12.2)mV for ions with z=2. In this implementation we fix γ to 12.2mV (this matches the value for Nernstfactor is the published source code), and following Prinz et al. [2003], further set [Ca2+]out to 3mM.

Table 2 Constants#

Constant

Value

Unit

A (membrane area)

0.628

103cm2

C

0.628

nF

τCa

200

ms

f

14.96

μM/nA

[Ca2+]0

0.05

μM

[Ca2+]out

3000

μM

γ

12.2

mV

approximate numerical time step

0.025

ms

If we multiply all the units en Eq. (1) together, using 103cm2 for the unit of A, we find that they simplify to 1mV/ms – the desired units for dV/dt. Therefore we can write the implementation using only the magnitudes in the values column of Table 2 and ignore the units. Moreover, we omit C and A since their magnitudes cancel.

Maximum channel conductances#

These are the channel conductance values (g) to use in Eq. (1); they are reproduced from Table 2 in Prinz et al. [2004].

Differences in these values is what differentiates neuron models.

Table 3 Maximal conductance densities of model neurons#

\textbf{Maximal membrane conductance (}g\textbf{) in mS/cm}2

Model neuron

g(INa)

g(ICaT)

g(ICaS)

g(IA)

g(IK(Ca))

g(IKd)

g(IH)

g(Ileak)

AB/PD 1

400

2.5

6

50

10

100

0.01

0.00

AB/PD 2

100

2.5

6

50

5

100

0.01

0.00

AB/PD 3

200

2.5

4

50

5

50

0.01

0.00

AB/PD 4

200

5.0

4

40

5

125

0.01

0.00

AB/PD 5

300

5.0

2

10

5

125

0.01

0.00

LP 1

100

0.0

8

40

5

75

0.05

0.02

LP 2

100

0.0

6

30

5

50

0.05

0.02

LP 3

100

0.0

10

50

5

100

0.00

0.03

LP 4

100

0.0

4

20

0

25

0.05

0.03

LP 5

100

0.0

6

30

0

50

0.03

0.02

PY 1

100

2.5

2

50

0

125

0.05

0.01

PY 2

200

7.5

0

50

0

75

0.05

0.00

PY 3

200

10

0

50

0

100

0.03

0.00

PY 4

400

2.5

2

50

0

75

0.05

0.00

PY 5

500

2.5

2

40

0

125

0.01

0.03

PY 6

500

2.5

2

40

0

125

0.00

0.02

Voltage-dependent activation variables#

To implement the voltage equations in code, we write them as

x(V,[Ca2+];y,a,b,c,C)=y(V,[Ca2+])(σ(V;a,b,c)+C)σ(V;a,b,c)=a1+exp(V+bc),

where x is one of m, h, τm or τh. This form allows us to implement them as almost entirely vectorized operations, which calculate all activation variables simultaneously. Only the y(V,[Ca2+]) function requires a loop over the five channel types for which it is non-zero

  Cell In[21], line 1
    ```{glue:figure} tbl_act_params
    ^
SyntaxError: invalid syntax

Only the first four ion channels have an h variable. We define h_slice to select them.

h_slice = slice(0, 4)
nhchannels = 4         # Number of ion channels which have an h variable

Caution

The units for the denominator of m[IK(Ca)] were not reported in the original publications. However other equations [Ca2+] are given in μM, and if we check the source code we can see that these are the correct units here as well.

Table 4 Voltage dependence - y(V,[Ca2+])#

m

h

τm

τh

INa

1.341+exp(V+62.910)

ICaT

ICaS

2.8+14exp(V+2710)+exp(V+7013)
120+300exp(V+559)+exp(V+6516)

IA

IK(Ca)

[Ca2+][Ca2+]+3μM

IKd

IH

2exp(V+169.711.6)+exp(V26.714.3)
def y(V, Ca, exp=jnp.exp, array=jnp.array, ones_like=jnp.ones_like):
    """
    Returns a list of multipliers for the voltage equations.
    The indices in the (variable x channel) matrix to which each multiplier corresponds
    are stored in `y.indices`. The format of `y.indices` is compatible with advanced
    indexing, so one can do ``A[(*y.indices,...)] *= y(V, Ca)``.
    """
    V = V.reshape(-1)
    return array([
        1.34/(1+exp(-(V+62.9)/10.)),
        2.8 + 14./(exp((V+27.)/10.) + exp((V+70.)/-13.)),
        120. + 300./(exp((V+55.)/9.) + exp((V+65.)/-16.)),
        Ca/(Ca+3.) * ones_like(V),
        2./(exp((V+169.7)/-11.6) + exp((V-26.7)/14.3))
    ])
y.indices = ([act_varnames.index("tauh"), act_varnames.index("taum"), act_varnames.index("tauh"), act_varnames.index("minf"), act_varnames.index("taum")],
             [channels.index("INa"), channels.index("ICaS"), channels.index("ICaS"), channels.index("IKCa"), channels.index("IH")])
def act_vars(V, Ca, y=y, exp=jnp.exp, a=act_params.a[...,np.newaxis], b=act_params.b[...,np.newaxis], c=act_params.c[...,np.newaxis], C=act_params.C[...,np.newaxis]):
    """
    Returns an array of dynamic activation variables used in the expressions for the different currents
    with shape (channel x variable).
    Variables are: minf, hinf, τm and τh.
    """
    res = a/(1 + exp((V+b)/c)) + C       # σ() + C
    #res[(*y.indices, ...)] *= y(V, Ca)   # (σ() + C) * f(V, Ca)  —  Indexing + in-place operation ensure that only the 5 elements with a multiplier are updated
    res = res.at[(*y.indices, ...)].multiply(y(V, Ca))  # JAXMOD
    return res

Note

A few activation functions are actually not defined (e.g. h(IK(Ca))). These are represented by a constant unit function.

Electrical synapse model#

The AB and PD neurons are connected via an electric synapse [Prinz et al., 2004], which is given as Eq. (13) in Marder and Abbott [1998]:[4]

(2)#Ie=ge(VpostVpre).

Unfortunately Prinz et al. [2004] do not seem to document the value of ge they use. For our simulations we set it to 1, so it can be omitted in the implementation. Ie is then simply implemented as V[:,newaxis] - V, which yields the following antisymmetric matrix

Chemical synapse model#

The chemical synapse model is defined in Prinz et al. [2004] (p.1351, §Methods) and Marder and Abbott [1998] (Eqs. (14,18,19)) Here again, in case of discrepancy, we use the notation from Prinz et al. [2004].[5]

Is=gss(VpostEs)dsdt=s(Vpre)sτss(Vpre)=11+exp((VthVpre)/Δ)τs=1s(Vpre)k

The conductivity gs determines the synapse’s strength, and therefore the network’s connectivity. See Circuit model below.

The synapse parameters are chosen to match the inhibitory postsynaptic potentials (IPSPs) generated in the postsynaptic cell, which in turn depend on the type of neurotransmitter released by the presynaptic cell. For the purposes of the implementation, this means that these values depend on the type of presynaptic cell.

Cell type

Neurotransmitter

Es(mV)

k(ms1)

Vth(mV)

Δ(mV)

pre {AB, LP, PY}

glutamate

-70

1/40

-35

5

pre = PD

acetylcholine

-80

1/100

-35

5

In general, if we have N neurons, we would need to track a matrix of N×N synapse. However here the dynamics of s depends only on Vpre, and therefore it suffices to track s as a vector of length N.

syn_constants

Circuit model#

The circuit used in Prinz et al. [2004] is composed of three neuron populations:

Table 7 Neuron populations of the pyloric network#

Population name

Cell names

Number of cells

AB/PD
“pacemaker kernel”

anterior bursting
pyloric dilator

AB: 1
PD: 2

LP

lateral pyloric

1

PY

pyloric

5–8

In Prinz et al. [2004] the AB and PD are lumped together into a single model, with the only difference being that the AB cell has glutamatergic synapses while the PD cells have slower cholinergic synapses.

../_images/prinz-model_circuit-diagram_fig-1b-from-Prinz-2004_grey.svg

Fig. 1 Circuit diagram of the pyloric network, reproduced from Fig. 1 of Prinz et al. [2004].#

Network connectivity is determined by the synapse strength gs. In Prinz et al. [2004], this parameter takes one of the following values:

Table 8 Possible synapse values (gs)#

0 nS

1 nS
(PY only)

3 nS

10 nS

30 nS

100 nS

Implementation of dV/dt#

The evolution equations are implemented as a function dX for use with SciPy’s ODE solvers. Since these integrate only a single vector argument, we concatenate V, [Ca2+], s, m and h into a single array variable X.

Solving ODE requires a lot of function calls, and in this case the overhead of Python function calls does start to bite. To mitigate this, we pass all global constants and result arrays as arguments:

  • Passing global constants reduces the time required to look up those values, since the local scope is searched before the global scope. Binding functions like np.concatenate to concat similarly saves the lookup of np in the global namespace, and the subsequent lookup of concatenate in the np namespace.

  • Passing result arrays avoids the need to call np.array, which relative to other operations can have high overhead since it needs to allocate a new contiguous block of memory, as well as inspect the arguments for correctness. In practice this means things like the following:

    • Instead of concat((E[:i], ECa, E[i+2:])), we do E[i:i+2] = ECa. This saves both the cost of allocating a new array, and the overhead of all other checks concat performs.

The dX implementation is written such that as many operations as possible are performed outside the function (memory allocation, model value retrieval, etc.). The best way to use it is through the public API below, which combines all these setup operations and takes care of passing all those arguments defined for optimization purposes. It also provides packing/unpacking of the X vector.

Hide code cell source
Index = Union[int,None]
SliceReduceVal = Tuple[Type[slice], tuple[Index,Index,Index]]
Hide code cell source
@partial(jax.jit,
         static_argnames=["pop_slices", "syn_slices", "elec_slice",
                          "nchannels", "n_neurons", "I_ext",
                          #"γ", "Caout", "Eleak", "p",
                          #"f", "Ca0", "τCa",
                          "idx_ICaT", "idx_ICaS",
                          #"Es_tuple", "km_chol", "km_glut", "Vth_chol", "Vth_glut", "Δ_chol", "Δ_glut", "ε",
                         ]
        )
def dX(t, X,
       ## Extra parameters which should be passed with the `args` argument of the ODE solver ##
       g,   # Shape: n_pops x n_channels  (3 x 7)  <-- DOES NOT INCLUDE LEAK
       gleak,
       gs,  # Shape: n_pops x n_pops  (3 x 3)
       ge,  # Electrical conductance is generally fixed to 1, except when we need to remove all connections to compute spontaneous activity
       pop_slices: Tuple[SliceReduceVal],  # Slices for each of the three neuron populations AB/PD, LP and PY, in that order
       syn_slices: Tuple[SliceReduceVal],  # Slices for selecting neurons according to their synapse type: cholinergic (PD) and glutamatergic (AB, LP, PY), in that order
                                  # NOTE: We exploit the fact that types & populations can be ordered so that all subsets are contiguous blocks.
                                  #       This improves efficiency but is not essential: for non-contiguous blocks, we could use arrays to specify an advanced index
       elec_slice: SliceReduceVal, # Slice for selecting the neurons which have electrical synapses (i.e. AB/PD)
       E,   # constants.E.reshape(..., n_neurons)
       #Ii,  # empty(n_channels, n_neurons)
       #Is,  # empty(n_neurons)
       n_neurons,   # Used for unpacking X
       I_ext = None,
       *,
       ## Below are constants, pulled from the global module variables. They should not be used as arguments ##    
       # X unpacking
       nchannels  = len(channels),    # Number of non-leak ion channels (7)
       nhchannels = nhchannels,       # Number of channels with an h variable (4)
       h_slice     = h_slice,
       # E_Ca - Eq (3)
       γ     = constants.γ,
       logCaout = jnp.log(constants.Caout),
       # dI - Eq (1)
       Eleak = constants.Eleak,
       p     = constants.p[...,np.newaxis],
       # dCa - Eq (2)
       f        = constants.f,
       Ca0      = constants.Ca0,
       τCa      = constants.τCa,
       idx_ICaT = channels.index("ICaT"),
       idx_ICaS = channels.index("ICaS"),
       # Ie - Eq (5)
       # Is - Eq (6)
       Es_tuple = tuple(syn_constants.loc[["chol", "glut", "glut", "glut"], "Es"]),  # FIXME: Hard-coded pops.  Follows order in syn_slices
       km_chol  = syn_constants.loc["chol", "km" ],
       km_glut  = syn_constants.loc["glut", "km" ],
       Vth_chol = syn_constants.loc["chol", "Vth"],
       Vth_glut = syn_constants.loc["glut", "Vth"],
       Δ_chol   = syn_constants.loc["chol", "Δ"  ],
       Δ_glut   = syn_constants.loc["glut", "Δ"  ],
       ε        = 1e-8,   # Added to a denominator that may be zero
       # functions
       clip       = jnp.clip,
       concat     = jnp.concatenate,
       empty_like = jnp.empty_like,
       exp        = jnp.exp,
       log        = jnp.log,
       minimum    = jnp.minimum,
       newax      = jnp.newaxis,
       swapaxes   = jnp.swapaxes,
       tile       = jnp.tile
      ):
    # Shape convention: parameter x voltage
    #     (equivalent): (channel) x (neuron)
    
    # Reconstitute slices
    pop_slices = tuple(T(*args) for T, args in pop_slices)
    syn_slices = tuple(T(*args) for T, args in syn_slices)
    elec_slice = elec_slice[0](*elec_slice[1])
    
    # Unpack the vector argument
    X = X.reshape(3+nchannels+nhchannels, n_neurons)
    V, logCa, logits, logitm, logith, = (
        X[0], X[1], X[2],
        X[3:3+nchannels], X[3+nchannels:3+nchannels+nhchannels],
    )
    Ca = exp(logCa)
    s = 1 / (1 + exp(-logits))  # Inverse of the logit function  (NB: expression remains well defined even when exp(logits) -> oo)
    m = 1 / (1 + exp(-logitm))
    h = 1 / (1 + exp(-logith))
    
    
    # Update reversal potentials using Eq. (3)
    #E = tile(E, (len(X),1))  # I’m not happy with the copy, but I don’t see a better way to support modifying ECa when the size of the first dimension (time) is unknown
    # JAXMOD E[...,[1,2],:] = γ*Caout/Ca  # Modifying E in-place is fine as long as no other function does so
    E = E.at[...,[idx_ICaT,idx_ICaS],:].set( γ * (logCaout - logCa) )
    
    # Compute conductance inputs                  # g.shape    :  pop x channel
                                                  # pop_g.shape: channel
    #mphVE  = m**p * h * (V-E)                    # shape: channel x neuron
    mphVE  = m**p * (V-E)                         # shape: channel x neuron
    mphVE  = mphVE.at[h_slice,:].multiply(h)
    Ii = concat(
        [pop_g[:,newax] * mphVE[...,slc]          # shape: channel x pop_size
         for pop_g, slc in zip(g, pop_slices)],
        axis=-1
    )
    # for pop_g, slc in zip(g, pop_slices):
    #     Ii[:,slc] = pop_g[:,newax] * mphVE[:,slc] # shape: channel x pop_size
    # Ii has shape: channel x neuron
    
    # Compute dCa using Eq. (2)
    dCa = (-f * (Ii[...,idx_ICaT,:] + Ii[...,idx_ICaS,:]) - Ca + Ca0) / τCa
    dlogCa = dCa / Ca   # Chain rule w/ log transform
    
    # Compute the voltage-dependent activation variables using definitions in §Voltage-dependent activation variables
    m_inf, h_inf, τ_m, τ_h = act_vars(V, Ca)
    h_inf = h_inf[h_slice]  # For vectorization reasons, act_vars (at least for now)
    τ_h   = τ_h[h_slice]    # returns dummy values for channels with no h var
    
    # Compute dm and dh using Eq (1)
    dm = (m_inf - m) / τ_m
    dh = (h_inf - h) / τ_h
    dlogitm = dm / (m * (1-m))  # Chain rule through logit
    dlogith = dh / (h * (1-h))  # Chain rule through logit
    
    # Compute electrical synapse inputs using Eq (5)
    Ve = V[...,elec_slice,newax]
    Ie = ge * (Ve - swapaxes(Ve,-1,-2)).sum(axis=-1)   # This computes Ve - Ve.T, allowing for an additional time dimension on the left
        # ge is assumed to be a scalar. Currently we always set it to 1, or 0 if we are computing the thermalization (which requires a completely disconnected network)
    
    # Compute the synaptic inputs using Eq (6)
    # (By splitting glut & chol, we can use scalars for all parameters and not have to worry about broadcasting with the shape of V)
    cholslc, glutslc = syn_slices
    
    Vglut = V[...,glutslc]
    Vchol = V[...,cholslc]
    s_glut = s[...,glutslc]
    s_chol = s[...,cholslc]
    
    sinf_glut = 1 / (1 + exp(Vth_glut - Vglut)/Δ_glut)  # shape: (n_glut,)
    sinf_chol = 1 / (1 + exp(Vth_chol - Vchol)/Δ_chol)  # shape: (n_chol,)
    τs_glut = (1 - sinf_glut) / km_glut                 # shape: (n_glut,)
    τs_chol = (1 - sinf_chol) / km_chol                 # shape: (n_chol,)
    # ds_glut = (sinf_glut - s_glut) / τs_glut            # shape: (n_glut,)
    # ds_chol = (sinf_chol - s_chol) / τs_chol            # shape: (n_glut,)
    dlogits_glut = (sinf_glut - s_glut) / (τs_glut * s_glut * (1-s_glut) + ε)   # Incl. chain rule w/ logit transform.
    dlogits_chol = (sinf_chol - s_chol) / (τs_chol * s_chol * (1-s_chol) + ε)   # ε=1e-8 is added for numerical stability
    
    dlogits = empty_like(logits)
    dlogits = dlogits.at[cholslc].set(dlogits_chol)
    dlogits = dlogits.at[glutslc].set(dlogits_glut)

    # Sum synaptic inputs from each population
    Is = concat(
        [sum((gs[i,j] * s[preslc] * (V[postslc,np.newaxis] - Es_tuple[j])).sum(axis=-1)
             for j, preslc in enumerate(pop_slices))
         for i, postslc in enumerate(pop_slices)],
        axis=-1
    )
    # jax.debug.print("Is: {}", Is)
    # jax.debug.print("dlogits: {}", dlogits)
    # # Alternative 1: In-place updates of a pre-allocated array
    # Is[:] = 0
    # for i, postslc in enumerate(pop_slices):
    #     for j, preslc in enumerate(pop_slices):
    #         Is[postslc] += (gs[i,j] * s[preslc] * (V[postslc,np.newaxis] - Es_tuple[j])).sum(axis=-1)
    # # Alternative 2: Construct a block matrix. The extra allocation of
    # # memory makes this slower, but can be useful for diagnosis
    # blocks = [[gs[i,j] * s_tuple[j] * (V[postslc,np.newaxis] - Es_tuple[j])
    #            for j, preslc in enumerate(pop_slices)]
    #           for i, postslc in enumerate(pop_slices)]
    # Is = np.block(blocks).sum(axis=-1)

    
    # # Asserts serve as documentation, and can be uncommented to re-validate code after changes.
    # _A = np.arange(50)
    # assert Ii.shape == (nchannels, len(V)), "Conductance input matrix doesn’t have the right shape."
    # assert np.array(Is_chol.shape) + np.array(Is_glut.shape) == (len(V),), "Synaptic input matrix doesn’t have the right shape."
    # assert np.all(concat([_A[slc] for slc in pop_slices]) == _A[:len(V)]), "Pop slices are not contiguous or don’t add up to the total number of neurons."
    # assert np.all(concat([_A[slc] for slc in syn_slices]) == _A[:len(V)]), "Synapse type slices are not contiguous or don’t add up to the total number of neurons."
    # assert ds_chol.shape == (2,)
    
    # Compute dV by summing inputs over channels
    dV = -Ii.sum(axis=-2) - Is
    # Add external inputs
    if I_ext:
        dV -= I_ext(t)
    # Add contributions currents from electrical synapses
    #dV[..., elec_slice] -= Ie
    dV = dV.at[..., elec_slice].add(-Ie)   # JAXMOD
    # Add leak currents
    for pop_gleak, slc in zip(gleak, pop_slices):
        #dV[...,slc] -= pop_gleak * (V[...,slc] - Eleak)
        dV = dV.at[...,slc].add(- pop_gleak * (V[...,slc] - Eleak) )  # JAXMOD

    # Combine all derivatives and return
    return concat((dV.reshape(1,-1), dlogCa.reshape(1,-1), dlogits.reshape(1,-1), dlogitm, dlogith)).reshape(-1)

Initialization#

For the experiments, we want to initialize models in a steady state. To find this steady state, we first need to run the model, generally for much longer than the amount of simulation time we need for the experiment itself. We definitely don’t want to do this every time, so after simulating a model, we store the initial state in a cache on the disk. Thus we distinguish between two types of initialization:

Cold initialization

This is a fixed initial state used for all models, when no previous simulations are available. It is used for the thermalization run, which integrates the model until it reaches a steady state.

Thermalized (“warm”) initialization

The final state of the initialization run is the thermalized initialization. Subsequent calls with the same model will retrieve this state from the cache.

During the thermalizaton run, the model is disconnected (all connectivities gs are set to 0) and receives no external input; this mirrors the procedure followed by Prinz et al. [2003]. If we did not disconnect neurons, many of them would never reach a steady state since the circuit is designed to spontaneously oscillate. Moreover, one would need to generate and store a different initialization for each circuit, rather than for each neuron, which is combinatorially many more.

Cold initialization values are provided by on p. 4001 of Prinz et al. [2003]:[6]

Table 9 Cold initialization#

V

[Ca2+]

m

h

s

50mV

0.05μM

0

1

0

There are five main technical differences between our procedure and that of Prinz et al. [2004]:

  • For numerical stability, we track the values of logitm, logith and logits. This means we can’t initialize them exactly at 0 or 1; instead we use logitm=logits=10 and logith=10, which correspond to approximately m=s=105 and h=105.

  • Prinz et al. [2004] set s=0 after the thermalization. This is presumably because the original neuron model catalog did not include simulations of s. In our case, we set s=0 for the cold initialization, and use its subsequent thermalized value for the data run. This is both simpler on the implementation side, and more consistent with the desire to let all spontaneous transients relax before connecting neurons.

  • We don’t first compute the thermalization for individual model neurons separately, but instead recompute it for each combination of neuron models. (Specifically, each thermalization run is identified by a g_cond matrix – gs and Iext are ignored, since they are set to zero during thermalization.) If we were to simulate the entire catalog of neuron model combinations this would be wasteful, but since we only need a handful, this approach is adequate and simpler to implement.

  • Instead of fixed-step Euler integration, we use the Runge-Kutta 4(5) algorithm with adaptive step sizes included in scipy.integrate. This is not only more efficient than Euler, but also allows one to estimate numerical errors. (Although the integration scheme is currently hard-coded, since we are using standard scipy integrators, it would be easy enough to change it or make it user-configurable. Indeed, while standard, RK45 is not necessarily the best choice for Hodgkin-Huxley systems since they are moderately stiff.)

  • Finally, instead of detecting the steady state automatically, we use a fixed integration time and rely on visual inspection to determine whether this is enough to let transients decay in all considered models. Again we can do this because we only need to simulate a handful of models.

The main practical difference is that instead of pre-computing a neuron catalogue, the thermalization is done automatically and on-demand. For the user therefore it makes no difference whether a warm initialization for a particular circuit is available or not: when they request a simulation, if no warm initialization is available, they just need to wait longer to get their result. Subsequent runs then reuse the warm initialization computed on the first run.

Implementation

  • The cold initialization is given by the class method State.cold_initialized.

  • The warm-up simulation is implemented in the method Prinz2004.get_thermalization.

Three class attributes of Prinz2004 are used to control the behaviour of the warm-up simulation:

  • __thermalization_store__ determines where the cache is stored on disk.

  • __thermalization_time__ is the warm-up simulation time; it is currently set to 5s.

  • __thermalization_time_step__ is the recording time step for the warm-up simulation. This is only relevant for inspecting the warm-up run. (The integrator uses an adaptive time step and discards non-recorded steps.)

Public API#

State object#

In order to integrate the model equations, we need to know not just the membrane potential of each neuron, but also its calcium concentration, synaptic variable (s) and activation and inactivation variables (m, h). These are stored together in a State object, which also provides

  • automatic conversion to/from log-transformed values;

  • convenience methods for converting between different storage layouts, for example concatening and flattening all variables for the ODE integrator.

@dataclass
class State:
    V     : Array[float, 1]  # (n_neurons)
    logCa : Array[float, 1]  # (n_neurons)
    logits: Array[float, 1]  # (n_neurons)
    logitm: Array[float, 2]  # (n_neurons, n_channels)
    logith: Array[float, 2]  # (n_neurons, n_channels)

    """Storage container for simulator variables.

    - Handles automatic conversion to/from log-transformed values.
    - Provides methods for converting between storage layouts, in particular for
    exporting to the 1-d vector format required by ODE integrators.
    """

    @classmethod
    def cold_initialized(cls, n_neurons):
        """Default state initialization. Used to initialize the thermalization run."""
        return cls(
            -50   * jnp.ones(n_neurons),
            jnp.log(0.05) * jnp.ones(n_neurons),
            -10 * jnp.ones(n_neurons),   # “fully deactivated” would be s=0, which is not possible with the logit, but this should be close enough
            -10 * jnp.ones((n_neurons, len(channels))),  # “fully deactivated”
            10 * jnp.ones((n_neurons, nhchannels)),   # “fully activated”
        )   

    @classmethod
    def from_array(cls, X: Union[Array[float, 1], Array[float, 2]]) -> "Prinz2004.State":
        """Unpack a flattened state vector."""
        nchannels = len(channels)
        if X.ndim == 1:
            X = X.reshape(3+nchannels+nhchannels, -1)
        elif X.ndim != 2:
            raise ValueError(f"State array `X` has {X.ndim} dimensions, when it should have either 1 or 2.")
        return cls(
            jnp.array(X[0]), jnp.array(X[1]), jnp.array(X[2]),
            jnp.array(X[3:3+nchannels]), jnp.array(X[3+nchannels:3+nchannels+nhchannels]),
        )
    def to_vector(self):
        """Return a flat 1-d vector, as is used in `solve_ivp`."""
        return jnp.concatenate((self.V, self.logCa, self.logits, self.logitm.flatten(), self.logith.flatten()))
    def to_array(self):
        """
        Return a 2-d array of shape (3+2*n_channels, n_neurons)
        First three rows are `V`, `logCa`, `logits` respectively.
        Then `m` and `h` as two blocks of (n_channels, n_neurons).
        """
        return jnp.concatenate((
            self.V[np.newaxis,:], self.logCa[np.newaxis,:], self.logits[np.newaxis,:],
            self.logitm, self.logith))

    @property
    def s(self):
        return 1 / (1 + 1/np.exp(self.logits))  # Inverse of the logit function  (NB: expression chosen to be well defined even when exp(logits) -> oo)
    @s.setter
    def _(self, s):
        self.logits = np.log(s/(1-s))
    @property
    def m(self):
        return 1 / (1 + 1/np.exp(self.logitm))
    @m.setter
    def _(self, m):
        self.logitm = np.log(m/(1-m))
    @property
    def h(self):
        return 1 / (1 + 1/np.exp(self.logith))
    @h.setter
    def _(self, h):
        self.logith = np.log(h/(1-h))
    @property
    def Ca(self):
        return np.exp(self.logCa)
    @Ca.setter
    def _(self, Ca):
        self.logCa = np.exp(Ca)

SimResult object#

The ODE integrators provided by scipy.integrate treat the data as a flat 1-D vector, which is not convenient for associating the trace of each component to the correct state variable. A SimResult stores the data returned from the integrator, in the integrator’s compact format, but provides a human-friendly interface for retrieving traces for individual variables. Conversions to/from log-transformed values are handled automatically.

@dataclass
class SimResult:
    t         : Array[float, 1]
    data      : Array[float, 3]
    pop_slices: List[slice]
    
    """Stores and makes accessible the results of a simulation run.

    Underlying storage is very close to the output of the ODE simulator.
    Properties for each variable (`V`, `logCa`, `Ca`, etc.) retrieve the correct rows from the data structure.
    """

    def __post_init__(self):
        self.data = self.data.reshape(3+len(channels)+nhchannels, -1, len(self.t))
        self.data = np.moveaxis(self.data, -1, 0)  # Both Pandas & Holoviews work better if time is the index axis
    
    def __getitem__(self, key) -> Union[SimResult,State]:
        if isinstance(key, (slice, list, np.ndarray)):
            return SimResult(t=self.t[key], data=self.data[key,:,:], pop_slices=self.pop_slices)
        elif isinstance(key, int):
            return self.t, State.from_array(self.data[key,:,:])
        else:
            raise TypeError("SimResult` only supports 1-d indexing along the time axis.")
    
    @property
    def V(self): return self._make_df(self.data[:, 0, :])
    @property
    def logCa(self): return self._make_df(self.data[:, 1, :])
    @property
    def Ca(self): return np.exp(self.logCa)
    @property
    def logits(self): return self._make_df(self.data[:, 2, :])
    @property
    def s(self): return 1 / (1 + np.exp(-self.logits))
    @property
    def logitm(self): return self.data[:, 3:3+len(channels), :]
    @property
    def m(self): return 1 / (1 + np.exp(-self.logitm))
    @property
    def logith(self): return self.data[:, 3+len(channels):3+len(channels)+nhchannels, :]
    @property
    def h(self): return 1 / (1 + np.exp(-self.logith))

    def _make_df(self, data):
        cols = pd.MultiIndex.from_tuples(
            ((pop, i)
             for pop, slc in self.pop_slices.items()
             for i in range(1, 1+slc.stop-slc.start)),
            names=["pop", "index"]
        )
        return pd.DataFrame(data, index=pd.Index(self.t, name="time"), columns=cols)

Prinz2004 object#

This is the core simulator class. To perform a simulation, create an instance of this class and call its integrate method.

@dataclass(frozen=True)
class Prinz2004:
    pop_sizes  : dict[str, int]
    gs         : Array[float, 2]
    g_ion: Optional[Array[float, 2]]=None
    ge         : float=1.  # Currently only used to turn on/off electrical connectivity

    """ Core simulator class.

    To perform a simulation, create an instance of this class and call its `integrate` method.
    """
    
    # Private attributes
    __thermalization_store__ : ClassVar[Path] = config.paths.simresults/"prinz2004_thermalize"
    __thermalization_time__  : ClassVar[float] = 5000.   # Warm initialization is the state after this many seconds of spontaneous activity (no inputs, no synaptic connections)
    __thermalization_time_step__: ClassVar[float] = 1.   # Warm initialization trace is recorded with this time step
        # TODO: Instead of a fixed time step, using an `events` callback to detect spikes could substantially reduce the data we need to store
    _thermalization_store_lock: ClassVar[Lock] = Lock(str(__thermalization_store__.with_suffix(".lock")))
    _thermalization_store_lock.lifetime = timedelta(seconds=15)  # 15s is the default
        
    def __post_init__(self):
        # For convenience, we allow passing `g_ion` as a DataFrame
        if isinstance(self.g_ion, pd.DataFrame):
            object.__setattr__(self, "g_ion", self.g_ion.to_numpy())
        object.__setattr__(self, "gs", np.asarray(self.gs))

    ## Public API ##
    
    def __call__(self, t_array, I_ext: Optional[Callable]=None):
        X0 = self.get_thermalization()
        res = self.integrate(0, X0, t_array, I_ext)
        return SimResult(t_array, res.y, self.pop_slices)
    
    def derivative(self, t: float, X: State, I_ext: Optional[Callable]=None):
        """
        Evaluate the model equations, returning the derivative at `t` if the state is `X`.
        This is a convenience method; it is equivalent to the code used to integrate, but not the same.
        """
        X0 = X
        # C.f. self.integrate()
        if isinstance(X0, State):
            X0_flat = X0.to_vector()
        else:
            X0_flat = X0.flatten()
            X0 = self.State.from_vector(X0)
        pop_slices = tuple(slc.__reduce__() for slc in self.pop_slices.values())
        syn_slices = (self.syn_slices["chol"].__reduce__(), self.syn_slices["glut"].__reduce__())
        elec_slice = self.elec_slice.__reduce__()
        args=(self.g, self.gleak, self.gs, self.ge,  # Conductances
              pop_slices, syn_slices, elec_slice,    # Population slices
              self.E(X0.Ca),                         # Expanded from constants.E (except E[[1,2],:], corresponding to ECa, which is updated in-place)
              self.tot_cells,                        # Pre-computed number of neurons
              I_ext)                                 # External input, if provided
        return dX(t, X0_flat, *args)

    ## Semi-private methods ##
    
    # Even though the dataclass is frozen, changing the init_store is allowed (at worst, it just means the thermalization is recomputed)
    @classmethod
    def clear_thermalization_store(cls):
        cls.__thermalization_store__.unlink(missing_ok=True)
    
    def integrate(self, t0: float, X0: Union[Array, Prinz2004.State], t_eval: ArrayLike,
                  I_ext: Optional[Callable]=None) -> OdeResult:
        """
        Integrate the model from the initial state `(t0, X0)`. The result will contain the states
        at all values in `t_eval`.
        
        Args:
            t0: time corresponding to X0
            X0: Initial state; either an instance of Prinz2004 or a flat vector.
            t_eval: The time points at which to record the trace
            I_ext: If provided, this should be a function with the signature ``(t) -> I``, where ``I``
               is a 1-d vector with one element for every neuron. The units of I are nA; its value is
               added directly to the derivative dV.
               Alternatively, if all neurons should receive the same input, ``I`` can be a scalar.
               (In case we do support a vectorized `dX` later, if the function accepts a vector as input,
               the result should be of shape ``(time x neurons)``.
        """
        if isinstance(X0, State):
            X0_flat = X0.to_vector()
        else:
            X0_flat = X0.flatten()
            X0 = self.State.from_vector(X0)
        
        t_eval = np.sort(t_eval)
        T = t_eval[-1]
        
        # NB: Order is important for syn_slices
        pop_slices = tuple(slc.__reduce__() for slc in self.pop_slices.values())
        syn_slices = (self.syn_slices["chol"].__reduce__(), self.syn_slices["glut"].__reduce__())
        elec_slice = self.elec_slice.__reduce__()
        
        res = integrate.solve_ivp(
            dX, (t0, T), X0_flat, method="RK45", t_eval=t_eval,
            args=(self.g, self.gleak, self.gs, self.ge,  # Conductances
                  pop_slices, syn_slices, elec_slice,    # Population slices
                  self.E(X0.Ca),                         # Expanded from constants.E (except E[[1,2],:], corresponding to ECa, which is updated in-place)
                  self.tot_cells,                        # Pre-computed number of neurons
                  I_ext),                                # External input, if provided
            first_step=0.005)#constants.Δt)
        
        if res.status < 0:
            raise RuntimeError(f"Integration failed with the following message:\n{res.message}")
            #logger.error(res.message)
        return res
    
    @property
    def g_cond(self):
        if self.g_ion is not None:
            # Explicitely assigned g_cond
            return pd.DataFrame(self.g_ion,
                                index=self.pop_model_list, columns=g_cond.columns)
        else:
            return g_cond.loc[self.pop_model_list]
    @property
    def g(self):
        g_no_leak = self.g_cond.filter(regex=r"mathrm{(?!leak)")
        #_pop_lst = [self.pop_models["AB/PD"], pop_models["LP"], pop_models["PY"]]
        return g_no_leak.to_numpy()
    @property
    def gleak(self):
        return self.g_cond.loc[:, '$g(I_\mathrm{leak})$'].to_numpy()
    
    def E(self, Ca):
        E = jnp.tile(constants.E[:,np.newaxis], (1, self.tot_cells))
        E = E.at[[1,2],:].set(constants.γ*np.log(constants.Caout/Ca))
        return E
    
    @property
    def tot_cells(self):
        return sum(self.pop_sizes.values())
    @property
    def pop_model_list(self):
        """
        Returns a list of the form ``["AB 3", "PD 3", "LP 2", "PY 4"]``, where values
        correspond to model labels used in `g_cond`. They may omit numbers, if those are
        also omitted in `g_cond`.
        """
        return list(self.pop_slices.keys())
        # # CODE FOR LATER: The following was originally used to merge AB and PD pop labels
        # # AB 1 or PD 1 -> AB/PD 1  (where '1' is any number)
        # slcs = self.create_slices(
        #     self.pop_sizes, label_sub={r"(?:AB|PD)( \d+)?": r"AB/PD\1"})
        # return list(slcs.keys())
    
    @property
    def pop_slices(self):
        # If pops are specified as AB/PD, split them AB: 1, PD: n-1
        pop_sizes = {}
        for name, size in self.pop_sizes.items():
            m = re.match("AB/PD( \d+)?", name)
            if m:
                pop_sizes[f"PD{m[1]}"] = size - 1
                pop_sizes[f"AB{m[1]}"] = 1
            else:
                pop_sizes[name] = size
        return self.create_slices(pop_sizes, {})
    
    @property
    def syn_slices(self):
        # If pops are specified as AB/PD, split them AB: 1, PD: n-1
        pop_sizes = {}
        for name, size in self.pop_sizes.items():
            m = re.match("AB/PD( \d+)?", name)
            if m:
                pop_sizes[f"PD{m[1]}"] = size - 1
                pop_sizes[f"AB{m[1]}"] = 1
            else:
                pop_sizes[name] = size
        # When creating the slices:
        # AB, LP, PY -> glut    PD -> chol
        slices = self.create_slices(
            pop_sizes, label_sub={r"(?:AB|LP|PY)( \d+)?": r"glut",
                                  r"(?:PD)( \d+)?": r"chol"})
        # Corner case: If there are no cholinergic or glutamatergic neurons,
        #    add corresponding empty slices
        if "glut" not in slices: slices["glut"] = slice(0,0)
        if "chol" not in slices: slices["chol"] = slice(0,0)
        return slices
    @property
    def elec_slice(self):
        slcs = self.create_slices(
            self.pop_sizes, label_sub={r"(?:AB|PD).*": "elec"}, include=[r"AB|PD"])
        return slcs.get("elec", slice(0,0))  # If there are no neurons with electrical synapses, return an empty slice
  
    def get_thermalization(self) -> State:
        """
        The thermalization is obtained by integrating the network model
        with all connections to zero, to find the spontaneous steady-state of
        each neuron model.
        This method either performs that integration, or, if it was already
        done before, retrieves it from a cache.
        """
        sim = self.thermalize()
        t, X = sim[-1]  # Get the final state
        # Warm init is done with one neuron per population: reinflate the populations.
        # NB: The thermalization uses the merged AB/PD populations of pop_slices
        pop_sizes = tuple(slc.stop-slc.start for slc in self.pop_slices.values())
        return State.from_array(
            np.repeat(X.to_array(), pop_sizes, axis=-1)
        )
    
    def thermalize(self) -> SimResult:
        """
        Compute the initial simulation used to thermalize the model.
        - Constructs a surrogate model with all population sizes set to 1.
        - Sets all connectivities to 0.
        - Integrate
        
        In addition, this method manages the termalization cache, and will skip
        all steps if it finds a result already in the cache.
        
        .. TODO:: Currently the cache key is the entire model, which is a bit
           wasteful. It would be better to cache the results of each neuron
           separately.

        Returns a `SimResult` instance, with recording resolution determined by
        ``self.__thermalization_time_step__``.
        
        Normally this method is not used directly, but called by `get_thermalization`.
        The main reason to use it directly would be to plot the initialization
        run for inspection.
        """
        # # Pre-compute the set of pop_slices, so we can validate it against the disconnected model
        # # (Alternatively, we could create a disconnected model even when we find a pre-cached result, but that seems wasteful)
        # pop_slices = {}; i = 0
        # for pop in self.pop_slices:
        #     if pop.startswith("AB/PD"):
        #         pop_slices[pop] = slice(i,i+2); i += 2
        #     else:
        #         pop_slices[pop] = slice(i,i+1); i += 1
        # We thermalize the model by turning off synaptic connections,
        # so we also use a key which only depends on the model identities.
        init_key = str((("T", self.__thermalization_time__), ("g_cond", self.g_cond.to_csv())))
        if not self.__thermalization_store__.exists():
            self.__thermalization_store__.parent.mkdir(parents=True, exist_ok=True)
            warm_t_X = None
        else:
            with shelve.open(str(self.__thermalization_store__), 'r') as store:
                warm_t_X = store.get(init_key)
        if warm_t_X:
            t, warm_X = warm_t_X
        else:
            # NB: We use pop_slices instead of pop_sizes in order to merge AB/PD populations when possible
            # Get the cold initialized state (for a smaller model with 1 neuron / pop)
            X0 = State.cold_initialized(len(self.pop_slices))
            # Construct the array of times at which we will record
            init_T = self.__thermalization_time__
            Δt = self.__thermalization_time_step__
            t_eval = np.concatenate((np.arange(0, init_T, Δt), [init_T]))  # The final time init_T is always evaluated
            # Build the model with no connections and only one neuron per pop
            disconnected_model = Prinz2004(
                pop_sizes={pop: 1 for pop in self.pop_slices},
                gs=np.zeros_like(self.gs),
                g_ion=self.g_ion,
                ge=0)
            #assert pop_slices == disconnected_model.pop_slices
            # Integrate
            res = disconnected_model.integrate(0, X0, t_eval)
            t, warm_X = res.t, res.y
            # Update the cache
            with self._thermalization_store_lock:  # We use a lock file because shelve doesn't support concurrent writes (not doing this can corrupt the store)
                with shelve.open(str(self.__thermalization_store__)) as store:
                    store[init_key] = (t, warm_X)
        # Return
        return SimResult(t, warm_X,
                         pop_slices={pop: slice(i,i+1)
                                     for i,pop in enumerate(self.pop_slices)})
   
    ## Private methods ##
    
    @staticmethod
    def create_slices(pop_sizes,
                      label_sub: dict[str,str],
                      include: Sequence[str]=(),
                      exclude: Sequence[str]=()):
        """
        `label_sub` should be a substitution dictionary of ``pattern: repl`` pairs which
        are passed to `re.sub`. In other words, `pattern` should be a regex pattern,
        and when it matches a cell type name, it is replaced by `repl`.
        """
        slcs = {}
        i = 0
        # Convert population sizes to slices, by tracking an index counter i
        for name, size in pop_sizes.items():
            # NB: Don’t use `continue` in this loop, otherwise the `i` counter is lost
            skip = ((include and not any(re.search(pattern, name) for pattern in include))
                    or any(re.search(pattern, name) for pattern in exclude))
            for pattern, repl in label_sub.items():
                name = re.sub(pattern, repl, name)
            if skip:
                pass
            elif name in slcs:
                last_slc = slcs[name][-1]
                if i == last_slc.stop:
                    # This slice is contiguous with the previous one: merge them
                    slcs[name][-1] = slice(last_slc.start, i+size)
                else:
                    # Not contiguous: Add a new slice
                    slcs[name].append(slice(i, i+size))
            else:
                # New population model: Add a new slice
                slcs[name] = [slice(i, i+size)]
            i += size

        # Convert to simple slices when possible, otherwise indexing arrays
        iarr = np.arange(i)  # i counter is now exactly the total number of cells
        for name, slc_lst in slcs.items():
            if len(slc_lst) == 1:
                slcs[name] = slc_lst[0]
            else:
                # Multiple discontinuous arrays: must use indexing
                slcs[name] = np.concatenate([iarr[slc] for slc in slc_lst])

        return slcs

Anecdotal timings#

All timings use the circuit model described in the Circuit model section.

Table 10 Timing single evaluations of the dX function.#

Model size
(# neurons)

Hardware

NumPy

JAX

9

3.50GHz CPU
(Xeon W-2265, 12 core)

1.26 ms ± 13.3 µs

Table 11 Timing integration (simulation)#

Simulation time

Model size
(# neurons)

Hardware

Library

200 ms

1000 ms

9

3.10GHz CPU
(i5-5675C, 2 core)

NumPy

21.3 s

2 min 49 s

9

3.10GHz CPU
(i5-5675C, 2 core)

JaX

2.33 s

0 min 7.54 s