# -*- coding: utf-8 -*-
#
# sovtree.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
import itertools
import copy
from . import morphtree
from .stree import STree
from .morphtree import MorphLoc, MorphNode
from .phystree import PhysNode, PhysTree
from .netree import NETNode, NET, Kernel
from ..tools.fittools import zerofinding as zf
from ..tools.fittools import histogramsegmentation as hs
def _consecutive(data, stepsize=1):
return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1)
[docs]
class SOVNode(PhysNode):
"""
Node that defines functions and stores quantities to implement separation
of variables calculation (Major, 1993)
"""
def __init__(self, index, p3d=None):
super().__init__(index, p3d)
def _set_sov(self, channel_storage, tau_0=0.02):
self.counter = 0
# segment parameters
self.g_m = self.calc_g_tot(channel_storage) # uS/cm^2
# parameters for SOV approach
self.R_sov = self.R * 1e-4 # convert um to cm
self.L_sov = self.L * 1e-4 # convert um to cm
self.tau_m = self.c_m / self.g_m # s
self.eps_m = self.tau_m / tau_0
self.lambda_m = np.sqrt(self.R_sov / (2.0 * self.g_m * self.r_a)) # cm
self.tau_0 = tau_0 # s
self.z_a = self.r_a / (np.pi * self.R_sov**2) # MOhm/cm
self.g_inf_m = 1.0 / (self.z_a * self.lambda_m) # uS
# # segment amplitude information
self.kappa_m = np.nan
self.mu_vals_m = np.nan
self.q_vals_m = np.nan
def __str__(self, with_parent=True, with_morph_info=False):
if with_morph_info:
node_str = super(PhysNode, self).__str__(with_parent=with_parent)
else:
node_str = super(MorphNode, self).__str__(with_parent=with_parent)
if hasattr(self, "R_sov"):
node_str += (
f" --- "
f"(g_m = {self.g_m:.8f} uS/cm^2, "
f"tau_m = {self.tau_m:.8f} s, "
f"eps_m = {self.eps_m:.8f}, "
f"R_sov = {self.R_sov:.8f} cm, "
f"L_sov = {self.L_sov:.8f} cm)"
)
return node_str
def q_m(self, x):
return np.sqrt(self.eps_m * x**2 - 1.0)
def dq_dp_m(self, x):
return -self.tau_m / (2.0 * self.q_m(x))
def mu_m(self, x):
cns = self.child_nodes
if len(cns) == 0:
return self.z_a * self.lambda_m / self.q_m(x) * self.g_shunt
else:
x = zf._to_complex(x)
mu_ds = [cn.mu_m(x) for cn in cns]
q_ds = [cn.q_m(x) for cn in cns]
return (
self.g_shunt
- np.sum(
[
cn.g_inf_m
* q_ds[i]
* (1.0 - mu_ds[i] / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m))
/ (1.0 / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m) + mu_ds[i])
for i, cn in enumerate(cns)
],
0,
)
) / (self.g_inf_m * self.q_m(x))
def dmu_dp_m(self, x):
cns = self.child_nodes
if len(cns) == 0:
return -self.dq_dp_m(x) * self.mu_m(x) / self.q_m(x)
else:
x = zf._to_complex(x)
mu_ds = [cn.mu_m(x) for cn in cns]
dmu_dp_ds = [cn.dmu_dp_m(x) for cn in cns]
q_ds = [cn.q_m(x) for cn in cns]
dq_dp_ds = [cn.dq_dp_m(x) for cn in cns]
return (
-self.dq_dp_m(x) * self.mu_m(x)
- np.sum(
[
cn.g_inf_m
* (
dq_dp_ds[i]
* (
1.0
- mu_ds[i] / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m)
)
/ (
1.0 / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m)
+ mu_ds[i]
)
+ q_ds[i]
* (
(1.0 + mu_ds[i] ** 2)
* dq_dp_ds[i]
* cn.L_sov
/ cn.lambda_m
- dmu_dp_ds[i]
)
/ (
np.cos(q_ds[i] * cn.L_sov / cn.lambda_m)
+ mu_ds[i] * np.sin(q_ds[i] * cn.L_sov / cn.lambda_m)
)
** 2
)
for i, cn in enumerate(cns)
],
0,
)
/ self.g_inf_m
) / self.q_m(x)
def _set_kappa_factors(self, xzeros):
xzeros = zf._to_complex(xzeros)
self.kappa_m = self.parent_node.kappa_m / (
np.cos(self.q_m(xzeros) * self.L_sov / self.lambda_m)
+ self.mu_m(xzeros) * np.sin(self.q_m(xzeros) * self.L_sov / self.lambda_m)
)
def _set_mu_vals(self, xzeros):
xzeros = zf._to_complex(xzeros)
self.mu_vals_m = self.mu_m(xzeros)
def _set_q_vals(self, xzeros):
xzeros = zf._to_complex(xzeros)
self.q_vals_m = self.q_m(xzeros)
def _find_local_poles(self, maxspace_freq=500):
poles = []
pmultiplicities = []
n = 0
val = 0.0
while val < maxspace_freq:
poles.append(np.sqrt((1.0 + val**2) / self.eps_m))
if val == 0:
pmultiplicities.append(0.5)
else:
pmultiplicities.append(1.0)
n += 1
val = n * np.pi * self.lambda_m / self.L_sov
return poles, pmultiplicities
def _set_zeros_poles(self, maxspace_freq=500, pprint=False):
cns = self.child_nodes
# find the poles of (1 + mu*cot(qL/l))/(cot(qL/l) + mu)
lpoles, lpmultiplicities = self._find_local_poles(maxspace_freq)
for cn in cns:
lpoles.extend(cn.poles)
lpmultiplicities.extend(cn.pmultiplicities)
inds = np.argsort(lpoles)
lpoles = np.array(lpoles)[inds]
lpmultiplicities = np.array(lpmultiplicities)[inds]
# construct the function cot(qL/l) + mu
f = lambda x: 1.0 / np.tan(
self.q_m(x) * self.L_sov / self.lambda_m
) + self.mu_m(x)
dfdx = (
lambda x: -2.0
* x
* (
-(self.L_sov / self.lambda_m)
/ np.sin(self.q_m(x) * self.L_sov / self.lambda_m) ** 2
* self.dq_dp_m(x)
+ self.dmu_dp_m(x)
)
/ self.tau_0
)
# find its zeros, this are the poles of the next level
xval = 1.5 / np.sqrt(self.eps_m)
for cn in cns:
c_eps_m = cn.eps_m
xval_ = 1.5 / np.sqrt(c_eps_m)
if xval_ > xval:
xval = xval_
if np.abs(f(xval)) < 1e-20:
xval = (xval + lpoles[1]) / 2.0
if pprint:
print("")
print("xval: ", xval)
# find zeros larger than xval
if pprint:
print("finding real poles")
PF = zf.poleFinder(
fun=f,
dfun=dfdx,
global_poles={"poles": lpoles, "pmultiplicities": lpmultiplicities},
)
poles, pmultiplicities = PF.find_real_zeros(vmin=xval)
# find the first zero
if pprint:
print("finding first pole")
p1 = []
pm1 = []
zf.find_zeros_on_segment(
p1, pm1, 0.0, xval, f, dfdx, lpoles, lpmultiplicities, pprint=pprint
)
self.poles = np.concatenate((p1, poles)).real
self.pmultiplicities = np.concatenate((pm1, pmultiplicities)).real
class SomaSOVNode(SOVNode):
"""
Subclass of SOVNode to threat the special case of the soma
The following member functions are not supposed to work properly,
calling them may result in errors:
`neat.SOVNode._set_kappa_factors()`
`neat.SOVNode._set_mu_vals()`
`neat.SOVNode._set_q_vals()`
`neat.SOVNode._find_local_poles()`
"""
def __init__(self, index, p3d=None):
super().__init__(index, p3d)
def _set_sov(self, channel_storage, tau_0=0.02):
self.counter = 0
# convert to cm
self.R_sov = self.R * 1e-4 # convert um to cm
self.L_sov = self.L * 1e-4 # convert um to cm
# surface
self.A = 4.0 * np.pi * self.R_sov**2 # cm^2
# total conductance
self.g_m = self.calc_g_tot(channel_storage=channel_storage) # uS/cm^2
# parameters for the SOV approach
self.tau_m = self.c_m / self.g_m # s
self.eps_m = self.tau_m / tau_0 # ns
self.g_s = self.g_m * self.A + self.g_shunt # uS
self.c_s = self.c_m * self.A # uF
self.tau_0 = tau_0 # s
# segment amplitude factors
self.kappa_m = 1.0
def f_transc(self, x):
cns = self.child_nodes
x = zf._to_complex(x)
mu_ds = [cn.mu_m(x) for cn in cns]
q_ds = [cn.q_m(x) for cn in cns]
return self.g_s * (1.0 - self.eps_m * x**2) - np.sum(
[
cn.g_inf_m
* q_ds[i]
* (1.0 - mu_ds[i] / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m))
/ (1.0 / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m) + mu_ds[i])
for i, cn in enumerate(cns)
],
0,
)
def dN_dp(self, x):
cns = self.child_nodes
x = zf._to_complex(x)
mu_ds = [cn.mu_m(x) for cn in cns]
dmu_dp_ds = [cn.dmu_dp_m(x) for cn in cns]
q_ds = [cn.q_m(x) for cn in cns]
dq_dp_ds = [cn.dq_dp_m(x) for cn in cns]
return self.c_s - np.sum(
[
cn.g_inf_m
* (
dq_dp_ds[i]
* (1.0 - mu_ds[i] / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m))
/ (1.0 / np.tan(q_ds[i] * cn.L_sov / cn.lambda_m) + mu_ds[i])
+ q_ds[i]
* (
(1.0 + mu_ds[i] ** 2) * dq_dp_ds[i] * cn.L_sov / cn.lambda_m
- dmu_dp_ds[i]
)
/ (
np.cos(q_ds[i] * cn.L_sov / cn.lambda_m)
+ mu_ds[i] * np.sin(q_ds[i] * cn.L_sov / cn.lambda_m)
)
** 2
)
for i, cn in enumerate(cns)
],
0,
)
def _set_zeros_poles(self, maxspace_freq=500, pprint=False):
# find the poles of cot(qL/l) + mu
lpoles = []
lpmultiplicities = []
for cn in self.child_nodes:
lpoles.extend(cn.poles)
lpmultiplicities.extend(cn.pmultiplicities)
inds = np.argsort(lpoles)
lpoles = np.array(lpoles)[inds]
lpmultiplicities = np.array(lpmultiplicities)[inds]
# construct the function cot(qL/l) + mu
f = lambda x: self.f_transc(x)
dfdx = lambda x: -2.0 * x * self.dN_dp(x) / self.tau_0
# find its zeros, this are the inverse timescales of the model
xval = 1.5 / np.sqrt(self.eps_m)
for cn in self.child_nodes:
c_eps_m = cn.eps_m
xval_ = 1.5 / np.sqrt(c_eps_m)
if xval_ > xval:
xval = xval_
if np.abs(f(xval)) < 1e-20:
xval = (xval + lpoles[1]) / 2.0
if pprint:
print("xval: ", xval)
# find zeros larger than xval
PF = zf.poleFinder(
fun=f,
dfun=dfdx,
global_poles={"poles": lpoles, "pmultiplicities": lpmultiplicities},
)
zeros, multiplicities = PF.find_real_zeros(vmin=xval)
# find the first zero
z1 = []
zm1 = []
zf.find_zeros_on_segment(
z1, zm1, 0.0, xval, f, dfdx, lpoles, lpmultiplicities, pprint=pprint
)
self.zeros = np.concatenate((z1, zeros)).real
self.zmultiplicities = np.concatenate((zm1, multiplicities)).real
self.prefactors = self.dN_dp(self.zeros).real
[docs]
class SOVTree(PhysTree):
"""
Class that computes the separation of variables time scales and spatial
mode functions for a given morphology and electrical parameter set. Employs
the algorithm by (Major, 1994). This three defines a special
`neat.SomaSOVNode` on as a derived class from `neat.SOVNode` as some
functions required for SOV calculation are different and thus overwritten.
The SOV calculation proceeds on the computational tree (see docstring of
`neat.MorphNode`). Thus it makes no sense to look for sov quantities in the
original tree.
"""
def __init__(self, *args, **kwargs):
self.maxspace_freq = None
super().__init__(*args, **kwargs)
def _get_repr_dict(self):
repr_dict = super()._get_repr_dict()
repr_dict.update({"maxspace_freq": f"{self.maxspace_freq:1.6g}"})
return repr_dict
def __repr__(self):
repr_str = STree.__repr__(self)
return repr_str + repr(self._get_repr_dict())
[docs]
def create_corresponding_node(self, node_index, p3d=None):
"""
Creates a node with the given index corresponding to the tree class.
Parameters
----------
node_index: int
index of the new node
"""
if node_index == 1:
return SomaSOVNode(node_index, p3d=p3d)
else:
return SOVNode(node_index, p3d=p3d)
@morphtree.computational_tree_decorator
def get_sov_matrices(self, loc_arg):
"""
returns the alphas, the reciprocals of the mode time scales [1/ms]
as well as the spatial functions evaluated at ``locs``
Parameters
----------
loc_arg: see :func:`neat.MorphTree.convert_loc_arg_to_locs()`
the locations at which to evaluate the SOV matrices
Returns
-------
alphas: np.ndarray of complex (ndim = 1)
the reciprocals of mode time-scales (kHz)
gammas: np.ndarray of complex (ndim = 2)
the spatial function associated with each mode, evaluated at
each locations. Dimension 0 is number of modes and dimension 1
number of locations
"""
locs = self.convert_loc_arg_to_locs(loc_arg)
if len(self) > 1:
# set up the matrices
zeros = self.root.zeros
prefactors = self.root.prefactors
alphas = zeros**2 / (self.tau_0 * 1e3)
gammas = np.zeros((len(alphas), len(locs)), dtype=complex)
# fill the matrix of prefactors
for ii, loc in enumerate(locs):
if loc["node"] == 1:
x = 0.0
node = self.root.child_nodes[0]
else:
x = loc["x"]
node = self[loc["node"]]
# fill a column of the matrix, corresponding to current loc
gammas[:, ii] = (
node.kappa_m
* (
np.cos(node.q_vals_m * (1.0 - x) * node.L_sov / node.lambda_m)
+ node.mu_vals_m
* np.sin(node.q_vals_m * (1.0 - x) * node.L_sov / node.lambda_m)
)
/ np.sqrt(prefactors * 1e3)
)
else:
alphas = np.array([1e-3 / self.root.tau_m])
gammas = np.array([[np.sqrt(alphas[0] / self.root.g_s)]])
# return the matrices
return alphas, gammas
@morphtree.computational_tree_decorator
def calc_sov_equations(self, maxspace_freq=500.0, pprint=False):
"""
Calculate the timescales and spatial functions of the separation of
variables approach, using the algorithm by (Major, 1993).
The (reciprocals) of the timescales (i.e. the roots of the transcendental
equation) are stored in the somanode.
The spatial factors are stored in each (computational) node.
Parameters
----------
maxspace_freq: float (default is 500)
roughly corresponds to the maximal spatial frequency of the
smallest time-scale mode
"""
self.maxspace_freq = maxspace_freq
self.tau_0 = np.pi # 1.
for node in self:
node._set_sov(self.channel_storage, tau_0=self.tau_0)
if len(self) > 1:
# start the recursion through the tree
self._sov_from_leaf(
self.leafs[0],
self.leafs[1:],
maxspace_freq=maxspace_freq,
pprint=pprint,
)
# zeros are now found, set the kappa factors
zeros = self.root.zeros
self._sov_from_root(self.root, zeros)
# clean
for node in self:
node.counter = 0
else:
self[1]._set_sov(self.channel_storage, tau_0=self.tau_0)
def _sov_from_leaf(self, node, leafs, count=0, maxspace_freq=500.0, pprint=False):
if pprint:
print("Forward sweep: " + str(node))
pnode = node.parent_node
# log how many times recursion has passed at node
if not self.is_leaf(node):
node.counter += 1
# if the number of childnodes of node is equal to the amount of times
# the recursion has passed node, the mu functions can be set. Otherwise
# we start a new recursion at another leaf.
if node.counter == len(node.child_nodes):
node._set_zeros_poles(maxspace_freq=maxspace_freq)
if not self.is_root(node):
self._sov_from_leaf(
pnode,
leafs,
count=count + 1,
maxspace_freq=maxspace_freq,
pprint=pprint,
)
elif len(leafs) > 0:
self._sov_from_leaf(
leafs[0],
leafs[1:],
count=count + 1,
maxspace_freq=maxspace_freq,
pprint=pprint,
)
def _sov_from_root(self, node, zeros):
for cnode in node.child_nodes:
cnode._set_kappa_factors(zeros)
cnode._set_mu_vals(zeros)
cnode._set_q_vals(zeros)
self._sov_from_root(cnode, zeros)
[docs]
def get_mode_importance(
self, loc_arg=None, sov_data=None, importance_type="simple"
):
"""
Gives the overal importance of the SOV modes for a certain set of
locations
Parameters
----------
loc_arg: None or list of locations
sov_data: None or tuple of mode matrices
One of the keyword arguments ``loc_arg`` or ``sov_data``
must not be ``None``. If ``loc_arg`` is not ``None``, the importance
is evaluated at these locations (see
:func:`neat.MorphTree.convert_loc_arg_to_locs`).
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
importance_type: string ('relative' or 'absolute')
when 'absolute', returns an absolute measure of the importance,
when 'relative', normalizes so that maximum importance is one.
Defaults to 'relative'.
Returns
-------
np.ndarray (ndim = 1)
the importances associated with each mode for the provided set
of locations
"""
if loc_arg is not None:
locs = self.convert_loc_arg_to_locs(loc_arg)
alphas, gammas = self.get_sov_matrices(locs)
elif sov_data is not None:
alphas = sov_data[0]
gammas = sov_data[1]
else:
raise IOError(
"One of the kwargs `loc_arg` or `sov_data` must not be ``None``"
)
if importance_type == "simple":
absolute_importance = np.sum(np.abs(gammas), 1) / np.abs(alphas)
elif importance_type == "full":
absolute_importance = np.zeros(len(alphas))
for kk, (alpha, phivec) in enumerate(zip(alphas, gammas)):
absolute_importance[kk] = np.sqrt(
np.sum(np.abs(np.dot(phivec[:, None], phivec[None, :])))
/ np.abs(alpha)
)
else:
raise ValueError("`importance_type` argument can be 'simple' or \
'full'")
return absolute_importance / np.max(absolute_importance)
[docs]
def get_important_modes(
self,
loc_arg=None,
sov_data=None,
eps=1e-4,
sort_type="timescale",
return_importance=False,
):
"""
Returns the most importand eigenmodes (those whose importance is above
the threshold defined by `eps`)
Parameters
----------
loc_arg: None or list of locations
sov_data: None or tuple of mode matrices
One of the keyword arguments ``loc_arg`` or ``sov_data``
must not be ``None``. If ``loc_arg`` is not ``None``, the importance
is evaluated at these locations (see
:func:`neat.MorphTree.convert_loc_arg_to_locs`).
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
eps: float
the cutoff threshold in relative importance below which modes
are truncated
sort_type: string ('timescale' or 'importance')
specifies in which order the modes are returned. If 'timescale',
modes are sorted in order of decreasing time-scale, if
'importance', modes are sorted in order of decreasing importance.
return_importance: bool
if ``True``, returns the importance metric associated with each
mode
Returns
-------
alphas: np.ndarray of complex (ndim = 1)
the reciprocals of mode time-scales ``[kHz]``
gammas: np.ndarray of complex (ndim = 2)
the spatial function associated with each mode, evaluated at
each locations. Dimension 0 is number of modes and dimension 1
number of locations
importance: np.ndarray (`shape` matches `alphas`, only if `return_importance` is ``True``)
value of importance metric for each mode
"""
if loc_arg is not None:
locs = self.convert_loc_arg_to_locs(loc_arg)
alphas, gammas = self.get_sov_matrices(locs)
elif sov_data is not None:
alphas = sov_data[0]
gammas = sov_data[1]
else:
raise IOError(
"One of the kwargs `loc_arg` or `sov_data` must not be ``None``"
)
importance = self.get_mode_importance(
sov_data=(alphas, gammas), importance_type="simple"
)
inds = np.where(importance > eps)[0]
# only modes above importance cutoff
alphas, gammas, importance = alphas[inds], gammas[inds, :], importance[inds]
if sort_type == "timescale":
inds_sort = np.argsort(np.abs(alphas))
elif sort_type == "importance":
inds_sort = np.argsort(importance)[::-1]
else:
raise ValueError("`sort_type` argument can be 'timescale' or \
'importance'")
if return_importance:
return alphas[inds_sort], gammas[inds_sort, :], importance[inds_sort]
else:
return alphas[inds_sort], gammas[inds_sort, :]
def get_kernels(
self,
loc_arg=None,
sov_data=None,
eps=0.0,
):
"""
Returns the impulse response kernels as a double nested list of "neat.Kernel".
The element at the position i,j represents the transfer impedance kernel
between compartments i and j.
Parameters
----------
loc_arg: None or list of locations
sov_data: None or tuple of mode matrices
One of the keyword arguments ``loc_arg`` or ``sov_data``
must not be ``None``. If ``loc_arg`` is not ``None``, the importance
is evaluated at these locations (see
:func:`neat.MorphTree.convert_loc_arg_to_locs`).
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
eps: float
the cutoff threshold in relative importance below which modes
are truncated
Returns
-------
list of list of `neat.Kernel`
The kernels of the model
"""
if loc_arg is not None:
locs = self.convert_loc_arg_to_locs(loc_arg)
alphas, gammas = self.get_important_modes(locs, eps=eps)
elif sov_data is not None:
alphas = sov_data[0]
gammas = sov_data[1]
else:
raise IOError(
"One of the kwargs `loc_arg` or `sov_data` must not be ``None``"
)
# some renaming
nn = gammas.shape[1]
aa = alphas
cc = np.einsum("ik,kj->kij", gammas.T, gammas)
return [[Kernel((aa, cc[:, ii, jj])) for ii in range(nn)] for jj in range(nn)]
def calc_zf(
self,
loc0,
loc1,
freqs=None,
eps=0.0,
):
"""
Computes the impedance between two locations for the provided frequencies
Parameters
----------
loc1: dict, tuple or `:class:MorphLoc`
One of two locations between which the transfer impedance is computed
loc2: dict, tuple or `:class:MorphLoc`
One of two locations between which the transfer impedance is computed
freqs: np.ndarray of complex or None (default)
if ``None``, returns the steady state impedance matrix, if
a array of complex numbers, returns the impedance matrix for
each Fourrier frequency in the array
eps: float
the cutoff threshold in relative importance below which modes
are truncated
Returns
------
`np.ndarray` (``ndim=3``)
the matrix of impulse responses, first dimension corresponds to the
time axis, second and third dimensions contain the impulse response
in ``[MOhm/ms]`` at that time point
"""
if freqs is None:
freqs = np.array([0.0])
kernel = self.get_kernels(loc_arg=[loc0, loc1], eps=eps)[0][1]
return kernel.ft(freqs)
def calc_zt(
self,
loc0,
loc1,
times=None,
eps=0.0,
):
"""
Computes the impulse response kernel between two locations for the provided frequencies
Parameters
----------
loc1: dict, tuple or `:class:MorphLoc`
One of two locations between which the transfer impedance is computed
loc2: dict, tuple or `:class:MorphLoc`
One of two locations between which the transfer impedance is computed
times: np.array
The time-points at which to evaluate the kernels. If not provided,
evaluates `t = np.linspace(0.1,100.,1000)`
eps: float
the cutoff threshold in relative importance below which modes
are truncated
Returns
------
`np.ndarray` (``ndim=3``)
the matrix of impulse responses, first dimension corresponds to the
time axis, second and third dimensions contain the impulse response
in ``[MOhm/ms]`` at that time point
"""
if times is None:
times = np.linspace(0.1, 100.0, 1000)
kernel = self.get_kernels(loc_arg=[loc0, loc1], eps=eps)[0][1]
return kernel.t(times)
def calc_impulse_response_matrix(
self,
times=None,
loc_arg=None,
sov_data=None,
eps=0.0,
):
"""
Computes the matrix of impulse response kernels at a given set of
locations
Parameters
----------
times: np.array
The time-points at which to evaluate the kernels. If not provided,
evaluates `t = np.linspace(0.1,100.,1000)`
loc_arg: None or list of locations
sov_data: None or tuple of mode matrices
One of the keyword arguments ``loc_arg`` or ``sov_data``
must not be ``None``. If ``loc_arg`` is not ``None``, the importance
is evaluated at these locations (see
:func:`neat.MorphTree.convert_loc_arg_to_locs`).
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
eps: float
the cutoff threshold in relative importance below which modes
are truncated
Returns
------
`np.ndarray` (``ndim=3``)
the matrix of impulse responses, first dimension corresponds to the
time axis, second and third dimensions contain the impulse response
in ``[MOhm/ms]`` at that time point
"""
if times is None:
times = np.linspace(0.1, 100.0, 1000)
kernels = self.get_kernels(loc_arg=loc_arg, sov_data=sov_data, eps=eps)
zt_mat = np.array([[zk.t(times) for zk in row] for row in kernels])
return np.transpose(zt_mat, axes=(2, 0, 1))
[docs]
def calc_impedance_matrix(
self,
freqs=None,
loc_arg=None,
sov_data=None,
eps=0.0,
mem_limit=500,
):
"""
Compute the impedance matrix for a set of locations
Parameters
----------
freqs: np.ndarray of complex or None (default)
if ``None``, returns the steady state impedance matrix, if
a array of complex numbers, returns the impedance matrix for
each Fourrier frequency in the array
loc_arg: None or list of locations
sov_data: None or tuple of mode matrices
One of the keyword arguments ``loc_arg`` or ``sov_data``
must not be ``None``. If ``loc_arg`` is not ``None``, the importance
is evaluated at these locations (see
:func:`neat.MorphTree.convert_loc_arg_to_locs`).
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
eps: float
the cutoff threshold in relative importance below which modes
are truncated
mem_limit: int
parameter governs whether the fast (but memory intense) method
or the slow method is used
Returns
-------
np.ndarray of floats (ndim = 2 or 3)
the impedance matrix, steady state if `freqs` is ``None``, the
frequency dependent impedance matrix if `freqs` is given, with
the frequency dependence at the first dimension ``[MOhm ]``
"""
# TODO: this shortened formulation causes some tests to fail, have to figure out why
# remove_dim = False
# if freqs is None:
# remove_dim = True
# freqs = np.array([0.])
# kernels = self.get_kernels(loc_arg=loc_arg, sov_data=sov_data, eps=eps)
# zf_mat = np.array([[zk.ft(freqs) for zk in row] for row in kernels])
# if remove_dim:
# zf_mat[:,:,0]
# return np.transpose(zf_mat, axes=(2,0,1))
if loc_arg is not None:
locs = self.convert_loc_arg_to_locs(loc_arg)
alphas, gammas = self.get_sov_matrices(locs)
elif sov_data is not None:
alphas = sov_data[0]
gammas = sov_data[1]
else:
raise IOError(
"One of the kwargs `loc_arg` or `sov_data` must not be ``None``"
)
n_loc = gammas.shape[1]
if freqs is None:
# construct the 2d steady state matrix
y_activation = 1.0 / alphas
# compute the matrix, methods depends on memory limit
if gammas.shape[1] < mem_limit and gammas.shape[0] < int(mem_limit / 2.0):
z_mat = np.sum(
gammas[:, :, np.newaxis]
* gammas[:, np.newaxis, :]
* y_activation[:, np.newaxis, np.newaxis],
0,
).real
else:
z_mat = np.zeros((n_loc, n_loc))
for ii, jj in itertools.product(range(n_loc), range(n_loc)):
z_mat[ii, jj] = np.sum(
gammas[:, ii] * gammas[:, jj] * y_activation
).real
else:
# construct the 3d fourrier matrix
y_activation = 1e3 / (alphas[np.newaxis, :] * 1e3 + freqs[:, np.newaxis])
z_mat = np.zeros((len(freqs), n_loc, n_loc), dtype=complex)
for ii, jj in itertools.product(range(n_loc), range(n_loc)):
z_mat[:, ii, jj] = np.sum(
gammas[np.newaxis, :, ii]
* gammas[np.newaxis, :, jj]
* y_activation,
1,
)
return z_mat
[docs]
def construct_net(
self,
dz=50.0,
dx=10.0,
eps=1e-4,
use_hist=False,
add_lin_terms=True,
improve_input_impedance=False,
pprint=False,
):
"""
Construct a Neural Evaluation Tree (NET) for this cell. The locations
for which impedance values are computed are stored under the name
`net eval`
Parameters
----------
dz: float
the impedance step for the NET model derivation
dx: float
the distance step to evaluate the impedance matrix
eps: float
the cutoff threshold in relative importance below which modes
are truncated
use_hist: bool
whether or not to use histogram segmentations to find well
separated parts of the dendritic tree (such ass apical tree)
add_lin_terms:
take into account that the optained NET will be used in conjunction
with linear terms
Returns
-------
`neat.NETree`
The neural evaluation tree (Wybo et al., 2019) associated with the
morphology.
"""
# create a set of location at which to evaluate the impedance matrix
self.distribute_locs_uniform(dx=dx, name="net eval")
# compute the z_mat matrix
alphas, gammas = self.get_important_modes(loc_arg="net eval", eps=eps)
z_mat = self.calc_impedance_matrix(sov_data=(alphas, gammas))
# derive the NET
net = NET()
self._add_layer_a(
net,
None,
z_mat,
alphas,
gammas,
0.0,
0,
np.arange(len(self.get_locs("net eval"))),
dz=dz,
use_hist=use_hist,
add_lin_terms=add_lin_terms,
pprint=pprint,
)
net.set_new_loc_idxs()
if improve_input_impedance:
self._improve_input_impedance(net, alphas, gammas)
if add_lin_terms:
lin_terms = self.compute_lin_terms(net, sov_data=(alphas, gammas))
return net, lin_terms
else:
return net
def _add_layer_a(
self,
net,
pnode,
z_mat,
alphas,
gammas,
z_max_prev,
z_ind_0,
true_loc_idxs,
dz=100.0,
use_hist=True,
add_lin_terms=False,
pprint=False,
):
# create a histogram
n_bin = 15
z_hist = np.histogram(z_mat[0, :], n_bin, density=False)
# find the histogram partition
h_ftc = hs.histogramSegmentator(z_hist)
s_inds, p_inds = h_ftc.partition_fine_to_coarse(eps=1.4)
while len(s_inds) > 3:
s_inds = np.delete(s_inds, 1)
# identify the necessary node indices and kernel computation indices
node_inds = []
kernel_inds = []
min_inds = []
for ii, si in enumerate(s_inds[:-1]):
if si > 0:
n_inds = np.where(z_mat[0, :] > z_hist[1][si + 1])[0]
k_inds = np.where(
np.logical_and(
z_mat[0, :] > z_hist[1][si + 1],
z_mat[0, :] <= z_hist[1][s_inds[ii + 1] + 1],
)
)[0]
min_ind = np.argmin(z_mat[0, k_inds])
min_inds.append(min_ind)
else:
n_inds = np.where(z_mat[0, :] >= z_hist[1][0])[0]
k_inds = np.where(
np.logical_and(
z_mat[0, :] >= z_hist[1][0],
z_mat[0, :] <= z_hist[1][s_inds[ii + 1] + 1],
)
)[0]
min_ind = np.argmin(z_mat[0, k_inds])
min_inds.append(min_ind)
node_inds.append(n_inds)
kernel_inds.append(k_inds)
# add NET nodes to the NET tree
for ii, n_inds in enumerate(node_inds):
k_inds = kernel_inds[ii]
if len(k_inds) != 0:
if add_lin_terms:
# get the minimal kernel
gammas_avg = gammas[:, 0] * gammas[:, k_inds[min_inds[ii]]]
else:
# get the average kernel
if len(k_inds) < 100000:
gammas_avg = np.mean(gammas[:, 0:1] * gammas[:, k_inds], 1)
else:
inds_ = np.random.choice(k_inds, size=100000)
gammas_avg = np.mean(gammas[:, 0:1] * gammas[:, inds_], 1)
z_avg_approx = np.sum(gammas_avg / alphas).real
self._subtract_parent_kernels(gammas_avg, pnode)
# add a node to the tree
node = NETNode(
len(net), true_loc_idxs[n_inds], z_kernel=(alphas, gammas_avg)
)
if pnode != None:
net.add_node_with_parent(node, pnode)
else:
net.root = node
# set new pnode
pnode = node
# print stuff
if pprint:
print(node)
print("n_loc =", len(node.loc_idxs))
print("(locind0, size) = ", (k_inds[0], z_mat.shape[0]))
print("")
if k_inds[0] == 0:
# start new branches, split where they originate from soma by
# checking where input impedance is close to somatic transfer
# impedance
z_max = z_hist[1][s_inds[ii + 1]]
# check where new dendritic branches start
z_diag = z_mat[k_inds, k_inds]
z_x0 = z_mat[k_inds, 0]
b_inds = np.where(np.abs(z_diag - z_x0) < dz / 2.0)[0][1:].tolist()
if len(b_inds) > 0:
if b_inds[0] != 1:
b_inds = [1] + b_inds
kk = len(b_inds) - 1
while kk > 0:
if b_inds[kk] - 1 == b_inds[kk - 1]:
del b_inds[kk]
kk -= 1
else:
b_inds = [1]
for jj, i0 in enumerate(b_inds):
# make new z_mat matrix
i1 = len(k_inds) if i0 == b_inds[-1] else b_inds[jj + 1]
inds = np.meshgrid(k_inds[i0:i1], k_inds[i0:i1], indexing="ij")
z_mat_new = copy.deepcopy(z_mat[inds[0], inds[1]])
# move further in the tree
self._add_layer_b(
net,
node,
z_mat_new,
alphas,
gammas,
z_max,
k_inds[i0:i1],
dz=dz,
use_hist=use_hist,
add_lin_terms=add_lin_terms,
)
else:
# make new z_mat matrix
k_seqs = _consecutive(k_inds)
if pprint:
print("\n>>> consecutive")
print("nseq:", len(k_seqs))
for k_seq in k_seqs:
print("sequence:", k_seq)
for k_seq in k_seqs:
inds = np.meshgrid(k_seq, k_seq, indexing="ij")
z_mat_new = copy.deepcopy(z_mat[inds[0], inds[1]])
z_max = z_mat[0, 0] + 1
# move further in the tree
self._add_layer_b(
net,
node,
z_mat_new,
alphas,
gammas,
z_max,
k_seq,
dz=dz,
pprint=pprint,
use_hist=use_hist,
add_lin_terms=add_lin_terms,
)
def _add_layer_b(
self,
net,
pnode,
z_mat,
alphas,
gammas,
z_max_prev,
true_loc_idxs,
dz=100.0,
use_hist=True,
pprint=False,
add_lin_terms=False,
):
# print stuff
if pprint:
print(">>> node index = ", node._index)
if pnode != None:
print("parent index = ", pnode._index)
else:
print("start")
# get the diagonal
z_diag = np.diag(z_mat)
if true_loc_idxs[0] == 0 and z_mat[0, 0] > z_max_prev:
n_bins = "soma"
z_max = z_mat[0, 0] + 1.0
z_min = z_max_prev
else:
# histogram GF
n_bins = max(
int(z_mat.size / 50.0), int((np.max(z_mat) - np.min(z_mat)) / dz)
)
if n_bins > 1:
if np.all(np.diff(z_diag) > 0):
z_min = z_max_prev
z_max = z_min + dz
if pprint:
print("--> +", dz)
elif use_hist:
z_hist = np.histogram(z_mat.flatten(), n_bins, density=False)
# find the histogram partition
h_ftc = hs.histogramSegmentator(z_hist)
s_ind, p_ind = h_ftc.partition_fine_to_coarse()
# get the new min max values
z_histx = z_hist[1]
z_min = z_max_prev
z_max = z_histx[s_ind[1]]
ii = 1
while np.min(z_diag) > z_histx[s_ind[ii]]:
ii += 1
z_max = z_histx[s_ind[ii]]
ii = np.argmax(z_hist[0][s_ind[0] : s_ind[ii]])
z_avg = z_hist[0][ii]
if z_max - z_min > dz:
z_max = z_min + dz
if pprint:
print("--> hist: +", str(z_max - z_min))
else:
z_min = z_max_prev
z_max = z_min + dz
if pprint:
print("--> +", dz)
else:
z_min = z_max_prev
z_max = np.max(z_mat)
if pprint:
print("--> all: +", str(z_max - z_min))
d_inds = np.where(z_diag <= z_max + 1e-15)[0]
# make sure that there is at least one element in the layer
while len(d_inds) == 0:
z_max += dz
d_inds = np.where(z_diag <= z_max + 1e-15)[0]
# identify different domains
if add_lin_terms and true_loc_idxs[0] == 0:
t0 = np.array([1])
t1 = np.array([len(z_diag)])
else:
t0 = np.where(
np.logical_and(z_diag[:-1] < z_max + 1e-15, z_diag[1:] >= z_max + 1e-15)
)[0]
if len(t0) > 0:
t0 += 1
if z_diag[0] >= z_max + 1e-15:
t0 = np.concatenate(([0], t0))
t1 = np.where(
np.logical_and(z_diag[:-1] >= z_max + 1e-15, z_diag[1:] < z_max + 1e-15)
)[0]
if len(t1) > 0:
t1 += 1
if z_diag[-1] >= z_max + 1e-15:
t1 = np.concatenate((t1, [len(z_diag)]))
# identify where the kernels are within the interval
l_inds = np.where(z_mat <= z_max + 1e-15)
# get the average kernel
if l_inds[0].size < 100000:
gammas_avg = np.mean(
gammas[:, true_loc_idxs[l_inds[0]]]
* gammas[:, true_loc_idxs[l_inds[1]]],
1,
)
else:
inds_ = np.random.randint(l_inds[0].size, size=100000)
gammas_avg = np.mean(
gammas[:, true_loc_idxs[l_inds[0]][inds_]]
* gammas[:, true_loc_idxs[l_inds[1]][inds_]],
1,
)
self._subtract_parent_kernels(gammas_avg, pnode)
# add a node to the tree
node = NETNode(len(net), true_loc_idxs, z_kernel=(alphas, gammas_avg))
if pnode != None:
net.add_node_with_parent(node, pnode)
else:
net.root = node
if pprint:
print("(locind0, size) = ", (true_loc_idxs[0], z_mat.shape[0]))
print("(zmin, zmax, n_bins) = ", (z_min, z_max, n_bins))
print("")
# move on to the next layers
if len(d_inds) < len(z_diag):
for jj, ind0 in enumerate(t0):
ind1 = t1[jj]
z_mat_new = copy.deepcopy(z_mat[ind0:ind1, ind0:ind1])
true_loc_idxs_new = true_loc_idxs[ind0:ind1]
self._add_layer_b(
net,
node,
z_mat_new,
alphas,
gammas,
z_max,
true_loc_idxs_new,
dz=dz,
use_hist=use_hist,
pprint=pprint,
)
def _subtract_parent_kernels(self, gammas, pnode):
if pnode != None:
gammas -= pnode.z_kernel["c"]
self._subtract_parent_kernels(gammas, pnode.parent_node)
def _improve_input_impedance(self, net, alphas, gammas):
nmaxind = np.max([n.index for n in net])
for node in net:
if len(node.loc_idxs) == 1:
# recompute the kernel of this single loc layer
if node.parent_node is not None:
p_kernel = net.calc_total_kernel(node.parent_node)
p_k_c = p_kernel.c
else:
p_k_c = np.zeros_like(gammas)
gammas_real = gammas[:, node.loc_idxs[0]] ** 2
node.z_kernel.c = gammas_real - p_k_c
elif len(node.newloc_idxs) > 0:
z_k_approx = net.calc_total_kernel(node)
# add new input nodes for the nodes that don't have one
for ind in node.newloc_idxs:
nmaxind += 1
gammas_real = gammas[:, ind] ** 2
z_k_real = Kernel(dict(a=alphas, c=gammas_real))
# add node
newnode = NETNode(nmaxind, [ind], z_kernel=z_k_real - z_k_approx)
newnode.newloc_idxs = [ind]
net.add_node_with_parent(newnode, node)
# empty the new indices
node.newloc_idxs = []
net.set_new_loc_idxs()
[docs]
def compute_lin_terms(self, net, sov_data=None, eps=1e-4):
"""
Construct linear terms for `net` so that transfer impedance to soma is
exactly matched
Parameters
----------
net: `neat.NETree`
the neural evaluation tree (NET)
sov_data: None or tuple of mode matrices
If ``sov_data`` is not ``None``, it is a tuple of a vector of
the reciprocals of the mode timescales and a matrix with the
corresponding spatial mode functions.
eps: float
the cutoff threshold in relative importance below which modes
are truncated
Returns
-------
lin_terms: dict of {int: `neat.Kernel`}
the kernels associated with linear terms of the NET, keys are
indices of their corresponding location stored inder 'net eval'
"""
if sov_data != None:
alphas = sov_data[0]
gammas = sov_data[1]
else:
alphas, gammas = self.get_important_modes(loc_arg="net eval", eps=eps)
lin_terms = {}
for ii, loc in enumerate(self.get_locs("net eval")):
if not self.is_root(self[loc["node"]]):
# create the true kernel
z_k_true = Kernel((alphas, gammas[:, ii] * gammas[:, 0]))
# compute the NET approximation kernel
z_k_net = net.get_reduced_tree([0, ii]).get_root().z_kernel
# compute the lin term
lin_terms[ii] = z_k_true - z_k_net
return lin_terms