Pyloric network simulator#
This is an implementation of the three-cell model of crustacean stomatogastric ganglion described in Marder and Abbott [1998] and used by Prinz et al. [2004] to demonstrate that disparate parameter sets lead to comparable network activity. The main features of this particular implementation are:
- JAX-accelerated Python implementation
Combines the flexibility of Python with the speed of compiled JAX code.
The specification of update equations for ion channels was carefully designed to maximize vectorization.
- Choice of ODE solver
The official C++ implementation uses hard-coded Euler integration, which can make it vulnerable to numerical instabilities. (Neuron conductance-based models, after all, are stiff by design.)
In contrast, this implementation is designed to be used with Python’s ODE solvers, making it easy to use a more appropriate solver. In particular, adaptive step solvers can dramatically reduce simulation time, while still keeping enough time precision to resolve spikes.
- Modularity
New cell types are defined simply by extending Python arrays for the ion channel parameters.
If needed, an arbitrary function can also be given for specific ion channels.- Flexibility
You are not limited to the published three-cell network: studying a smaller two-cell network, or a larger one with different LP cells, is a simple matter of changing three arrays.
- All-in-one documentation
The original specification of the pyloric circuit model is spread across at least three resources[1].
Here all definitions are included in the inlined documentaiton, are fully referenced and use consistent notation.- Single source of truth
Both the documentation and the code use the same source for parameter values, so you can be sure the documentated values are actually those used in the code.
In cases where this was not practical, Python values are specified with a string identical to the source of the documentation’s Markdown table. Agreement between documentation and code can be checked at any time by “diff”-ing the markdown and Python code blocks, or copy-pasting one onto the other.
Show code cell source
import logging
import shelve
import re
import numpy as np
import pandas as pd
import holoviews as hv
from ast import literal_eval
from collections import namedtuple
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Union, Type, Generator, Callable, Tuple, ClassVar
from flufl.lock import Lock
from numpy import exp
from numpy.typing import ArrayLike
from scipy import integrate
from addict import Dict
from scityping.numpy import Array
logger = logging.getLogger(__name__)
Show code cell source
from config import config
try:
import jax
from jax import config as jax_config; jax_config.update("jax_enable_x64", True); del jax_config
import jax.numpy as jnp
except ImportError:
import jax_shim as jax
import jax_shim.numpy as jnp
Usage example#
import holoviews as hv
from pyloric_simulator.prinz2004 import Prinz2004, neuron_models, dims
hv.extension("bokeh") # Bokeh plots allow to zoom in on the trace much more easily
Instantiate a new model by specifying:
The number of neurons of each type (PD, AB, LP and PY)
The connectivity
gs
between populationsThe set of membrane conductances for each type
g_cond
.
The 16 sets which define the neuron models used in Prinz et al. [2004] are provided inneuron_models
.
# Note that cholinergic (PD) and glutamatergic (AB, LP, PY) neurons are contiguous.
# This allows the model to use more efficient indexing based on slices,
# instead of slower indexing arrays.
# `pop_size` keys should match one of the four type labels (AB, PD, LP, PY) or the
# combined label "AB/PD". If there are multiple neurons of one type, they can be
# followed by a number: "AB 1", "AB 2", ...
# IMPORTANT: The `Prinz2004` class uses regular expressions on these labels to determine
# the right type of connectivity, so they must follow this labelling pattern exactly.
model = Prinz2004(pop_sizes = {"PD": 2, "AB": 1, "LP": 1, "PY": 5},
gs = [ [ 0 , 0 , 3 , 0 ],
[ 0 , 0 , 3 , 3 ],
[ 3 , 3 , 0 , 3 ],
[ 3 , 3 , 3 , 0 ] ],
g_ion = neuron_models.loc[["AB/PD 1", "AB/PD 1", "LP 1", "PY 1"]] # May be specified as DataFrame or Array
)
Define a set of time points and evaluate the model at those points. A model simulation is always initialized with the result of a cached warm-up simulation (see Initialization). If no such simulation matching the model parameters is found in the cache, it is performed first and cached for future runs.
The example below reproduces the procedure of Prinz et al. [2004]: After neurons are connected, an initial simulation of 3s is thrown away to let transients decay. Then 1s is simulated to generate the data for that network model.
Note
The time points we provide to model
are only those which are recorded; they have no bearing on the integration time step.[2] To resolve spikes, a time resolution of 1 ms or less is needed.
# Generate 1001 sequential time points (3000–4000), and evaluate the model at those points
res = model(np.linspace(3000, 4000, 1001))
Results are returned as a SimResult
object, which has attributes to retrieve the different traces: membrane voltage V
, calcium concentration Ca
, synaptic activation s
, membrane activation m
and membrane inactivation h
.
These are returned as Pandas DataFrames.
TODO
The m
and h
variables for now are still returned as plain arrays.
# Retrieve trace for one neuron per population
Vtraces = res.V.loc[:,[("AB", 1), ("LP", 1), ("PY", 1)]].droplevel("index", axis="columns")
# In addition to voltage, the following time-dependent variables are also available:
res.Ca # Intracellular calcium concentration (n_time_bins x n_neurons)
res.s # Synapse activation (n_time_bins x n_neurons)
res.m # Channel activation (n_time_bins x n_channels x n_neurons)
res.h # Channel inactivation (n_time_bins x n_channels x n_neurons)
# A plain time-bin array can also be retrieved
res.t
array([3000., 3001., 3002., ..., 3998., 3999., 4000.])
For simple plots, one can use the DataFrame plot
method directly.
For more complex plots, it is often more convenient to retrieve the data and call the plotting functions directly.
Vtraces.plot("AB") # Simple plot using the DataFrame method
# Call plotting functions directly after retrieving data
hv.Curve(Vtraces.loc[:,"AB"]) + hv.Curve(Vtraces.loc[:,"LP"]) + hv.Curve(Vtraces.loc[:,"PY"])
To inspect the initialization curves, we use the private method thermalize
; this is used internally to generate the thermalized state. It returns a SimResult
object.
(NB: Since during thermalization, we disconnect all neurons to get each’s individual spontaneous activity, it is only necessary to simulate one neuron per population. This is why the result returned by thermalize
always has population sizes of 1.)
res = model.thermalize()
Vtraces = res.V.droplevel("index", axis="columns") # The initialization run has only one neuron per pop, since they are identical
hv.output(backend="bokeh") # Workaround: If we don’t switch the default, Bokeh ignores the new colors
ov = hv.Overlay([hv.Curve(Vtraces.loc[:,pop_name], label=pop_name)
for i, pop_name in enumerate(model.pop_model_list)]) \
.opts(ylabel="voltage (mV)") \
.redim(time=dims.t)
# Assign trace colours
for pop_name, c in zip(res.pop_slices, hv.Cycle("Dark2").values):
ov.opts(hv.opts.Curve(f"Curve.{hv.core.util.sanitize_identifier_fn(pop_name)}", color=c))
ov.opts(width=800, legend_position="right", backend="bokeh")
Definitions#
Show code cell source
# IMPORTANT: List orders must be kept consistent, because for efficiency the
# implementation uses plain arrays, vectorized ops and indexing by position
act_varnames = ["minf", "hinf", "taum", "tauh"]
channels = ["INa", "ICaT", "ICaS", "IA", "IKCa", "IKd", "IH"]
neuron_types = ["PD", "AB", "LP", "PY"] # Putting PD first allows us to have all glutamatergic neurons in one contiguous slice
neuron_pops = ["AB/PD", "LP", "PY"]
dims = dict(
minf = r"$m_\infty$",
hinf = r"$h_\infty$",
taum = r"$τ_m$",
tauh = r"$τ_h$",
INa = r"$I_\mathrm{Na}$",
ICaT = r"$I_\mathrm{CaT}$",
ICaS = r"$I_\mathrm{CaS}$",
IA = r"$I_\mathrm{A}$",
IKCa = r"$I_\mathrm{K(Ca)}$",
IKd = r"$I_\mathrm{Kd}$",
IH = r"$I_\mathrm{H}$",
V = r"$V$",
Es = r"$E_\mathrm{s}$",
km = r"$k_{-}$",
Vth = r"$V_{\mathrm{th}}$",
Δ = r"$\Delta$",
t = "time"
)
dims = Dict({nm: hv.Dimension(nm, label=label) for nm, label in dims.items()})
dims.V.unit = "mV"
dims.t.unit = "ms"
Single compartment conductance model#
The conductance model is described in Prinz et al. [2003] (p.1-2, §Model). Where differences in notation occur, we prefer those in Prinz et al. [2004].
In the equations,
Note
Although a concentration must always be positive, the differential equation for
A similar thing can be said of
Constant low-dimensional parameters#
unit |
|||||||||
---|---|---|---|---|---|---|---|---|---|
50 |
-80 |
-80 |
-80 |
-20 |
-50 |
||||
3 |
3 |
3 |
3 |
4 |
4 |
1 |
Calcium reversal potentials
where Nernstfactor
is the published source code), and following Prinz et al. [2003], further set
Constant |
Value |
Unit |
---|---|---|
0.628 |
||
200 |
||
14.96 |
||
0.05 |
||
3000 |
||
12.2 |
||
approximate numerical time step |
0.025 |
If we multiply all the units en Eq. (1) together, using
Maximum channel conductances#
These are the channel conductance values (
Differences in these values is what differentiates neuron models.
Model neuron |
||||||||
---|---|---|---|---|---|---|---|---|
AB/PD 1 |
400 |
2.5 |
6 |
50 |
10 |
100 |
0.01 |
0.00 |
AB/PD 2 |
100 |
2.5 |
6 |
50 |
5 |
100 |
0.01 |
0.00 |
AB/PD 3 |
200 |
2.5 |
4 |
50 |
5 |
50 |
0.01 |
0.00 |
AB/PD 4 |
200 |
5.0 |
4 |
40 |
5 |
125 |
0.01 |
0.00 |
AB/PD 5 |
300 |
5.0 |
2 |
10 |
5 |
125 |
0.01 |
0.00 |
LP 1 |
100 |
0.0 |
8 |
40 |
5 |
75 |
0.05 |
0.02 |
LP 2 |
100 |
0.0 |
6 |
30 |
5 |
50 |
0.05 |
0.02 |
LP 3 |
100 |
0.0 |
10 |
50 |
5 |
100 |
0.00 |
0.03 |
LP 4 |
100 |
0.0 |
4 |
20 |
0 |
25 |
0.05 |
0.03 |
LP 5 |
100 |
0.0 |
6 |
30 |
0 |
50 |
0.03 |
0.02 |
PY 1 |
100 |
2.5 |
2 |
50 |
0 |
125 |
0.05 |
0.01 |
PY 2 |
200 |
7.5 |
0 |
50 |
0 |
75 |
0.05 |
0.00 |
PY 3 |
200 |
10 |
0 |
50 |
0 |
100 |
0.03 |
0.00 |
PY 4 |
400 |
2.5 |
2 |
50 |
0 |
75 |
0.05 |
0.00 |
PY 5 |
500 |
2.5 |
2 |
40 |
0 |
125 |
0.01 |
0.03 |
PY 6 |
500 |
2.5 |
2 |
40 |
0 |
125 |
0.00 |
0.02 |
Voltage-dependent activation variables#
Note
When
, the values of and are irrelevant, but we should not use to avoid dividing by zero.The
current is not used for , and . Because we use vectorized operations, the computations are still performed, but the result is afterwards discarded. (h_slice
is used to select only those channels with an variable.)The variable
is defined purely in terms of . We do this by setting and .
To implement the voltage equations in code, we write them as
where
Cell In[21], line 1
```{glue:figure} tbl_act_params
^
SyntaxError: invalid syntax
Only the first four ion channels have an h_slice
to select them.
h_slice = slice(0, 4)
nhchannels = 4 # Number of ion channels which have an h variable
Caution
The units for the denominator of
def y(V, Ca, exp=jnp.exp, array=jnp.array, ones_like=jnp.ones_like):
"""
Returns a list of multipliers for the voltage equations.
The indices in the (variable x channel) matrix to which each multiplier corresponds
are stored in `y.indices`. The format of `y.indices` is compatible with advanced
indexing, so one can do ``A[(*y.indices,...)] *= y(V, Ca)``.
"""
V = V.reshape(-1)
return array([
1.34/(1+exp(-(V+62.9)/10.)),
2.8 + 14./(exp((V+27.)/10.) + exp((V+70.)/-13.)),
120. + 300./(exp((V+55.)/9.) + exp((V+65.)/-16.)),
Ca/(Ca+3.) * ones_like(V),
2./(exp((V+169.7)/-11.6) + exp((V-26.7)/14.3))
])
y.indices = ([act_varnames.index("tauh"), act_varnames.index("taum"), act_varnames.index("tauh"), act_varnames.index("minf"), act_varnames.index("taum")],
[channels.index("INa"), channels.index("ICaS"), channels.index("ICaS"), channels.index("IKCa"), channels.index("IH")])
def act_vars(V, Ca, y=y, exp=jnp.exp, a=act_params.a[...,np.newaxis], b=act_params.b[...,np.newaxis], c=act_params.c[...,np.newaxis], C=act_params.C[...,np.newaxis]):
"""
Returns an array of dynamic activation variables used in the expressions for the different currents
with shape (channel x variable).
Variables are: minf, hinf, τm and τh.
"""
res = a/(1 + exp((V+b)/c)) + C # σ() + C
#res[(*y.indices, ...)] *= y(V, Ca) # (σ() + C) * f(V, Ca) — Indexing + in-place operation ensure that only the 5 elements with a multiplier are updated
res = res.at[(*y.indices, ...)].multiply(y(V, Ca)) # JAXMOD
return res
Note
A few activation functions are actually not defined (e.g.
Electrical synapse model#
The AB and PD neurons are connected via an electric synapse [Prinz et al., 2004], which is given as Eq. (13) in Marder and Abbott [1998]:[4]
Unfortunately Prinz et al. [2004] do not seem to document the value of V[:,newaxis] - V
, which yields the following antisymmetric matrix
Chemical synapse model#
The chemical synapse model is defined in Prinz et al. [2004] (p.1351, §Methods) and Marder and Abbott [1998] (Eqs. (14,18,19)) Here again, in case of discrepancy, we use the notation from Prinz et al. [2004].[5]
Note on numerical stability
These are the equations as reported in Prinz et al. [2004], but as-is they are numerically unstable because
In the implementation below, we mitigate this issues in two ways:
We actually track
, which is the logit transform of . This has the advantage of being unbounded, so applying a discrete update will never produce an invalid value of . The logit function is also monotone and thus invertible; its inverse is , and its derivative .We add a small
to the denominator to ensure it is never exactly zero:
Memory layout of synapse variables
We have two types of synapses, glutamatergic and cholinergic, and an undetermined number of neurons. To take advantage of vectorized operations, we group synapse variables according to synapse type: since parameters are then constant within each group, we can treat them as scalars, and broadcasting works with any size of voltage vector
The
Here
Note that the implementation requires two different splittings of the set of neurons:
By population:
AB/PD
LP
PY
By synapse type:
cholinergic (PD)
glutamatergic (AB/LP/PY)
By choosing to place PD neurons first, we ensure that the synapse subsets are contiguous, which improves computational efficiency.
The conductivity
The synapse parameters are chosen to match the inhibitory postsynaptic potentials (IPSPs) generated in the postsynaptic cell, which in turn depend on the type of neurotransmitter released by the presynaptic cell. For the purposes of the implementation, this means that these values depend on the type of presynaptic cell.
Cell type |
Neurotransmitter |
||||
---|---|---|---|---|---|
pre |
glutamate |
-70 |
1/40 |
-35 |
5 |
pre = PD |
acetylcholine |
-80 |
1/100 |
-35 |
5 |
In general, if we have
syn_constants
Circuit model#
The circuit used in Prinz et al. [2004] is composed of three neuron populations:
Population name |
Cell names |
Number of cells |
---|---|---|
AB/PD |
anterior bursting |
AB: 1 |
LP |
lateral pyloric |
1 |
PY |
pyloric |
5–8 |
In Prinz et al. [2004] the AB and PD are lumped together into a single model, with the only difference being that the AB cell has glutamatergic synapses while the PD cells have slower cholinergic synapses.
Fig. 1 Circuit diagram of the pyloric network, reproduced from Fig. 1 of Prinz et al. [2004].#
Network connectivity is determined by the synapse strength
0 nS |
1 nS |
3 nS |
10 nS |
30 nS |
100 nS |
Implementation of #
The evolution equations are implemented as a function dX
for use with SciPy’s ODE solvers. Since these integrate only a single vector argument, we concatenate
Solving ODE requires a lot of function calls, and in this case the overhead of Python function calls does start to bite. To mitigate this, we pass all global constants and result arrays as arguments:
Passing global constants reduces the time required to look up those values, since the local scope is searched before the global scope. Binding functions like
np.concatenate
toconcat
similarly saves the lookup ofnp
in the global namespace, and the subsequent lookup ofconcatenate
in thenp
namespace.Passing result arrays avoids the need to call
np.array
, which relative to other operations can have high overhead since it needs to allocate a new contiguous block of memory, as well as inspect the arguments for correctness. In practice this means things like the following:Instead of
concat((E[:i], ECa, E[i+2:]))
, we doE[i:i+2] = ECa
. This saves both the cost of allocating a new array, and the overhead of all other checksconcat
performs.
The dX
implementation is written such that as many operations as possible are performed outside the function (memory allocation, model value retrieval, etc.). The best way to use it is through the public API below, which combines all these setup operations and takes care of passing all those arguments defined for optimization purposes. It also provides packing/unpacking of the X
vector.
Memory layout of
The
Any code which receives X.reshape(-1,N)
. Reshaping is an extremely cheap operation since it doesn’t require moving any data, and accessing either rows or columns is fast thanks to NumPy’s efficient indexing.
Possibilities for future optimization
If you like the code below, but it isn’t quite fast enough for your needs, here are some possibilities.
- Vectorize along time axis
As it is,
dX
expectst
to be a scalar andX
to be a 1-d array. If it were reworked to accept an additional batch dimension (sot
of shape(T,)
andX
of shape(T,n)
), then we could pass thevectorized=True
option tosolve_ivp
, allowing it to compute multiple time points at once and reducing function call overhead. The main challenge here is thatact_fn
currently only accepts arbitrary numbers of dimensions on the right, but the left-most one must be channels. This is partly for efficiency considerations. One would either need to adaptact_fn
, or reorder the data to put the channel dimension before time. A second, smaller challenge is finding an efficient way to update theE
array when the size of the time dimension is unknown in advance.- Optimize the linear algebra operations
For simplicity, when evaluating currents, the implementation uses expressions of the form $
g u v N \times N u v^T$ as an intermediate step. Reworking the expression to use.dot
would like provide some improvements, although as long as the number of neurons is small, this improvement may be modest.- Rewrite using JAX
If one really needs C-level performance, the most effective approach (in terms of gain/effort) may be be to rewrite this function with JAX. This should be much simpler than porting to C, since the API of JAX is by design nearly identical to that of NumPy. Moreover, by keeping the code in terms of high-level linear algebra operations, we leave the optimization to JAX, which is likely to do a better job than any custom C-code we write. In particular, since all of the costly operations are standard linear algebra, one could use JAX’s ability to compile for GPU, and possibly obtain a simulator capable of simulating very large numbers of neurons on a single workstation.
- Rewrite as a C function
SciPy has an object called a
LowLevelCallable
, which can be used to pass callback functions to SciPy functions which don’t carry the overhead of a Python function. This is useful for things like integrators, where the number of calls is large compared to the evaluation time of a single function call. However doing this is a lot of work, and we lose the possibility to use NumPy’s array broadcasting to keep code at a relatively high level. Rewriting in JAX is likely to be much faster, and might allow use of additional resources like GPUs.
Show code cell source
Index = Union[int,None]
SliceReduceVal = Tuple[Type[slice], tuple[Index,Index,Index]]
Show code cell source
@partial(jax.jit,
static_argnames=["pop_slices", "syn_slices", "elec_slice",
"nchannels", "n_neurons", "I_ext",
#"γ", "Caout", "Eleak", "p",
#"f", "Ca0", "τCa",
"idx_ICaT", "idx_ICaS",
#"Es_tuple", "km_chol", "km_glut", "Vth_chol", "Vth_glut", "Δ_chol", "Δ_glut", "ε",
]
)
def dX(t, X,
## Extra parameters which should be passed with the `args` argument of the ODE solver ##
g, # Shape: n_pops x n_channels (3 x 7) <-- DOES NOT INCLUDE LEAK
gleak,
gs, # Shape: n_pops x n_pops (3 x 3)
ge, # Electrical conductance is generally fixed to 1, except when we need to remove all connections to compute spontaneous activity
pop_slices: Tuple[SliceReduceVal], # Slices for each of the three neuron populations AB/PD, LP and PY, in that order
syn_slices: Tuple[SliceReduceVal], # Slices for selecting neurons according to their synapse type: cholinergic (PD) and glutamatergic (AB, LP, PY), in that order
# NOTE: We exploit the fact that types & populations can be ordered so that all subsets are contiguous blocks.
# This improves efficiency but is not essential: for non-contiguous blocks, we could use arrays to specify an advanced index
elec_slice: SliceReduceVal, # Slice for selecting the neurons which have electrical synapses (i.e. AB/PD)
E, # constants.E.reshape(..., n_neurons)
#Ii, # empty(n_channels, n_neurons)
#Is, # empty(n_neurons)
n_neurons, # Used for unpacking X
I_ext = None,
*,
## Below are constants, pulled from the global module variables. They should not be used as arguments ##
# X unpacking
nchannels = len(channels), # Number of non-leak ion channels (7)
nhchannels = nhchannels, # Number of channels with an h variable (4)
h_slice = h_slice,
# E_Ca - Eq (3)
γ = constants.γ,
logCaout = jnp.log(constants.Caout),
# dI - Eq (1)
Eleak = constants.Eleak,
p = constants.p[...,np.newaxis],
# dCa - Eq (2)
f = constants.f,
Ca0 = constants.Ca0,
τCa = constants.τCa,
idx_ICaT = channels.index("ICaT"),
idx_ICaS = channels.index("ICaS"),
# Ie - Eq (5)
# Is - Eq (6)
Es_tuple = tuple(syn_constants.loc[["chol", "glut", "glut", "glut"], "Es"]), # FIXME: Hard-coded pops. Follows order in syn_slices
km_chol = syn_constants.loc["chol", "km" ],
km_glut = syn_constants.loc["glut", "km" ],
Vth_chol = syn_constants.loc["chol", "Vth"],
Vth_glut = syn_constants.loc["glut", "Vth"],
Δ_chol = syn_constants.loc["chol", "Δ" ],
Δ_glut = syn_constants.loc["glut", "Δ" ],
ε = 1e-8, # Added to a denominator that may be zero
# functions
clip = jnp.clip,
concat = jnp.concatenate,
empty_like = jnp.empty_like,
exp = jnp.exp,
log = jnp.log,
minimum = jnp.minimum,
newax = jnp.newaxis,
swapaxes = jnp.swapaxes,
tile = jnp.tile
):
# Shape convention: parameter x voltage
# (equivalent): (channel) x (neuron)
# Reconstitute slices
pop_slices = tuple(T(*args) for T, args in pop_slices)
syn_slices = tuple(T(*args) for T, args in syn_slices)
elec_slice = elec_slice[0](*elec_slice[1])
# Unpack the vector argument
X = X.reshape(3+nchannels+nhchannels, n_neurons)
V, logCa, logits, logitm, logith, = (
X[0], X[1], X[2],
X[3:3+nchannels], X[3+nchannels:3+nchannels+nhchannels],
)
Ca = exp(logCa)
s = 1 / (1 + exp(-logits)) # Inverse of the logit function (NB: expression remains well defined even when exp(logits) -> oo)
m = 1 / (1 + exp(-logitm))
h = 1 / (1 + exp(-logith))
# Update reversal potentials using Eq. (3)
#E = tile(E, (len(X),1)) # I’m not happy with the copy, but I don’t see a better way to support modifying ECa when the size of the first dimension (time) is unknown
# JAXMOD E[...,[1,2],:] = γ*Caout/Ca # Modifying E in-place is fine as long as no other function does so
E = E.at[...,[idx_ICaT,idx_ICaS],:].set( γ * (logCaout - logCa) )
# Compute conductance inputs # g.shape : pop x channel
# pop_g.shape: channel
#mphVE = m**p * h * (V-E) # shape: channel x neuron
mphVE = m**p * (V-E) # shape: channel x neuron
mphVE = mphVE.at[h_slice,:].multiply(h)
Ii = concat(
[pop_g[:,newax] * mphVE[...,slc] # shape: channel x pop_size
for pop_g, slc in zip(g, pop_slices)],
axis=-1
)
# for pop_g, slc in zip(g, pop_slices):
# Ii[:,slc] = pop_g[:,newax] * mphVE[:,slc] # shape: channel x pop_size
# Ii has shape: channel x neuron
# Compute dCa using Eq. (2)
dCa = (-f * (Ii[...,idx_ICaT,:] + Ii[...,idx_ICaS,:]) - Ca + Ca0) / τCa
dlogCa = dCa / Ca # Chain rule w/ log transform
# Compute the voltage-dependent activation variables using definitions in §Voltage-dependent activation variables
m_inf, h_inf, τ_m, τ_h = act_vars(V, Ca)
h_inf = h_inf[h_slice] # For vectorization reasons, act_vars (at least for now)
τ_h = τ_h[h_slice] # returns dummy values for channels with no h var
# Compute dm and dh using Eq (1)
dm = (m_inf - m) / τ_m
dh = (h_inf - h) / τ_h
dlogitm = dm / (m * (1-m)) # Chain rule through logit
dlogith = dh / (h * (1-h)) # Chain rule through logit
# Compute electrical synapse inputs using Eq (5)
Ve = V[...,elec_slice,newax]
Ie = ge * (Ve - swapaxes(Ve,-1,-2)).sum(axis=-1) # This computes Ve - Ve.T, allowing for an additional time dimension on the left
# ge is assumed to be a scalar. Currently we always set it to 1, or 0 if we are computing the thermalization (which requires a completely disconnected network)
# Compute the synaptic inputs using Eq (6)
# (By splitting glut & chol, we can use scalars for all parameters and not have to worry about broadcasting with the shape of V)
cholslc, glutslc = syn_slices
Vglut = V[...,glutslc]
Vchol = V[...,cholslc]
s_glut = s[...,glutslc]
s_chol = s[...,cholslc]
sinf_glut = 1 / (1 + exp(Vth_glut - Vglut)/Δ_glut) # shape: (n_glut,)
sinf_chol = 1 / (1 + exp(Vth_chol - Vchol)/Δ_chol) # shape: (n_chol,)
τs_glut = (1 - sinf_glut) / km_glut # shape: (n_glut,)
τs_chol = (1 - sinf_chol) / km_chol # shape: (n_chol,)
# ds_glut = (sinf_glut - s_glut) / τs_glut # shape: (n_glut,)
# ds_chol = (sinf_chol - s_chol) / τs_chol # shape: (n_glut,)
dlogits_glut = (sinf_glut - s_glut) / (τs_glut * s_glut * (1-s_glut) + ε) # Incl. chain rule w/ logit transform.
dlogits_chol = (sinf_chol - s_chol) / (τs_chol * s_chol * (1-s_chol) + ε) # ε=1e-8 is added for numerical stability
dlogits = empty_like(logits)
dlogits = dlogits.at[cholslc].set(dlogits_chol)
dlogits = dlogits.at[glutslc].set(dlogits_glut)
# Sum synaptic inputs from each population
Is = concat(
[sum((gs[i,j] * s[preslc] * (V[postslc,np.newaxis] - Es_tuple[j])).sum(axis=-1)
for j, preslc in enumerate(pop_slices))
for i, postslc in enumerate(pop_slices)],
axis=-1
)
# jax.debug.print("Is: {}", Is)
# jax.debug.print("dlogits: {}", dlogits)
# # Alternative 1: In-place updates of a pre-allocated array
# Is[:] = 0
# for i, postslc in enumerate(pop_slices):
# for j, preslc in enumerate(pop_slices):
# Is[postslc] += (gs[i,j] * s[preslc] * (V[postslc,np.newaxis] - Es_tuple[j])).sum(axis=-1)
# # Alternative 2: Construct a block matrix. The extra allocation of
# # memory makes this slower, but can be useful for diagnosis
# blocks = [[gs[i,j] * s_tuple[j] * (V[postslc,np.newaxis] - Es_tuple[j])
# for j, preslc in enumerate(pop_slices)]
# for i, postslc in enumerate(pop_slices)]
# Is = np.block(blocks).sum(axis=-1)
# # Asserts serve as documentation, and can be uncommented to re-validate code after changes.
# _A = np.arange(50)
# assert Ii.shape == (nchannels, len(V)), "Conductance input matrix doesn’t have the right shape."
# assert np.array(Is_chol.shape) + np.array(Is_glut.shape) == (len(V),), "Synaptic input matrix doesn’t have the right shape."
# assert np.all(concat([_A[slc] for slc in pop_slices]) == _A[:len(V)]), "Pop slices are not contiguous or don’t add up to the total number of neurons."
# assert np.all(concat([_A[slc] for slc in syn_slices]) == _A[:len(V)]), "Synapse type slices are not contiguous or don’t add up to the total number of neurons."
# assert ds_chol.shape == (2,)
# Compute dV by summing inputs over channels
dV = -Ii.sum(axis=-2) - Is
# Add external inputs
if I_ext:
dV -= I_ext(t)
# Add contributions currents from electrical synapses
#dV[..., elec_slice] -= Ie
dV = dV.at[..., elec_slice].add(-Ie) # JAXMOD
# Add leak currents
for pop_gleak, slc in zip(gleak, pop_slices):
#dV[...,slc] -= pop_gleak * (V[...,slc] - Eleak)
dV = dV.at[...,slc].add(- pop_gleak * (V[...,slc] - Eleak) ) # JAXMOD
# Combine all derivatives and return
return concat((dV.reshape(1,-1), dlogCa.reshape(1,-1), dlogits.reshape(1,-1), dlogitm, dlogith)).reshape(-1)
Initialization#
For the experiments, we want to initialize models in a steady state. To find this steady state, we first need to run the model, generally for much longer than the amount of simulation time we need for the experiment itself. We definitely don’t want to do this every time, so after simulating a model, we store the initial state in a cache on the disk. Thus we distinguish between two types of initialization:
- Cold initialization
This is a fixed initial state used for all models, when no previous simulations are available. It is used for the thermalization run, which integrates the model until it reaches a steady state.
- Thermalized (“warm”) initialization
The final state of the initialization run is the thermalized initialization. Subsequent calls with the same model will retrieve this state from the cache.
During the thermalizaton run, the model is disconnected (all connectivities
Cold initialization values are provided by on p. 4001 of Prinz et al. [2003]:[6]
0 |
1 |
0 |
There are five main technical differences between our procedure and that of Prinz et al. [2004]:
For numerical stability, we track the values of
, and . This means we can’t initialize them exactly at 0 or 1; instead we use and , which correspond to approximately and .Prinz et al. [2004] set
after the thermalization. This is presumably because the original neuron model catalog did not include simulations of . In our case, we set for the cold initialization, and use its subsequent thermalized value for the data run. This is both simpler on the implementation side, and more consistent with the desire to let all spontaneous transients relax before connecting neurons.We don’t first compute the thermalization for individual model neurons separately, but instead recompute it for each combination of neuron models. (Specifically, each thermalization run is identified by a
g_cond
matrix – and are ignored, since they are set to zero during thermalization.) If we were to simulate the entire catalog of neuron model combinations this would be wasteful, but since we only need a handful, this approach is adequate and simpler to implement.Instead of fixed-step Euler integration, we use the Runge-Kutta 4(5) algorithm with adaptive step sizes included in
scipy.integrate
. This is not only more efficient than Euler, but also allows one to estimate numerical errors. (Although the integration scheme is currently hard-coded, since we are using standardscipy
integrators, it would be easy enough to change it or make it user-configurable. Indeed, while standard, RK45 is not necessarily the best choice for Hodgkin-Huxley systems since they are moderately stiff.)Finally, instead of detecting the steady state automatically, we use a fixed integration time and rely on visual inspection to determine whether this is enough to let transients decay in all considered models. Again we can do this because we only need to simulate a handful of models.
The main practical difference is that instead of pre-computing a neuron catalogue, the thermalization is done automatically and on-demand. For the user therefore it makes no difference whether a warm initialization for a particular circuit is available or not: when they request a simulation, if no warm initialization is available, they just need to wait longer to get their result. Subsequent runs then reuse the warm initialization computed on the first run.
Implementation
The cold initialization is given by the class method
State.cold_initialized
.The warm-up simulation is implemented in the method
Prinz2004.get_thermalization
.
Three class attributes of Prinz2004
are used to control the behaviour of the warm-up simulation:
__thermalization_store__
determines where the cache is stored on disk.__thermalization_time__
is the warm-up simulation time; it is currently set to 5s.__thermalization_time_step__
is the recording time step for the warm-up simulation. This is only relevant for inspecting the warm-up run. (The integrator uses an adaptive time step and discards non-recorded steps.)
Public API#
State
object#
In order to integrate the model equations, we need to know not just the membrane potential of each neuron, but also its calcium concentration, synaptic variable (State
object, which also provides
automatic conversion to/from log-transformed values;
convenience methods for converting between different storage layouts, for example concatening and flattening all variables for the ODE integrator.
@dataclass
class State:
V : Array[float, 1] # (n_neurons)
logCa : Array[float, 1] # (n_neurons)
logits: Array[float, 1] # (n_neurons)
logitm: Array[float, 2] # (n_neurons, n_channels)
logith: Array[float, 2] # (n_neurons, n_channels)
"""Storage container for simulator variables.
- Handles automatic conversion to/from log-transformed values.
- Provides methods for converting between storage layouts, in particular for
exporting to the 1-d vector format required by ODE integrators.
"""
@classmethod
def cold_initialized(cls, n_neurons):
"""Default state initialization. Used to initialize the thermalization run."""
return cls(
-50 * jnp.ones(n_neurons),
jnp.log(0.05) * jnp.ones(n_neurons),
-10 * jnp.ones(n_neurons), # “fully deactivated” would be s=0, which is not possible with the logit, but this should be close enough
-10 * jnp.ones((n_neurons, len(channels))), # “fully deactivated”
10 * jnp.ones((n_neurons, nhchannels)), # “fully activated”
)
@classmethod
def from_array(cls, X: Union[Array[float, 1], Array[float, 2]]) -> "Prinz2004.State":
"""Unpack a flattened state vector."""
nchannels = len(channels)
if X.ndim == 1:
X = X.reshape(3+nchannels+nhchannels, -1)
elif X.ndim != 2:
raise ValueError(f"State array `X` has {X.ndim} dimensions, when it should have either 1 or 2.")
return cls(
jnp.array(X[0]), jnp.array(X[1]), jnp.array(X[2]),
jnp.array(X[3:3+nchannels]), jnp.array(X[3+nchannels:3+nchannels+nhchannels]),
)
def to_vector(self):
"""Return a flat 1-d vector, as is used in `solve_ivp`."""
return jnp.concatenate((self.V, self.logCa, self.logits, self.logitm.flatten(), self.logith.flatten()))
def to_array(self):
"""
Return a 2-d array of shape (3+2*n_channels, n_neurons)
First three rows are `V`, `logCa`, `logits` respectively.
Then `m` and `h` as two blocks of (n_channels, n_neurons).
"""
return jnp.concatenate((
self.V[np.newaxis,:], self.logCa[np.newaxis,:], self.logits[np.newaxis,:],
self.logitm, self.logith))
@property
def s(self):
return 1 / (1 + 1/np.exp(self.logits)) # Inverse of the logit function (NB: expression chosen to be well defined even when exp(logits) -> oo)
@s.setter
def _(self, s):
self.logits = np.log(s/(1-s))
@property
def m(self):
return 1 / (1 + 1/np.exp(self.logitm))
@m.setter
def _(self, m):
self.logitm = np.log(m/(1-m))
@property
def h(self):
return 1 / (1 + 1/np.exp(self.logith))
@h.setter
def _(self, h):
self.logith = np.log(h/(1-h))
@property
def Ca(self):
return np.exp(self.logCa)
@Ca.setter
def _(self, Ca):
self.logCa = np.exp(Ca)
SimResult
object#
The ODE integrators provided by scipy.integrate
treat the data as a flat 1-D vector, which is not convenient for associating the trace of each component to the correct state variable.
A SimResult
stores the data returned from the integrator, in the integrator’s compact format, but provides a human-friendly interface for retrieving traces for individual variables. Conversions to/from log-transformed values are handled automatically.
@dataclass
class SimResult:
t : Array[float, 1]
data : Array[float, 3]
pop_slices: List[slice]
"""Stores and makes accessible the results of a simulation run.
Underlying storage is very close to the output of the ODE simulator.
Properties for each variable (`V`, `logCa`, `Ca`, etc.) retrieve the correct rows from the data structure.
"""
def __post_init__(self):
self.data = self.data.reshape(3+len(channels)+nhchannels, -1, len(self.t))
self.data = np.moveaxis(self.data, -1, 0) # Both Pandas & Holoviews work better if time is the index axis
def __getitem__(self, key) -> Union[SimResult,State]:
if isinstance(key, (slice, list, np.ndarray)):
return SimResult(t=self.t[key], data=self.data[key,:,:], pop_slices=self.pop_slices)
elif isinstance(key, int):
return self.t, State.from_array(self.data[key,:,:])
else:
raise TypeError("SimResult` only supports 1-d indexing along the time axis.")
@property
def V(self): return self._make_df(self.data[:, 0, :])
@property
def logCa(self): return self._make_df(self.data[:, 1, :])
@property
def Ca(self): return np.exp(self.logCa)
@property
def logits(self): return self._make_df(self.data[:, 2, :])
@property
def s(self): return 1 / (1 + np.exp(-self.logits))
@property
def logitm(self): return self.data[:, 3:3+len(channels), :]
@property
def m(self): return 1 / (1 + np.exp(-self.logitm))
@property
def logith(self): return self.data[:, 3+len(channels):3+len(channels)+nhchannels, :]
@property
def h(self): return 1 / (1 + np.exp(-self.logith))
def _make_df(self, data):
cols = pd.MultiIndex.from_tuples(
((pop, i)
for pop, slc in self.pop_slices.items()
for i in range(1, 1+slc.stop-slc.start)),
names=["pop", "index"]
)
return pd.DataFrame(data, index=pd.Index(self.t, name="time"), columns=cols)
Prinz2004
object#
This is the core simulator class. To perform a simulation, create an instance of this class and call its integrate
method.
@dataclass(frozen=True)
class Prinz2004:
pop_sizes : dict[str, int]
gs : Array[float, 2]
g_ion: Optional[Array[float, 2]]=None
ge : float=1. # Currently only used to turn on/off electrical connectivity
""" Core simulator class.
To perform a simulation, create an instance of this class and call its `integrate` method.
"""
# Private attributes
__thermalization_store__ : ClassVar[Path] = config.paths.simresults/"prinz2004_thermalize"
__thermalization_time__ : ClassVar[float] = 5000. # Warm initialization is the state after this many seconds of spontaneous activity (no inputs, no synaptic connections)
__thermalization_time_step__: ClassVar[float] = 1. # Warm initialization trace is recorded with this time step
# TODO: Instead of a fixed time step, using an `events` callback to detect spikes could substantially reduce the data we need to store
_thermalization_store_lock: ClassVar[Lock] = Lock(str(__thermalization_store__.with_suffix(".lock")))
_thermalization_store_lock.lifetime = timedelta(seconds=15) # 15s is the default
def __post_init__(self):
# For convenience, we allow passing `g_ion` as a DataFrame
if isinstance(self.g_ion, pd.DataFrame):
object.__setattr__(self, "g_ion", self.g_ion.to_numpy())
object.__setattr__(self, "gs", np.asarray(self.gs))
## Public API ##
def __call__(self, t_array, I_ext: Optional[Callable]=None):
X0 = self.get_thermalization()
res = self.integrate(0, X0, t_array, I_ext)
return SimResult(t_array, res.y, self.pop_slices)
def derivative(self, t: float, X: State, I_ext: Optional[Callable]=None):
"""
Evaluate the model equations, returning the derivative at `t` if the state is `X`.
This is a convenience method; it is equivalent to the code used to integrate, but not the same.
"""
X0 = X
# C.f. self.integrate()
if isinstance(X0, State):
X0_flat = X0.to_vector()
else:
X0_flat = X0.flatten()
X0 = self.State.from_vector(X0)
pop_slices = tuple(slc.__reduce__() for slc in self.pop_slices.values())
syn_slices = (self.syn_slices["chol"].__reduce__(), self.syn_slices["glut"].__reduce__())
elec_slice = self.elec_slice.__reduce__()
args=(self.g, self.gleak, self.gs, self.ge, # Conductances
pop_slices, syn_slices, elec_slice, # Population slices
self.E(X0.Ca), # Expanded from constants.E (except E[[1,2],:], corresponding to ECa, which is updated in-place)
self.tot_cells, # Pre-computed number of neurons
I_ext) # External input, if provided
return dX(t, X0_flat, *args)
## Semi-private methods ##
# Even though the dataclass is frozen, changing the init_store is allowed (at worst, it just means the thermalization is recomputed)
@classmethod
def clear_thermalization_store(cls):
cls.__thermalization_store__.unlink(missing_ok=True)
def integrate(self, t0: float, X0: Union[Array, Prinz2004.State], t_eval: ArrayLike,
I_ext: Optional[Callable]=None) -> OdeResult:
"""
Integrate the model from the initial state `(t0, X0)`. The result will contain the states
at all values in `t_eval`.
Args:
t0: time corresponding to X0
X0: Initial state; either an instance of Prinz2004 or a flat vector.
t_eval: The time points at which to record the trace
I_ext: If provided, this should be a function with the signature ``(t) -> I``, where ``I``
is a 1-d vector with one element for every neuron. The units of I are nA; its value is
added directly to the derivative dV.
Alternatively, if all neurons should receive the same input, ``I`` can be a scalar.
(In case we do support a vectorized `dX` later, if the function accepts a vector as input,
the result should be of shape ``(time x neurons)``.
"""
if isinstance(X0, State):
X0_flat = X0.to_vector()
else:
X0_flat = X0.flatten()
X0 = self.State.from_vector(X0)
t_eval = np.sort(t_eval)
T = t_eval[-1]
# NB: Order is important for syn_slices
pop_slices = tuple(slc.__reduce__() for slc in self.pop_slices.values())
syn_slices = (self.syn_slices["chol"].__reduce__(), self.syn_slices["glut"].__reduce__())
elec_slice = self.elec_slice.__reduce__()
res = integrate.solve_ivp(
dX, (t0, T), X0_flat, method="RK45", t_eval=t_eval,
args=(self.g, self.gleak, self.gs, self.ge, # Conductances
pop_slices, syn_slices, elec_slice, # Population slices
self.E(X0.Ca), # Expanded from constants.E (except E[[1,2],:], corresponding to ECa, which is updated in-place)
self.tot_cells, # Pre-computed number of neurons
I_ext), # External input, if provided
first_step=0.005)#constants.Δt)
if res.status < 0:
raise RuntimeError(f"Integration failed with the following message:\n{res.message}")
#logger.error(res.message)
return res
@property
def g_cond(self):
if self.g_ion is not None:
# Explicitely assigned g_cond
return pd.DataFrame(self.g_ion,
index=self.pop_model_list, columns=g_cond.columns)
else:
return g_cond.loc[self.pop_model_list]
@property
def g(self):
g_no_leak = self.g_cond.filter(regex=r"mathrm{(?!leak)")
#_pop_lst = [self.pop_models["AB/PD"], pop_models["LP"], pop_models["PY"]]
return g_no_leak.to_numpy()
@property
def gleak(self):
return self.g_cond.loc[:, '$g(I_\mathrm{leak})$'].to_numpy()
def E(self, Ca):
E = jnp.tile(constants.E[:,np.newaxis], (1, self.tot_cells))
E = E.at[[1,2],:].set(constants.γ*np.log(constants.Caout/Ca))
return E
@property
def tot_cells(self):
return sum(self.pop_sizes.values())
@property
def pop_model_list(self):
"""
Returns a list of the form ``["AB 3", "PD 3", "LP 2", "PY 4"]``, where values
correspond to model labels used in `g_cond`. They may omit numbers, if those are
also omitted in `g_cond`.
"""
return list(self.pop_slices.keys())
# # CODE FOR LATER: The following was originally used to merge AB and PD pop labels
# # AB 1 or PD 1 -> AB/PD 1 (where '1' is any number)
# slcs = self.create_slices(
# self.pop_sizes, label_sub={r"(?:AB|PD)( \d+)?": r"AB/PD\1"})
# return list(slcs.keys())
@property
def pop_slices(self):
# If pops are specified as AB/PD, split them AB: 1, PD: n-1
pop_sizes = {}
for name, size in self.pop_sizes.items():
m = re.match("AB/PD( \d+)?", name)
if m:
pop_sizes[f"PD{m[1]}"] = size - 1
pop_sizes[f"AB{m[1]}"] = 1
else:
pop_sizes[name] = size
return self.create_slices(pop_sizes, {})
@property
def syn_slices(self):
# If pops are specified as AB/PD, split them AB: 1, PD: n-1
pop_sizes = {}
for name, size in self.pop_sizes.items():
m = re.match("AB/PD( \d+)?", name)
if m:
pop_sizes[f"PD{m[1]}"] = size - 1
pop_sizes[f"AB{m[1]}"] = 1
else:
pop_sizes[name] = size
# When creating the slices:
# AB, LP, PY -> glut PD -> chol
slices = self.create_slices(
pop_sizes, label_sub={r"(?:AB|LP|PY)( \d+)?": r"glut",
r"(?:PD)( \d+)?": r"chol"})
# Corner case: If there are no cholinergic or glutamatergic neurons,
# add corresponding empty slices
if "glut" not in slices: slices["glut"] = slice(0,0)
if "chol" not in slices: slices["chol"] = slice(0,0)
return slices
@property
def elec_slice(self):
slcs = self.create_slices(
self.pop_sizes, label_sub={r"(?:AB|PD).*": "elec"}, include=[r"AB|PD"])
return slcs.get("elec", slice(0,0)) # If there are no neurons with electrical synapses, return an empty slice
def get_thermalization(self) -> State:
"""
The thermalization is obtained by integrating the network model
with all connections to zero, to find the spontaneous steady-state of
each neuron model.
This method either performs that integration, or, if it was already
done before, retrieves it from a cache.
"""
sim = self.thermalize()
t, X = sim[-1] # Get the final state
# Warm init is done with one neuron per population: reinflate the populations.
# NB: The thermalization uses the merged AB/PD populations of pop_slices
pop_sizes = tuple(slc.stop-slc.start for slc in self.pop_slices.values())
return State.from_array(
np.repeat(X.to_array(), pop_sizes, axis=-1)
)
def thermalize(self) -> SimResult:
"""
Compute the initial simulation used to thermalize the model.
- Constructs a surrogate model with all population sizes set to 1.
- Sets all connectivities to 0.
- Integrate
In addition, this method manages the termalization cache, and will skip
all steps if it finds a result already in the cache.
.. TODO:: Currently the cache key is the entire model, which is a bit
wasteful. It would be better to cache the results of each neuron
separately.
Returns a `SimResult` instance, with recording resolution determined by
``self.__thermalization_time_step__``.
Normally this method is not used directly, but called by `get_thermalization`.
The main reason to use it directly would be to plot the initialization
run for inspection.
"""
# # Pre-compute the set of pop_slices, so we can validate it against the disconnected model
# # (Alternatively, we could create a disconnected model even when we find a pre-cached result, but that seems wasteful)
# pop_slices = {}; i = 0
# for pop in self.pop_slices:
# if pop.startswith("AB/PD"):
# pop_slices[pop] = slice(i,i+2); i += 2
# else:
# pop_slices[pop] = slice(i,i+1); i += 1
# We thermalize the model by turning off synaptic connections,
# so we also use a key which only depends on the model identities.
init_key = str((("T", self.__thermalization_time__), ("g_cond", self.g_cond.to_csv())))
if not self.__thermalization_store__.exists():
self.__thermalization_store__.parent.mkdir(parents=True, exist_ok=True)
warm_t_X = None
else:
with shelve.open(str(self.__thermalization_store__), 'r') as store:
warm_t_X = store.get(init_key)
if warm_t_X:
t, warm_X = warm_t_X
else:
# NB: We use pop_slices instead of pop_sizes in order to merge AB/PD populations when possible
# Get the cold initialized state (for a smaller model with 1 neuron / pop)
X0 = State.cold_initialized(len(self.pop_slices))
# Construct the array of times at which we will record
init_T = self.__thermalization_time__
Δt = self.__thermalization_time_step__
t_eval = np.concatenate((np.arange(0, init_T, Δt), [init_T])) # The final time init_T is always evaluated
# Build the model with no connections and only one neuron per pop
disconnected_model = Prinz2004(
pop_sizes={pop: 1 for pop in self.pop_slices},
gs=np.zeros_like(self.gs),
g_ion=self.g_ion,
ge=0)
#assert pop_slices == disconnected_model.pop_slices
# Integrate
res = disconnected_model.integrate(0, X0, t_eval)
t, warm_X = res.t, res.y
# Update the cache
with self._thermalization_store_lock: # We use a lock file because shelve doesn't support concurrent writes (not doing this can corrupt the store)
with shelve.open(str(self.__thermalization_store__)) as store:
store[init_key] = (t, warm_X)
# Return
return SimResult(t, warm_X,
pop_slices={pop: slice(i,i+1)
for i,pop in enumerate(self.pop_slices)})
## Private methods ##
@staticmethod
def create_slices(pop_sizes,
label_sub: dict[str,str],
include: Sequence[str]=(),
exclude: Sequence[str]=()):
"""
`label_sub` should be a substitution dictionary of ``pattern: repl`` pairs which
are passed to `re.sub`. In other words, `pattern` should be a regex pattern,
and when it matches a cell type name, it is replaced by `repl`.
"""
slcs = {}
i = 0
# Convert population sizes to slices, by tracking an index counter i
for name, size in pop_sizes.items():
# NB: Don’t use `continue` in this loop, otherwise the `i` counter is lost
skip = ((include and not any(re.search(pattern, name) for pattern in include))
or any(re.search(pattern, name) for pattern in exclude))
for pattern, repl in label_sub.items():
name = re.sub(pattern, repl, name)
if skip:
pass
elif name in slcs:
last_slc = slcs[name][-1]
if i == last_slc.stop:
# This slice is contiguous with the previous one: merge them
slcs[name][-1] = slice(last_slc.start, i+size)
else:
# Not contiguous: Add a new slice
slcs[name].append(slice(i, i+size))
else:
# New population model: Add a new slice
slcs[name] = [slice(i, i+size)]
i += size
# Convert to simple slices when possible, otherwise indexing arrays
iarr = np.arange(i) # i counter is now exactly the total number of cells
for name, slc_lst in slcs.items():
if len(slc_lst) == 1:
slcs[name] = slc_lst[0]
else:
# Multiple discontinuous arrays: must use indexing
slcs[name] = np.concatenate([iarr[slc] for slc in slc_lst])
return slcs
Anecdotal timings#
All timings use the circuit model described in the Circuit model section.
Model size |
Hardware |
NumPy |
JAX |
---|---|---|---|
9 |
3.50GHz CPU |
1.26 ms ± 13.3 µs |
Simulation time |
||||
---|---|---|---|---|
Model size |
Hardware |
Library |
200 ms |
1000 ms |
9 |
3.10GHz CPU |
NumPy |
21.3 s |
2 min 49 s |
9 |
3.10GHz CPU |
JaX |
2.33 s |
0 min 7.54 s |