# -*- coding: utf-8 -*-
#
# compartmentfitter.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
import matplotlib.pyplot as pl
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from ..trees.stree import STree
from ..trees.phystree import PhysTree
from ..trees.compartmenttree import CompartmentTree
from ..tools.kernelextraction import Kernel
from ..channels.ionchannels import SPDict
from ..factorydefaults import FitParams, MechParams
from .cachetrees import CachedGreensTree, CachedSOVTree, EquilibriumTree
import copy
import pathlib
import warnings
def _statevar_is_activating(f_statevar):
"""
check whether a statevar is activating or inactivating
Parameters
----------
f_statevar: callable
the activation function of the state variable
"""
# test voltage values to check whether state variable is activating or
# inactivating
v_test = np.array([-43.22, -32.22])
sv_test = f_statevar(v_test)
return sv_test[0] < sv_test[1]
def _get_two_variable_holding_potentials(e_hs):
e_hs_aux_act = list(e_hs)
e_hs_aux_inact = list(e_hs)
for ii, e_h1 in enumerate(e_hs):
for jj, e_h2 in enumerate(e_hs):
e_hs_aux_act.append(e_h1)
e_hs_aux_inact.append(e_h2)
e_hs_aux_act = np.array(e_hs_aux_act)
e_hs_aux_inact = np.array(e_hs_aux_inact)
return e_hs_aux_act, e_hs_aux_inact
def get_expansion_points(e_hs, channel, only_e_h=False):
"""
Returns a list of expansion points around which to compute the impedance
matrix given a set of holding potentials. If the channel has only one state
variable, the returned expansion points are at the holding potentials, if
the channels has two state variables, the returned expansions points are
are different combinations of the state variable values around the holding
potentials
Parameters
----------
e_hs: iterable collection
The holding potentials around which the expansion points are computed
channel: `neat.channels.ionchannels.IonChannel`
The ion channels
only_e_h: bool
Only applicable for channels with at least two state variables.
If True, returned expansion points are always for state variable
combination evaluated at the same holding potential. Otherwise,
state variable activations are evaluated at different holding potentials.
Returns
-------
sv_hs: dict
the expansion points at every holding potential
"""
if len(channel.statevars) == 1 or only_e_h:
sv_hs = channel.compute_varinf(e_hs)
sv_hs["v"] = e_hs
else:
# create different combinations of holding potentials
e_hs_aux_act, e_hs_aux_inact = _get_two_variable_holding_potentials(e_hs)
sv_hs = SPDict(v=e_hs_aux_act)
for svar, f_inf in channel.f_varinf.items():
# check if variable is activation
if _statevar_is_activating(f_inf): # variable is activation
sv_hs[str(svar)] = f_inf(e_hs_aux_act)
else: # variable is inactivation
sv_hs[str(svar)] = f_inf(e_hs_aux_inact)
return sv_hs
[docs]
class CompartmentFitter(EquilibriumTree):
"""
Tree class that streamlines fitting reduced compartmental models
Attributes
----------
tree: `neat.PhysTree`
The full tree based on which reductions are made
fit_cfg: `neat.FitParams`
The fit parameters
concmech_cfg: `neat.MechParams`
The concentration mechanisms parameters
model_fits: dict of `{str: dict}`
Data structure with already performed model fits, where keys are the provided names.
Each entry is a dict of the form
`{'ctree': neat.CompartmentTree, 'locs': list of neat.MorphLoc}`
cache_name: str (default '')
name of files in which intermediate trees required for the fit are
cached.
cache_path: str (default '')
specify a path under which the intermediate files are cached.
Default is empty string, which means that intermediate files are stored
in the working directory.
save_cache: bool (default `True`)
Save the intermediate results in a cache (using `cache_path` and
`cache_name`).
recompute_cache: bool (default `False`)
Forces recomputing the caches.
"""
def __init__(self, *args, fit_cfg=None, concmech_cfg=None, **kwargs):
if (
len(args) == 0
or isinstance(args[0], str)
or isinstance(args[0], pathlib.Path)
or (
issubclass(type(args[0]), STree)
and not issubclass(type(args[0]), PhysTree)
)
):
call_post_init_in_contructor = False
# if the initialization argument is not provided (empty tree),
# or if it is a .swc-filename string,
# or if it is a tree class that is likely to require further build operations after
# calling this constructor, we do not call `post_init()` in this constructor, but
# raise a warning that it has to be called manually
warnings.warn(
f"Initialization of a {self.__class__.__name__}"
f"-instance as a tree that still has to be built, "
f"be sure to call `{self.__class__.__name__}.post_init()` after building the tree."
)
else:
call_post_init_in_contructor = True
self.fitted_models = {}
self.fit_cfg = None
self.concmech_cfg = None
super().__init__(*args, **kwargs)
if self.fit_cfg is None:
self.fit_cfg = FitParams()
elif fit_cfg is not None:
self.fit_cfg = fit_cfg
if self.concmech_cfg is None:
self.concmech_cfg = MechParams()
elif concmech_cfg is not None:
self.concmech_cfg = concmech_cfg
if call_post_init_in_contructor:
self.post_init()
# boolean flag that is reset the first time `self.fit_passive` is called
self.use_all_channels_for_passive = True
def set_cfg(self, fit_cfg=None, concmech_cfg=None):
self.fit_cfg = fit_cfg
if fit_cfg is None:
self.fit_cfg = FitParams()
self.concmech_cfg = concmech_cfg
if concmech_cfg is None:
self.concmech_cfg = MechParams()
def post_init(self):
with self.as_original_tree:
# set the equilibrium potentials in the tree
self.set_e_eq(pprint=True)
def get_attributes_excluded_from_cache_override(self):
"""
Returns a list of attributes that should NOT be overwritten by the cashed tree
Returns
-------
list of str
Attribute names that should not be overwritten
"""
return super().get_attributes_excluded_from_cache_override() + [
"fit_cfg",
"concmech_cfg",
]
[docs]
def convert_fit_arg(self, fit_arg):
"""
Convert a fit argument, which can be a tuple, dict or tuple, to a tuple
consisting of a `neat.CompartmentTree` that is either fitted, or in the
process of being fitted, and the corresponding list of locations.
Parameters
----------
fit_arg : string, dict, or tuple
If string, the provided argument is interpreted as the fit name.
If dict, the provided argument is a dictionary of the form
`{'ctree': neat.CompartmentTree, 'locs': <list of locations>}`.
If tuple, the provided argument is a tuple of the form
`(neat.CompartmentTree, <list of locations>}`.
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is (in the process of being) fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
Raises
------
TypeError
If `fit_arg` does not correspond to one of the above described arguments.
"""
if isinstance(fit_arg, str):
return (
self.fitted_models[fit_arg]["ctree"],
self.fitted_models[fit_arg]["locs"],
)
elif isinstance(fit_arg, dict):
return fit_arg["ctree"], fit_arg["locs"]
elif issubclass(type(fit_arg[0]), CompartmentTree):
return fit_arg[0], fit_arg[1]
else:
raise TypeError(
"Invalid type for `fit_arg`, should be string, "
"dict with {'ctree': neat.CompartmentTree, 'locs': list of locations}, "
"or a tuple of (neat.CompartmentTree, list of locations)"
)
def _store_fit(self, ctree, locs, fit_name=""):
if len(fit_name) > 0:
self.store_locs(locs, name=fit_name)
self.fitted_models[fit_name] = {
"ctree": ctree,
"locs": self.get_locs(name=fit_name),
"complete": False,
}
def remove_fit(self, fit_name):
try:
del self.fitted_models[fit_name]
except KeyError:
warnings.warn(f"Fit with name '{fit_name}' not in stored fits.")
self.remove_locs(fit_name)
[docs]
def set_ctree(self, loc_arg, fit_name="", extend_w_bifurc=True, pprint=False):
"""
Store an initial `neat.CompartmentTree`, providing a tree
structure scaffold for the fit for a given set of locations. The
locations are also stored on ``self`` under the name 'fit locs'
Parameters
----------
loc_arg: list of locations or string (see documentation of
:func:`MorphTree.convert_loc_arg_to_locs` for details)
The compartment locations
fit_name: str (optional, default: '')
The name of the fit. If provided, the resulting
`neat.CompartmentTree` and list of fit locations will be
stored. They can be accessed under the `fitted_models` attribute
of `neat.CompartmentFitter`.
extend_w_bifurc: bool (optional, default `True`)
To extend the compartment locations with all intermediate
bifurcations (see documentation of
:func:`MorphTree.extend_with_bifurcation_locs`).
pprint: bool
whether to print additional info
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
locs = self.convert_loc_arg_to_locs(loc_arg)
if extend_w_bifurc:
locs = self.extend_with_bifurcation_locs(locs)
else:
warnings.warn(
"Not adding bifurcations to `loc_arg`, this could "
"lead to inaccurate fits. To add bifurcation, set"
"kwarg `extend_w_bifurc` to ``True``"
)
# create the reduced compartment tree
ctree = self.create_compartment_tree(locs)
# store the fit
self._store_fit(ctree, locs, fit_name=fit_name)
# add currents to compartmental model
for c_name, channel in self.channel_storage.items():
e_revs = []
for node in self:
if c_name in node.currents:
e_revs.append(node.currents[c_name][1])
# reversal potential is the same throughout the reduced model
ctree.add_channel_current(copy.deepcopy(channel), np.mean(e_revs))
for node in ctree:
loc_idx = node.loc_idx
concmechs = self[locs[loc_idx]["node"]].concmechs
# try to set default parameters as the ones from the original tree
# if the concmech is not present at the corresponding location,
# use the default parameters
for ion in self.ions:
if ion in concmechs:
cparams = {pname: pval for pname, pval in concmechs[ion].items()}
node.add_conc_mech(ion, **cparams)
else:
node.add_conc_mech(ion, **self.concmech_cfg.exp_conc_mech)
return ctree, locs
[docs]
def create_tree_gf(
self,
channel_names=[],
cache_name_suffix="",
unmasked_nodes=None,
):
"""
Create a `CachedGreensTree` copy of the original tree, but only with the
channels in ``channel_names``. Leak 'L' is included in the tree by
default.
Parameters
----------
channel_names: list of strings
List of channel names of the channels that are to be included in the
new tree.
recompute_cache: bool
Whether or not to force recompute the impedance caches
unmasked_nodes: 'node_arg' (see documentation of `MorphTree.convert_node_arg_to_nodes`)
The nodes where the channels in `channel_names` will be initialized
to non-zero values
Returns
-------
`CachedGreensTree()`
"""
unmasked_node_indices = [
node.index for node in self.convert_node_arg_to_nodes(unmasked_nodes)
]
# create new tree and empty channel storage
tree = CachedGreensTree(
self,
cache_path=self.cache_path,
cache_name=self.cache_name + cache_name_suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
tree.channel_storage = {}
# add the ion channel to the tree
channel_names_newtree = set()
for node, node_orig in zip(tree, self):
node.currents = {}
g_l, e_l = node_orig.currents["L"]
# add the current to the tree
node._add_current("L", g_l, e_l)
if node.index not in unmasked_node_indices:
continue
for channel_name in channel_names:
try:
g_max, e_rev = node_orig.currents[channel_name]
node._add_current(channel_name, g_max, e_rev)
channel_names_newtree.add(channel_name)
except KeyError:
pass
tree.channel_storage = {
channel_name: self.channel_storage[channel_name]
for channel_name in channel_names_newtree
}
tree.set_comp_tree(eps=self.fit_cfg.fit_comptree_eps)
return tree
def _eval_channel(self, fit_arg, channel_name, pprint=False):
"""
Evaluate the impedance matrix for the model restricted to a single ion
channel type.
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
channel_name: string
The name of the ion channel under consideration
pprint: bool (optional, defaults to ``False``)
whether to print information
Return
------
fit_mats
list of fit matrices
"""
ctree, locs = self.convert_fit_arg(fit_arg)
# find the expansion point parameters for the channel
channel = self.channel_storage[channel_name]
sv_h = get_expansion_points(self.fit_cfg.e_hs, channel)
# create the trees with only a single channel and multiple expansion points
fit_tree = self.create_tree_gf(
[channel_name],
cache_name_suffix=f"_{channel_name}_",
)
# set the impedances in the tree
fit_tree.set_impedances_in_tree(
freqs=self.fit_cfg.freqs, sv_h={channel_name: sv_h}, pprint=pprint
)
# compute the impedance matrix for this activation level
z_mats = fit_tree.calc_impedance_matrix(locs)[None, :, :, :]
# compute the fit matrices for all holding potentials
fit_mats = []
for ii, e_h in enumerate(sv_h["v"]):
sv = SPDict(
{
str(svar): sv_h[svar][ii]
for svar in channel.statevars
if str(svar) != "v"
}
)
# compute the fit matrices
m_f, v_t = ctree.compute_g_single_channel(
channel_name,
z_mats[:, ii, :, :],
e_h,
np.array([self.fit_cfg.freqs]),
sv=sv,
other_channel_names=["L"],
all_channel_names=[channel_name],
action="return",
)
# compute open probability to weigh fit matrices
po_h = channel.compute_p_open(e_h, **sv)
w_f = 1.0 / po_h
fit_mats.append([m_f, v_t, w_f])
# fit the model for this channel
w_norm = 1.0 / np.sum([w_f for _, _, w_f in fit_mats])
for _, _, w_f in fit_mats:
w_f /= w_norm
# store the fit matrices
for m_f, v_t, w_f in fit_mats:
if not (np.isnan(m_f).any() or np.isnan(v_t).any() or np.isnan(w_f).any()):
ctree._fit_res_action(
"store", m_f, v_t, w_f, channel_names=[channel_name]
)
# run the fit
ctree.run_fit()
return fit_mats
[docs]
def fit_channels(self, fit_arg, pprint=False):
"""
Fit the active ion channel parameters
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
pprint: bool (optional, defaults to ``False``)
whether to print information
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
for channel_name in self.get_channels_in_tree():
self._eval_channel(fit_arg, channel_name, pprint=pprint)
return self.convert_fit_arg(fit_arg)
def _calibrate_conc_mechs(self, ion, orig_node, comp_node):
"""
Set the `gamma` factor of the concentration mechanism based on the ratio
of fitted conducances over original conductances permeable to the
associated ion
Parameters
----------
orig_node: `neat.PhysNode`
the original node corresponding to the location of the compartment
comp_node: `neat.CompartmentNode`
the fitted compartment node
"""
if ion not in orig_node.concmechs:
return
channel_storage = self.channel_storage
currents_orig = copy.deepcopy(orig_node.currents)
currents_comp = copy.deepcopy(comp_node.currents)
# compute g_max for the ion
g_ion_orig, g_ion_comp = 0.0, 0.0
for cname in orig_node.currents:
if cname in channel_storage and ion == channel_storage[cname].ion:
g_ion_orig += currents_orig.pop(cname, [0.0, 0.0])[0]
g_ion_comp += currents_comp.pop(cname, [0.0, 0.0])[0]
try:
comp_node.concmechs[ion].gamma = (
orig_node.concmechs[ion].gamma * g_ion_orig / g_ion_comp
)
except ZeroDivisionError:
# no Ca current so we rescale based on leak
# maybe concmech can be removed at this node?
g_l_orig = orig_node.currents["L"][0]
g_l_comp = comp_node.currents["L"][0]
comp_node.concmechs[ion].gamma = (
orig_node.concmechs[ion].gamma * g_l_orig / g_l_comp
)
[docs]
def fit_concentration(self, fit_arg, ion):
"""
Fits the concentration mechanisms parameters associate with the `ion`
ion type.
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
ion: str
The ion type that is to be fitted (e.g. 'ca').
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
ctree, locs = self.convert_fit_arg(fit_arg)
has_concmech = False
for node in self:
if ion in node.concmechs:
has_concmech = True
break
if not has_concmech:
return 0
orig_nodes = [self[loc["node"]] for loc in locs]
comp_nodes = ctree.get_nodes_from_loc_idxs(list(range(len(locs))))
for orig_node, comp_node in zip(orig_nodes, comp_nodes):
self._calibrate_conc_mechs(ion, orig_node, comp_node)
return ctree, locs
[docs]
def fit_passive(self, fit_arg, use_all_channels=True, pprint=False):
"""
Fit the steady state passive model, consisting only of leak and coupling
conductances, but ensure that the coupling conductances takes the passive
opening of all channels into account
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
use_all_channels: bool (optional)
use leak at rest of all channels combined in the passive fit (passive
leak has to be refit after capacitance fit)
pprint: bool (optional, defaults to ``False``)
whether to print information
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
ctree, locs = self.convert_fit_arg(fit_arg)
self.use_all_channels_for_passive = use_all_channels
suffix = "_pas_"
if use_all_channels:
suffix = f"_passified_"
if use_all_channels:
fit_tree = EquilibriumTree(self)
fit_tree.set_cache_params(
cache_path=self.cache_path,
cache_name=self.cache_name + "_eq" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
# set the channels to passive
fit_tree.as_passive_membrane()
# convert to a greens tree for further evaluation
fit_tree = CachedGreensTree(
fit_tree,
cache_path=self.cache_path,
cache_name=self.cache_name + "_gf" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
fit_tree.set_comp_tree(eps=self.fit_cfg.fit_comptree_eps)
else:
fit_tree = self.create_tree_gf(
[], # empty list of channel to include
cache_name_suffix=suffix,
)
# set the impedances in the tree
fit_tree.set_impedances_in_tree(freqs=0.0, pprint=pprint)
# compute the steady state impedance matrix
z_mat = fit_tree.calc_impedance_matrix(locs)
# fit the coupling+leak conductances to steady state impedance matrix
ctree.compute_gmc(z_mat, channel_names=["L"])
# print passive impedance matrices
if pprint:
z_mat_fit = ctree.calc_impedance_matrix(channel_names=["L"])
np.set_printoptions(precision=2, edgeitems=10, linewidth=500, suppress=True)
print("\n----- Impedance matrix comparison -----")
print("> Zmat orig =")
print(z_mat)
print("> Zmat fit =")
print(z_mat_fit)
print("> Zmat diff =")
print(z_mat - z_mat_fit)
print("---------------------------------------\n")
# restore defaults
np.set_printoptions(precision=8, edgeitems=3, linewidth=75, suppress=False)
return ctree, locs
[docs]
def fit_leak_only(self, fit_arg, pprint=True):
"""
Fit leak only. Coupling conductances have to have been fit already.
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
pprint: bool (optional, defaults to ``False``)
whether to print information
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
ctree, locs = self.convert_fit_arg(fit_arg)
# compute the steady state impedance matrix
fit_tree = self.create_tree_gf(
[],
cache_name_suffix="_only_leak_",
)
# set the impedances in the tree
fit_tree.set_impedances_in_tree(self.fit_cfg.freqs, pprint=pprint)
# compute the steady state impedance matrix
z_mat = fit_tree.calc_impedance_matrix(locs)[None, :, :]
# fit the conductances to steady state impedance matrix
ctree.compute_g_single_channel(
"L",
z_mat,
-75.0,
np.array([self.fit_cfg.freqs]),
other_channel_names=[],
action="fit",
)
# print passive impedance matrices
if pprint:
z_mat_fit = ctree.calc_impedance_matrix(channel_names=["L"])
np.set_printoptions(precision=2, edgeitems=10, linewidth=500, suppress=True)
print("\n----- Impedance matrix comparison -----")
print("> Zmat orig =")
print(z_mat)
print("> Zmat fit =")
print(z_mat_fit)
print("> Zmat diff =")
print(z_mat - z_mat_fit)
print("---------------------------------------\n")
# restore defaults
np.set_printoptions(precision=8, edgeitems=3, linewidth=75, suppress=False)
return ctree, locs
[docs]
def create_tree_sov(self):
"""
Create a `SOVTree` copy of the old tree
Parameters
----------
channel_names: list of strings
List of channel names of the channels that are to be included in the
new tree
Returns
-------
`neat.tools.fittools.compartmentfitter.CachedSOVTree`
"""
if self.use_all_channels_for_passive:
cache_name_suffix = "_SOV_allchans_"
else:
cache_name_suffix = "_SOV_only_leak_"
# create new tree and empty channel storage
tree = CachedSOVTree(
self,
cache_path=self.cache_path,
cache_name=self.cache_name + cache_name_suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
if not self.use_all_channels_for_passive:
tree.channel_storage = {}
for node, node_orig in zip(tree, self):
node.currents = {}
g_l, e_l = node_orig.currents["L"]
# add the current to the tree
node._add_current("L", g_l, e_l)
# set the computational tree
tree.set_comp_tree(eps=self.fit_cfg.fit_comptree_eps)
return tree
def _calc_sov_mats(self, locs, pprint=False):
"""
Use a `neat.SOVTree` to compute SOV matrices for fit
"""
# create an SOV tree
sov_tree = self.create_tree_sov()
# compute the SOV expansion for this tree
sov_tree.set_sov_in_tree(pprint=pprint)
# get SOV constants
alphas, phimat, importance = sov_tree.get_important_modes(
loc_arg=locs, sort_type="importance", eps=1e-12, return_importance=True
)
alphas = alphas.real
phimat = phimat.real
return alphas, phimat, importance, sov_tree
[docs]
def fit_capacitance(
self,
fit_arg,
inds=[0],
check_fit=True,
force_tau_m_fit=False,
pprint=False,
pplot=False,
):
"""
Fit the capacitances of the model to the largest SOV time scale
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
inds: list of int (optional, defaults to ``[0]``)
indices of eigenmodes used in the fit. Default is [0], indicating
the largest eigenmode
check_fit: bool (optional, default ``True``)
Check whether the largest eigenmode of the reduced model is within
tolerance of the largest eigenmode of the full tree. If not,
capacitances are set to mach membrane time scale
force_tau_m_fit: bool (optional, default ``False``)
force capacitance fit through membrance time scale matching
pprint: bool (optional, defaults to ``False``)
whether to print information
pplot: bool (optional, defaults to ``False``)
whether to plot the eigenmode timescales
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
ctree, locs = self.convert_fit_arg(fit_arg)
# compute SOV matrices for fit
try:
alphas, phimat, importance, sov_tree = self._calc_sov_mats(
locs, pprint=pprint
)
# fit the capacitances from SOV time-scales
ctree.compute_c(
-alphas[inds] * 1e3, phimat[inds, :], weights=importance[inds]
)
def calcTau():
nm = len(locs)
# original timescales
taus_orig = np.sort(np.abs(1.0 / alphas))[::-1][:nm]
# fitted timescales
lambdas, _, _ = ctree.calc_eigenvalues()
taus_fit = np.sort(np.abs(1.0 / lambdas))[::-1]
return taus_orig, taus_fit
taus_orig, taus_fit = calcTau()
if check_fit:
fit_not_sane = np.abs(taus_fit[0] - taus_orig[0]) > 0.8 * taus_orig[0]
else:
fit_not_sane = False
except Exception as e:
if pprint:
print(
f"Issue in SOV calculations:\n{e}\n> reverting to membrane timescale matching"
)
sov_tree = self.create_tree_sov()
fit_not_sane = True
def calcTauM():
clocs = [locs[n.loc_idx] for n in ctree]
# original membrane time scales
taus_m = []
for l in clocs:
g_m = sov_tree[l[0]].calc_g_tot(
channel_storage=sov_tree.channel_storage
)
taus_m.append(self[l[0]].c_m / g_m * 1e3)
taus_m_orig = np.array(taus_m)
# fitted membrance time scales
taus_m_fit = (
np.array([node.ca / node.currents["L"][0] for node in ctree]) * 1e3
)
return taus_m_orig, taus_m_fit
if fit_not_sane or force_tau_m_fit:
taus_m_orig, taus_m_fit = calcTauM()
# if fit was not sane, revert to more basic membrane timescale match
for ii, node in enumerate(ctree):
node.ca = node.currents["L"][0] * taus_m_orig[ii] * 1e-3
warnings.warn(
"No sane capacitance fit achieved for this configuragion,"
+ "reverted to more basic membrane time scale matching."
)
if pprint:
# mode time scales
taus_orig, taus_fit = calcTau()
# membrane time scales
taus_m_orig, taus_m_fit = calcTauM()
np.set_printoptions(
precision=2, edgeitems=10, linewidth=500, suppress=False
)
print("\n----- capacitances -----")
print(("Ca (uF) =\n" + str([nn.ca for nn in ctree])))
print("\n----- Eigenmode time scales -----")
print(("> Taus original (ms) =\n" + str(taus_orig)))
print(("> Taus fitted (ms) =\n" + str(taus_fit)))
print("\n----- Membrane time scales -----")
print(("> Tau membrane original (ms) =\n" + str(taus_m_orig)))
print(("> Tau membrane fitted (ms) =\n" + str(taus_m_fit)))
print("---------------------------------\n")
# restore default print options
np.set_printoptions(precision=8, edgeitems=3, linewidth=75, suppress=False)
if pplot:
self.plot_kernels(fit_arg, alphas=alphas, phimat=phimat)
return ctree, locs
def plot_sov(
self, fit_arg, alphas=None, phimat=None, importance=None, n_mode=8, alphas2=None
):
ctree, fit_locs = self.convert_fit_arg(fit_arg)
if alphas is None or phimat is None or importance is None:
alphas, phimat, importance, _ = self._calc_sov_mats(fit_locs, pprint=False)
if alphas2 is None:
alphas2, _, _ = ctree.calc_eigenvalues()
colours = list(pl.rcParams["axes.prop_cycle"].by_key()["color"])
loc_colours = np.array(
[colours[ii % len(colours)] for ii in range(len(fit_locs))]
)
markers = Line2D.filled_markers
pl.figure("SOV", figsize=(10, 10))
gs = GridSpec(2, 2)
ax1, ax2, ax3 = pl.subplot(gs[0, 0]), pl.subplot(gs[0, 1]), pl.subplot(gs[1, :])
# x axis modes
x_arr = np.arange(n_mode)
x_loc = np.arange(len(fit_locs))
# time scales
ax1.semilogy(x_arr, np.abs(1.0 / alphas[:n_mode]), "rD--")
if alphas2 is not None:
ax1.semilogy(
x_arr[: len(alphas2)], np.sort(np.abs(1.0 / alphas2))[::-1], "bo--"
)
ax1.set_xlabel(r"$k$")
ax2.set_ylabel(r"$\tau_k$ (ms)")
# importance
ax2.semilogy(x_arr, importance[:n_mode], "rD--")
ax2.set_xlabel(r"$k$")
ax2.set_ylabel(r"$I_k$")
# spatial modes
for kk in range(n_mode):
ax3.plot(x_loc, phimat[kk, :], ls="--", c="DarkGrey")
ax3.scatter(
x_loc,
phimat[kk, :],
c=loc_colours,
marker=markers[kk % len(markers)],
label=r"" + str(kk),
)
ax3.set_xlabel(r"$x_i$")
ax3.set_ylabel(r"$\phi_k(x_i)$")
ax3.legend(loc=0)
def _construct_kernels(self, nn, a, c):
return [[Kernel((a, c[:, ii, jj])) for ii in range(nn)] for jj in range(nn)]
[docs]
def get_kernels(
self,
fit_arg,
alphas=None,
phimat=None,
pprint=False,
):
"""
Returns the impedance kernels as a double nested list of "neat.Kernel".
The element at the position i,j represents the transfer impedance kernel
between compartments i and j.
If one of the `alphas` and or `phimat` are not provided, these SOV matrices
are recomputed.
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the compartmentree for which the kernels have to be computed.
alphas: `np.array`
The exponential coefficients, as follows from the SOV expansion
phimat: `np.ndarray` (dim=2)
The matrix to compute the exponential prefactors, as follows from
the SOV expansion
pprint: `bool`
Is verbose if ``True``
Returns
-------
k_orig: list of list of `neat.Kernel`
The kernels of the full model
k_comp: list of list of `neat.Kernel`
The kernels of the reduced model (i.e. of the compartment tree)
"""
ctree, locs = self.convert_fit_arg(fit_arg)
if alphas is None or phimat is None:
alphas, phimat, _, _ = self._calc_sov_mats(locs, pprint=pprint)
nn = len(locs)
# compute eigenvalues
alphas_comp, phimat_comp, phimat_inv_comp = ctree.calc_eigenvalues(
indexing="locs"
)
# get the kernels
k_orig = self._construct_kernels(
nn, alphas, np.einsum("ik,kj->kij", phimat.T, phimat)
)
k_comp = self._construct_kernels(
nn, -alphas_comp, np.einsum("ik,kj->kij", phimat_comp, phimat_inv_comp)
)
return k_orig, k_comp
[docs]
def plot_kernels(
self,
fit_arg,
alphas=None,
phimat=None,
t_arr=None,
):
"""
Plots the impedance kernels.
The kernel at the position i,j represents the transfer impedance kernel
between compartments i and j.
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the compartmentree for which the kernels have to be plotted.
alphas: `np.array`
The exponential coefficients, as follows from the SOV expansion
phimat: `np.ndarray` (dim=2)
The matrix to compute the exponential prefactors, as follows from
the SOV expansion
t_arr: `np.array`
The time-points at which the to be plotted kernels are evaluated.
Default is ``np.linspace(0.,200.,int(1e3))``
"""
ctree, fit_locs = self.convert_fit_arg(fit_arg)
nn = len(fit_locs)
if alphas is None or phimat is None:
alphas, phimat, _, _ = self._calc_sov_mats(fit_locs, pprint=False)
k_orig, k_comp = self.get_kernels(fit_arg, alphas=alphas, phimat=phimat)
if t_arr is None:
t_arr = np.linspace(0.0, 200.0, int(1e3))
pl.figure("Kernels", figsize=(2.0 * nn, 1.5 * nn))
gs = GridSpec(nn, nn)
gs.update(top=0.98, bottom=0.04, left=0.04, right=0.98)
colours = list(pl.rcParams["axes.prop_cycle"].by_key()["color"])
loc_colours = np.array(
[colours[ii % len(colours)] for ii in range(len(fit_locs))]
)
for ii in range(nn):
for jj in range(ii, nn):
ko, kc = k_orig[ii][jj], k_comp[ii][jj]
ax = pl.subplot(gs[ii, jj])
ax.plot(t_arr, ko(t_arr), c="DarkGrey")
ax.plot(t_arr, kc(t_arr), ls="--", c=loc_colours[jj])
# limits
ax.set_ylim((-0.5, 20.0))
# kernel label
pstring = r"%d $\leftrightarrow$ %d" % (ii, jj)
ax.set_title(pstring, pad=-10)
def _store_sov_mats(self):
fit_locs = self.get_locs("fit locs")
self.alphas, self.phimat, _, _ = self._calc_sov_mats(fit_locs, pprint=False)
def kernel_objective(self, t_arr=None):
fit_locs = self.get_locs("fit locs")
nn = len(fit_locs)
if t_arr is None:
t_arr = np.concatenate(
(np.logspace(-2, 0, 200), np.linspace(1.0, 200.0, 400)[1:])
)
k_orig, k_comp = self.get_kernels(alphas=self.alphas, phimat=self.phimat)
res = 0.0
for ii in range(nn):
for jj in range(ii, nn):
ko, kc = k_orig[ii][jj], k_comp[ii][jj]
res += np.sum((ko(t_arr) - kc(t_arr)) ** 2)
return res
[docs]
def check_passive(
self,
loc_arg,
alpha_inds=[0],
use_all_channels_for_passive=True,
force_tau_m_fit=False,
pprint=False,
):
"""
Checks the impedance kernels of the passive model.
Parameters
----------
loc_arg: list of locations or string (see documentation of
:func:`MorphTree.convert_loc_arg_to_locs` for details)
The compartment locations
alpha_inds: list of ints
Indices of all mode time-scales to be included in the fit
n_modes: int
The number of eigen modes that are shown
use_all_channels_for_passive: bool
Uses all channels in the tree to compute coupling conductances
force_tau_m_fit: bool
Force using the local membrane time-scale for capacitance fit
pprint: bool
is verbose if ``True``
"""
fit_arg = self.set_ctree(loc_arg)
# fit the passive steady state model
fit_arg = self.fit_passive(
fit_arg, use_all_channels=use_all_channels_for_passive, pprint=pprint
)
# fit the capacitances
fit_arg = self.fit_capacitance(
fit_arg,
inds=alpha_inds,
force_tau_m_fit=force_tau_m_fit,
pprint=pprint,
pplot=True,
)
_, fit_locs = self.convert_fit_arg(fit_arg)
colours = list(pl.rcParams["axes.prop_cycle"].by_key()["color"])
loc_colours = np.array(
[colours[ii % len(colours)] for ii in range(len(fit_locs))]
)
pl.figure("tree")
ax = pl.gca()
loc_args = [
dict(marker="o", mec="k", mfc=lc, markersize=6.0) for lc in loc_colours
]
self.plot_2d_morphology(
ax, marklocs=fit_locs, loc_args=loc_args, use_radius=False
)
pl.tight_layout()
pl.show()
def get_net(self, c_loc, locs, channel_names=[], pprint=False):
greens_tree = self.create_tree_gf(
channel_names=channel_names,
cache_name_suffix="_for_NET_",
)
greens_tree.set_impedances_in_tree(self.fit_cfg.freqs, pprint=False)
# create the NET
net, z_mat = greens_tree.calc_net_steadystate(c_loc)
net.improve_input_resistance(z_mat)
# prune the NET to only retain ``locs``
loc_idxs = greens_tree.get_nearest_loc_idxs([c_loc] + locs, "net eval")
net_reduced = net.get_reduced_tree(loc_idxs, indexing="locs")
return net_reduced
[docs]
def fit_e_eq(self, fit_arg):
"""
Fits the leak potentials of the reduced model to yield the same
equilibrium potentials as the full model
Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
loc_arg: `list` of locations or string
if `list` of locations, specifies the locations at which to compute
the equilibrium potentials, if ``string``, specifies the
name under which a set of location is stored
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is in the process of being fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
ctree, locs = self.convert_fit_arg(fit_arg)
# compute the equilibirum potentials at fit locations
# v_eqs_fit, conc_eqs_fit = self.calc_e_eq(locs)
v_eqs_fit, conc_eqs_fit = self.calc_e_eq(locs)
# set the equilibria
ctree.set_e_eq(v_eqs_fit)
for ion in self.ions:
ctree.set_conc_eq(ion, conc_eqs_fit[ion])
# fit the leak
ctree.fit_e_leak()
return ctree, locs
[docs]
def fit_model(
self,
loc_arg,
fit_name="",
alpha_inds=[0],
use_all_channels_for_passive=True,
pprint=False,
):
"""
Runs the full fit for a set of locations (the location are automatically
extended with the bifurcation locs)
Parameters
----------
loc_arg: list of locations or string (see documentation of
:func:`MorphTree.convert_loc_arg_to_locs` for details)
The compartment locations
fit_name: string
The name under which the fit will be stored. By default, the fit
will not be stored.
alpha_inds: list of ints
Indices of all mode time-scales to be included in the fit
use_all_channels_for_passive: bool (optional, default ``True``)
Uses all channels in the tree to compute coupling conductances
pprint: bool
whether to print information
Returns
-------
`neat.CompartmentTree`
The compartmenttree that is fitted.
list of <neat.MorphLoc>
The corresponding list of fit locations.
"""
if fit_name == "":
fit_name = "temp"
fit_arg = self.set_ctree(loc_arg, fit_name=fit_name, pprint=pprint)
# fit the passive steady state model
fit_arg = self.fit_passive(
fit_arg,
pprint=pprint,
use_all_channels=use_all_channels_for_passive,
)
# fit the capacitances
fit_arg = self.fit_capacitance(
fit_arg, inds=alpha_inds, pprint=pprint, pplot=False
)
# refit with only leak
if use_all_channels_for_passive:
fit_arg = self.fit_leak_only(fit_arg, pprint=pprint)
# fit the ion channels
fit_arg = self.fit_channels(fit_arg, pprint=pprint)
# fit the concentration mechansims
for ion in self.ions:
fit_arg = self.fit_concentration(fit_arg, ion)
# fit the resting potentials
fit_arg = self.fit_e_eq(fit_arg)
if fit_name == "temp":
self.remove_fit(fit_name)
else:
self.fitted_models[fit_name]["complete"] = True
return fit_arg
def recalc_impedance_matrix(self, loc_arg, g_syns, channel_names=None):
# process input
locs = self.convert_loc_arg_to_locs(loc_arg)
n_syn = len(locs)
assert n_syn == len(g_syns)
if n_syn == 0:
return np.array([[]])
if channel_names is None:
channel_names = list(self.channel_storage.keys())
suffix = "_".join(channel_names)
# create a greenstree with equilibrium potentials at rest
greens_tree = self.create_tree_gf(
channel_names=channel_names,
cache_name_suffix=f"_{'_'.join(channel_names)}_",
)
greens_tree.set_impedances_in_tree(self.fit_cfg.freqs, pprint=False)
# compute the impedance matrix of the synapse locations
z_mat = greens_tree.calc_impedance_matrix(locs, explicit_method=False)
# compute the ZG matrix
gd_mat = np.diag(g_syns)
zg_mat = np.dot(z_mat, gd_mat)
z_mat_ = np.linalg.solve(np.eye(n_syn) + zg_mat, z_mat)
return z_mat_
[docs]
def fit_syn_rescale(
self,
c_loc_arg,
s_loc_arg,
comp_inds,
g_syns,
e_revs,
fit_impedance=False,
channel_names=None,
):
"""
Computes the rescaled conductances when synapses are moved to compartment
locations, assuming a given average conductance for each synapse.
Parameters
----------
c_loc_arg: list of locations or string (see documentation of
:func:`MorphTree.convert_loc_arg_to_locs` for details)
The compartment locations
s_loc_arg: list of locations or string (see documentation of
:func:`MorphTree.convert_loc_arg_to_locs` for details)
The synapse locations
comp_inds: list or numpy.array of ints
for each location in [s_loc_arg], gives the index of the compartment
location in [c_loc_arg] to which the synapse is assigned
g_syns: list or numpy.array of floats
The average conductances for each synapse
e_revs: list or numpy.array of floats
The reversal potential of each synapse
fit_impdedance: bool (optional, default `False`)
Whether to also use the reproduction of the rescaled impedance matrix
as target.
channel_names: list of str or `None` (default)
List of ion channels to be included in impedance matrix calculation.
`None` includes all ion channels
Returns
-------
g_resc: numpy.array of floats
The rescale values for the synaptic weights
"""
# process input
c_locs = self.convert_loc_arg_to_locs(c_loc_arg)
s_locs = self.convert_loc_arg_to_locs(s_loc_arg)
n_comp, n_syn = len(c_locs), len(s_locs)
assert n_syn == len(g_syns) and n_syn == len(e_revs)
assert len(c_locs) > 0
if n_syn == 0:
return np.array([])
if channel_names is None:
channel_names = list(self.channel_storage.keys())
cs_locs = c_locs + s_locs
cg_syns = np.concatenate((np.zeros(n_comp), np.array(g_syns)))
comp_inds, g_syns, e_revs = (
np.array(comp_inds),
np.array(g_syns),
np.array(e_revs),
)
# create a greenstree with equilibrium potentials at rest
greens_tree = self.create_tree_gf(
channel_names=channel_names,
cache_name_suffix=f"_{'_'.join(channel_names)}_",
)
greens_tree.set_impedances_in_tree(self.fit_cfg.freqs, pprint=False)
# compute the impedance matrix of the synapse locations
z_mat = greens_tree.calc_impedance_matrix(cs_locs, explicit_method=False)
zc_mat = z_mat[:n_comp, :n_comp]
# get the reversal potentials of the synapse locations
e_eqs = self.calc_e_eq(cs_locs)[0]
e_cs = e_eqs[:n_comp]
e_ss = e_eqs[-n_syn:]
# compute the ZG matrix
gd_mat = np.diag(cg_syns)
zg_mat_ = np.dot(z_mat, gd_mat)
zg_mat = np.linalg.solve(np.eye(n_comp + n_syn) + zg_mat_, zg_mat_)
zg_mat = zg_mat[:n_comp, n_comp:]
# create the compartment assignment matrix & syn index vector
c_mat = np.array([comp_inds == cc for cc in range(n_comp)]).astype(int)
s_inds = np.array([np.where(cc > 0)[0][0] for cc in c_mat.T])
# compute the driving potential vectors
es_vec = e_revs - e_ss
ec_vec = e_revs - e_cs[s_inds]
zc_mat = np.dot(zc_mat, c_mat)
czg_mat = np.dot(c_mat.T, zg_mat)
# create matrices for inverse fit
a1_mat = np.einsum("ck,kn->cnk", zc_mat, np.diag(ec_vec))
a2_mat = np.einsum("ck,kn->cnk", zc_mat, czg_mat * es_vec[None, :])
b_mat = zg_mat * es_vec[None, :]
# unravel first two indices
a_mat = np.reshape(a1_mat - a2_mat, (n_syn * n_comp, -1))
b_vec = np.reshape(b_mat, (n_syn * n_comp,))
if fit_impedance:
# fit based on impedance matrix
zr_mat = np.linalg.solve(np.eye(n_comp + n_syn) + zg_mat_, z_mat)
zr_mat = zr_mat[:n_comp, :n_comp]
zc_mat = z_mat[:n_comp, :n_comp]
# b matrix for fit
b_mat = zc_mat - zr_mat
# a tensor for fit
zcc = np.dot(zc_mat, c_mat)
czr = np.dot(c_mat.T, zr_mat)
aa_mat = np.einsum("ik,kn->ink", zcc, czr)
# unravel first two indices
a_mat_ = np.reshape(aa_mat, (n_comp * n_comp, -1))
b_vec_ = np.reshape(b_mat, (n_comp * n_comp,))
# perfor mfit
a_mat = np.concatenate((a_mat, a_mat_), axis=0)
b_vec = np.concatenate((b_vec, b_vec_), axis=0)
# compute rescaled synaptic conductances
g_resc = np.linalg.lstsq(a_mat, b_vec, rcond=None)[0]
b_arr = g_syns > 1e-9
g_resc[np.logical_not(b_arr)] = 1.0
g_resc[b_arr] = g_resc[b_arr] / g_syns[b_arr]
return g_resc
def assign_locs_to_comps(self, c_loc_arg, s_loc_arg, fz=0.8, channel_names=None):
"""
assumes the root node is in `c_loc_arg`
"""
if channel_names is None:
channel_names = list(self.channel_storage.keys())
# create a greenstree with equilibrium potentials at rest
greens_tree = self.create_tree_gf(
channel_names=channel_names,
cache_name_suffix=f"_{'_'.join(channel_names)}_at_rest_",
)
greens_tree.set_impedances_in_tree(self.fit_cfg.freqs, pprint=False)
# process input
c_locs = self.convert_loc_arg_to_locs(c_loc_arg)
s_locs = self.convert_loc_arg_to_locs(s_loc_arg)
# find nodes corresponding to locs
c_nodes = [self[loc["node"]] for loc in c_locs]
s_nodes = [self[loc["node"]] for loc in s_locs]
# compute input impedances
c_zins = [greens_tree.calc_zf(c_loc, c_loc)[0] for c_loc in c_locs]
s_zins = [greens_tree.calc_zf(s_loc, s_loc)[0] for s_loc in s_locs]
# paths to root
c_ptrs = [self.path_to_root(node) for node in c_nodes]
s_ptrs = [self.path_to_root(node) for node in s_nodes]
c_inds = []
for s_node, s_path, s_loc, s_zin in zip(s_nodes, s_ptrs, s_locs, s_zins):
z_diffs = []
# check if there are compartment nodes before bifurcation nodes in up direction
nn_inds = greens_tree.get_nearest_neighbour_loc_idxs(s_loc, c_locs)
# print c_before_b
c_ns = [c_nodes[ii] for ii in nn_inds]
c_ps = [c_ptrs[ii] for ii in nn_inds]
c_ls = [c_locs[ii] for ii in nn_inds]
c_zs = [c_zins[ii] for ii in nn_inds]
for c_node, c_path, c_loc, c_zin in zip(c_ns, c_ps, c_ls, c_zs):
# find the common node as far from the root as possible
s_p, c_p = s_path[::-1], c_path[::-1]
kk = 0
while kk < min(len(s_p), len(c_p)) and s_p[kk] == c_p[kk]:
p_node = s_p[kk]
kk += 1
# distinguish cases for computing impedance different
if p_node == s_node and p_node != c_node:
z_diffs.append(fz * np.abs(c_zin - s_zin))
elif p_node == c_node and p_node != s_node:
z_diffs.append((1.0 - fz) * np.abs(s_zin - c_zin))
elif p_node == c_node and p_node == s_node:
fz_ = fz if c_loc["x"] > s_loc["x"] else (1.0 - fz)
z_diffs.append(fz_ * np.abs(s_zin - c_zin))
else:
b_loc = (p_node.index, 1.0)
b_z = greens_tree.calc_zf(b_loc, b_loc)[0]
z_diffs.append((1.0 - fz) * (c_zin - b_z) + fz * (s_zin - b_z))
# compartment node with minimal impedance difference
ind_aux = np.argmin(z_diffs)
c_inds.append(nn_inds[ind_aux])
return c_inds