Source code for neat.modelreduction.cachetrees

# -*- coding: utf-8 -*-
#
# cachetrees.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

import dill
import numpy as np

import os
import pickle
import warnings


from ..trees.morphtree import computational_tree_decorator, MorphLoc
from ..trees.phystree import PhysTree
from ..trees.greenstree import GreensTree, GreensTreeTime
from ..trees.sovtree import SOVTree
from ..trees.netree import NET, NETNode

try:
    from ..simulations.neuron import neuronmodel as neurm
except ModuleNotFoundError:
    warnings.warn(
        "NEURON not available, equilibrium evaluation not working", UserWarning
    )


def consecutive(inds):
    """
    split a list of ints into consecutive sublists
    """
    return np.split(inds, np.where(np.diff(inds) != 1)[0] + 1)


class CachedTree(PhysTree):
    def __init__(
        self,
        *args,
        recompute_cache=None,
        save_cache=None,
        cache_name=None,
        cache_path=None,
        **kwargs,
    ):
        # we want a behaviour where the cache parameters are initialized to certain defauls
        # if they are not provided, but where the initialization operation based on copying
        # the tree leaves the cache parameters intact if the input tree is a subclass
        # of CachedTree. However, we also want the optionally provided arguments to be
        # overwritten to their provided values. The constructor below achieve this
        self.set_cache_params(
            **self.get_cache_defaults(
                recompute_cache=recompute_cache,
                save_cache=save_cache,
                cache_name=cache_name,
                cache_path=cache_path,
            )
        )
        super().__init__(*args, **kwargs)
        self.set_cache_params(
            recompute_cache=recompute_cache,
            save_cache=save_cache,
            cache_name=cache_name,
            cache_path=cache_path,
        )

    def get_cache_defaults(
        self,
        recompute_cache=None,
        save_cache=None,
        cache_name=None,
        cache_path=None,
    ):
        cache_params = {}
        cache_params["recompute_cache"] = (
            False if recompute_cache is None else recompute_cache
        )
        cache_params["save_cache"] = True if save_cache is None else save_cache
        cache_params["cache_name"] = "" if cache_name is None else cache_name
        cache_params["cache_path"] = "neatcache/" if cache_path is None else cache_path
        return cache_params

    def set_cache_params(
        self,
        recompute_cache=None,
        save_cache=None,
        cache_name=None,
        cache_path=None,
    ):
        if cache_path is not None:
            os.makedirs(cache_path, exist_ok=True)

        if cache_name is not None:
            self.cache_name = cache_name
        if cache_path is not None:
            self.cache_path = cache_path
        if save_cache is not None:
            self.save_cache = save_cache
        if recompute_cache is not None:
            self.recompute_cache = recompute_cache

    def get_cache_params(self):
        return {
            "cache_name": self.cache_name,
            "cache_path": self.cache_path,
            "save_cache": self.save_cache,
            "recompute_cache": self.recompute_cache,
        }

    def get_cache_params_in_dict(self, kwarg_dict):
        return {
            key: val
            for key, val in kwarg_dict.iteritems()
            if key in {"cache_name", "cache_path", "save_cache", "recompute_cache"}
        }

    def get_attributes_excluded_from_cache_override(self):
        """
        Returns a list of attributes that should NOT be overwritten by the cashed tree

        Returns
        -------
        list of str
            Attribute names that should not be overwritten
        """
        return ["cache_name", "cache_path", "save_cache", "recompute_cache"]

    def maybe_execute_funcs(
        self,
        funcs_args_kwargs=[],
        pprint=False,
    ):
        file_name = os.path.join(
            self.cache_path,
            f"{self.cache_name}_cache_{self.unique_hash()}.p",
        )

        if pprint:
            print(
                f"\n>>>> Cache file for {self.__class__.__name__}:\n    {file_name}\n<<<<"
            )

        try:
            # ensure that the funcs are recomputed if 'recompute' is true
            if self.recompute_cache:
                raise IOError

            with open(file_name, "rb") as file:
                tree_ = dill.load(file)

            for key in self.get_attributes_excluded_from_cache_override():
                tree_.__dict__[key] = self.__dict__[key]

            self.__dict__.update(tree_.__dict__)
            del tree_

        except (Exception, IOError, EOFError, KeyError) as err:
            if pprint:
                if self.recompute_cache:
                    logstr = ">>> Force recomputing cache..."
                else:
                    logstr = ">>> No cache found, recomputing..."
                print(logstr)

            # execute the functions
            for func, args, kwargs in funcs_args_kwargs:
                func(*args, **kwargs)

            if self.save_cache:
                with open(file_name, "wb") as file:
                    dill.dump(self, file)


[docs] class EquilibriumTree(CachedTree): """ Subclass of `neat.PhysTree` that allows for the calculation of the equilibrium potential at each node. Uses the NEURON simulator to evaluate the equibrium potentials. Can cache the results of the computation. The equilibrium potential is stored under the `v_ep` attribute of each node. """ def _calc_e_eq(self, loc_arg, ions=None, t_max=500.0, dt=0.1, factor_lambda=10.0): """ Calculates equilibrium potentials and concentrations in the tree. Computes the equilibria through a NEURON simulations without inputs. Parameters ---------- loc_arg: `list` of locations or string if `list` of locations, specifies the locations for which the equilibrium state evaluated, if ``string``, specifies the name under which a set of locations is stored ions: `iterable` of `str the names of the ions for which the concentration needs to be measured t_max: float duration of the simulation dt: float time-step of the simulation factor_lambda: `float` multiplies the number of compartments suggested by the lambda-rule """ locs = self.convert_loc_arg_to_locs(loc_arg) if ions is None: ions = self.ions # use longer simulation for Eeq fit if concentration mechansims are present t_max = t_max * 20.0 if len(ions) > 0 else t_max # create a biophysical simulation model sim_tree_biophys = neurm.NeuronSimTree(self) # compute equilibrium potentials sim_tree_biophys.init_model(dt=dt, factor_lambda=factor_lambda) sim_tree_biophys.store_locs(locs, "rec locs", warn=False) res_biophys = sim_tree_biophys.run( t_max, dt_rec=20.0, record_concentrations=ions ) sim_tree_biophys.delete_model() return ( np.array([v_m[-1] for v_m in res_biophys["v_m"]]), { ion: np.array([ion_eq[-1] for ion_eq in res_biophys[ion]]) for ion in ions }, )
[docs] def calc_e_eq( self, loc_arg, ions=None, method="interp", L_eps=50.0, pprint=False, **kwargs ): """ Calculates equilibrium potentials and concentrations in the tree. Uses either linear interpolations between the stored equilibria at the midpoints of the nodes or computes the equilibria through a NEURON simulation without inputs. Parameters ---------- loc_arg: `list` of locations or string if `list` of locations, specifies the locations for which the equilibrium state evaluated, if ``string``, specifies the name under which a set of locations is stored ions: `iterable` of `str the names of the ions for which the concentration needs to be measured method: Literal: 'interp' or 'sim' whether to use interpolation or simulation. Defaults to simulation if distance is larger than `L_eps` L_eps: float maximum distance (um) above which the method defaults to interpolation pprint: bool Whether or not to print additional information """ if ions is None: ions = self.ions locs = self.convert_loc_arg_to_locs(loc_arg) ref_locs = [(n.index, 0.5) for n in self] self.store_locs(ref_locs, name="ref locs") e_eqs = [] conc_eqs = {ion: [] for ion in ions} if method == "interp": if pprint: print("> computing e_eq through interpolation") idxs0 = self.get_nearest_loc_idxs(locs, "ref locs", direction=1) idxs1 = self.get_nearest_loc_idxs(locs, "ref locs", direction=2) for loc, idx0, idx1 in zip(locs, idxs0, idxs1): if idx0 is None or idx1 is None: if idx0 is None and idx1 is None: # ref locs probably not defined, computations should be redone break idx = idx0 if idx0 is not None else idx1 # locs[idx0] more distal than leaf ref loc e_eqs.append(self[ref_locs[idx][0]].v_ep) for ion in ions: conc_eqs[ion].append(self[ref_locs[idx][0]].conc_eps[ion]) else: L0 = self.path_length(loc, ref_locs[idx0]) L1 = self.path_length(loc, ref_locs[idx1]) if L0 < 1e-10 or L1 < 1e-10: idx = idx0 if L0 < L1 else idx1 # both neighbour locations are the same e_eqs.append(self[ref_locs[idx][0]].v_ep) for ion in ions: # linear interpolation to compute the equilibrium concentration conc_eqs[ion].append(self[ref_locs[idx][0]].conc_eps[ion]) elif L0 < L_eps and L1 < L_eps: v_ep0 = self[ref_locs[idx0][0]].v_ep v_ep1 = self[ref_locs[idx1][0]].v_ep # linear interpolation to compute the equilibrium potential e_eqs.append((v_ep0 * L1 + v_ep1 * L0) / (L1 + L0)) for ion in ions: c_ep0 = self[ref_locs[idx0][0]].conc_eps[ion] c_ep1 = self[ref_locs[idx0][0]].conc_eps[ion] # linear interpolation to compute the equilibrium concentration conc_eqs[ion].append((c_ep0 * L1 + c_ep1 * L0) / (L1 + L0)) else: break if len(e_eqs) < len(locs): if pprint: print("> computing e_eq through interpolation failed, simulating") return self._calc_e_eq(loc_arg, ions=ions, **kwargs) else: if pprint: print("> equilibria:") for ii, loc in enumerate(locs): conc_eq_str = str( {ion: f"{conc_eq[ii]:.8f}" for ion, conc_eq in conc_eqs.items()} ) print(f" loc {loc}: e_eq = {e_eqs[ii]:.2f} mV, {conc_eq_str}") return ( np.array(e_eqs), {ion: np.array(conc_eq) for ion, conc_eq in conc_eqs.items()}, )
def _set_e_eq(self, ions=None, t_max=500.0, dt=0.1, factor_lambda=10.0): if ions is None: ions = self.ions locs = [(n.index, 0.5) for n in self] res = self._calc_e_eq( locs, ions=ions, t_max=t_max, dt=dt, factor_lambda=factor_lambda ) for ii, n in enumerate(self): n.set_v_ep(res[0][ii]) for ion, conc in res[1].items(): n.set_conc_ep(ion, conc[ii])
[docs] def set_e_eq( self, ions=None, t_max=500.0, dt=0.1, factor_lambda=10.0, pprint=False ): """ Set equilibrium potentials and concentrations in the tree. Computes the equilibria through a NEURON simulation without inputs. Parameters ---------- ions: `list` of `str the names of the ions for which the concentration needs to be measured t_max: float duration of the simulation dt: float time-step of the simulation factor_lambda: `float` multiplies the number of compartments suggested by the lambda-rule """ self.maybe_execute_funcs( pprint=pprint, funcs_args_kwargs=[ ( self._set_e_eq, (), dict(ions=ions, t_max=t_max, dt=dt, factor_lambda=factor_lambda), ), ], )
[docs] class CachedGreensTree(GreensTree, CachedTree): """ Derived class of `neat.GreensTree` that caches the impedance calculation at each node. """ def __init__( self, *args, recompute_cache=None, save_cache=None, cache_name=None, cache_path=None, **kwargs, ): # we want a behaviour where the cache parameters are initialized to certain defauls # if they are not provided, but where the initialization operation based on copying # the tree leaves the cache parameters intact if the input tree is a subclass # of CachedTree. However, we also want the optionally provided arguments to be # overwritten to their provided values. The constructor below achieve this self.set_cache_params( **self.get_cache_defaults( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, ) ) super().__init__(*args, **kwargs) self.set_cache_params( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, )
[docs] def set_impedances_in_tree(self, freqs, sv_h=None, pprint=False, **kwargs): """ Sets the impedances in the tree. Parameters ---------- freqs: np.ndarray of float or complex The frequencies at which to evaluate the impedances sv_hs: dict of {string: np.ndarray} Keys are the channel names and values are numpy arrays that contain the expansion point for each ion channel pprint: bool (optional, default is ``False``) Print info """ if pprint: cname_string = ", ".join(list(self.channel_storage.keys())) print(f">>> evaluating impedances with {cname_string}") # we set freqs here already because it needs to be included in the # representation to generate a hash self.freqs = np.array(freqs) if sv_h is not None: # check if exansion point for all channels is defined assert sv_h.keys() == self.channel_storage.keys() for c_name, sv in sv_h.items(): # set the expansion point for node in self: node.set_expansion_point(c_name, statevar=sv) self.maybe_execute_funcs( pprint=pprint, funcs_args_kwargs=[ (self.set_comp_tree, [], {}), (self.set_impedance, [freqs], {"pprint": pprint, **kwargs}), ], )
@computational_tree_decorator def calc_net_steadystate(self, root_loc=None, dx=5.0, dz=5.0): if root_loc is None: root_loc = (1, 0.5) root_loc = MorphLoc(root_loc, self) # distribute locs on nodes st_nodes = self.gather_nodes(self[root_loc["node"]]) d2s_loc = self.path_length(root_loc, (1, 0.5)) net_locs = self.distribute_locs_at_d2s( d2s=np.arange(d2s_loc, 5000.0, dx), node_arg=st_nodes, name="net eval" ) # compute the impedance matrix for net calculation z_mat = self.calc_impedance_matrix("net eval", explicit_method=False)[0] # assert np.allclose(z_mat, z_mat_) # derive the NET net = NET() self._add_node_to_net( 0.0, z_mat[0, 0], z_mat, np.arange(z_mat.shape[0]), None, net, alpha=1.0, dz=dz, ) net.set_new_loc_idxs() return net, z_mat def _add_node_to_net( self, z_min, z_max, z_mat, loc_idxs, pnode, net, alpha=1.0, dz=20.0 ): # compute mean impedance of node inds = [[]] while len(inds[0]) == 0: inds = np.where((z_mat > z_min) & (z_mat < z_max)) z_max += dz z_node = np.mean(z_mat[inds]) # subtract impedances of parent nodes gammas = np.array([z_node]) self._subtract_parent_kernels(gammas, pnode) # add a node to the tree node = NETNode(len(net), loc_idxs, z_kernel=(np.array([alpha]), gammas)) if pnode != None: net.add_node_with_parent(node, pnode) else: net.root = node # recursion for following nodes d_inds = consecutive(np.where(np.diag(z_mat) > z_max)[0]) for di in d_inds: if len(di) > 0: self._add_node_to_net( z_max, z_max + dz, z_mat[di, :][:, di], loc_idxs[di], node, net, alpha=alpha, dz=dz, ) def _subtract_parent_kernels(self, gammas, pnode): if pnode != None: gammas -= pnode.z_kernel["c"] self._subtract_parent_kernels(gammas, pnode.parent_node)
[docs] class CachedGreensTreeTime(GreensTreeTime, CachedTree): """ Derived class of `neat.GreensTreeTime` that caches the separation of variables calculation """ def __init__( self, *args, recompute_cache=None, save_cache=None, cache_name=None, cache_path=None, **kwargs, ): # we want a behaviour where the cache parameters are initialized to certain defauls # if they are not provided, but where the initialization operation based on copying # the tree leaves the cache parameters intact if the input tree is a subclass # of CachedTree. However, we also want the optionally provided arguments to be # overwritten to their provided values. The constructor below achieve this self.set_cache_params( **self.get_cache_defaults( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, ) ) super().__init__(*args, **kwargs) self.set_cache_params( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, ) def set_impedances_in_tree(self, t_arr, pprint=False): """ Sets the impedances in the tree that are necessary for the evaluation of the response kernels. Parameters ---------- t_arr: np.ndarray of float The time-points at which to evaluate the response kernels pprint: bool (optional, default is ``False``) Print info """ if pprint: cname_string = ", ".join(list(self.channel_storage.keys())) print(f">>> evaluating response kernels with {cname_string}") self._set_freq_and_time_arrays(t_arr) self.maybe_execute_funcs( pprint=pprint, funcs_args_kwargs=[ (self.set_comp_tree, [], {}), (self.set_impedance, [t_arr], {}), ], )
[docs] class CachedSOVTree(SOVTree, CachedTree): """ Derived class of `neat.GreensTreeTime` that caches the impedance calculation at each node. """ def __init__( self, *args, recompute_cache=None, save_cache=None, cache_name=None, cache_path=None, **kwargs, ): # we want a behaviour where the cache parameters are initialized to certain defauls # if they are not provided, but where the initialization operation based on copying # the tree leaves the cache parameters intact if the input tree is a subclass # of CachedTree. However, we also want the optionally provided arguments to be # overwritten to their provided values. The constructor below achieve this self.set_cache_params( **self.get_cache_defaults( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, ) ) super().__init__(*args, **kwargs) self.set_cache_params( recompute_cache=recompute_cache, save_cache=save_cache, cache_name=cache_name, cache_path=cache_path, ) def set_sov_in_tree(self, maxspace_freq=100.0, pprint=False): """ Calculate the timescales and spatial functions of the separation of variables approach, using the algorithm by (Major, 1993). The (reciprocals) of the timescales (i.e. the roots of the transcendental equation) are stored in the somanode. The spatial factors are stored in each (computational) node. Parameters ---------- maxspace_freq: float (default is 500) roughly corresponds to the maximal spatial frequency of the smallest time-scale mode pprint: `bool` Verbose if ``True``. """ if pprint: print(f">>> evaluating SOV expansion") self.maxspace_freq = maxspace_freq self.maybe_execute_funcs( pprint=pprint, funcs_args_kwargs=[ (self.set_comp_tree, [], {"eps": 1.0}), ( self.calc_sov_equations, [], { "maxspace_freq": maxspace_freq, "pprint": pprint, }, ), ], )