Source code for ajustador.features

import math
from collections import namedtuple
import pprint
import numpy as np
from scipy import optimize

from . import utilities, detect, vartype
from .signal_smooth import smooth

def _plot_line(ax, ranges, value, label, color, zorder=3):
    for (a,b) in ranges:
        ax.plot([a, b], [float(value)]*2,
                label=label, color=color, linestyle='-', zorder=zorder)
        label = None # we only need one line labeled
        if isinstance(value, vartype.vartype):
            ax.plot([a, b], [value.x - 3*value.dev]*2,
                      color, linestyle='--', zorder=zorder)
            ax.plot([a, b], [value.x + 3*value.dev]*2,
                      color, linestyle='--', zorder=zorder)

def _plot_spike(ax, wave, spikes, i, bottom=None, spike_bounds=None, lmargin=0.0010, rmargin=0.0015):
    ax.set_xlim(spikes.x[i] - lmargin, spikes.x[i] + rmargin)
    ax.plot(wave.x, wave.y, label='recording')
    if bottom is not None:
        ax.vlines(spikes.x[i:i+1], bottom, spikes.y, 'r', zorder=3)
    if spike_bounds is not None:
        ax.axvspan(spike_bounds[i].left, spike_bounds[i].right,
                   alpha=0.3, color='cyan')

def plural(n, word):
    return '{} {}{}'.format(n, word, '' if n == 1 else 's')

class Feature:
    requires = ()
    provides = ()
    array_attributes = ()
    mean_attributes = ()

    def __init__(self, obj):
        self._obj = obj

    def plot(self, figure=None):
        if figure is None:
            from matplotlib import pyplot
            figure = pyplot.figure()

        wave = self._obj.wave

        ax = figure.add_subplot(111)
        ax.plot(wave.x, wave.y, label='recording')
        ax.set_xlabel('time / s')
        ax.set_ylabel('membrane potential / V')
        return ax

    def spike_plot(self, figure=None, max_spikes=20,
                   bottom=None, spike_bounds=None,
                   lmargin=0.0010, rmargin=0.0015, rowsize=None):

        if figure is None:
            from matplotlib import pyplot
            figure = pyplot.figure()

        wave = self._obj.wave
        spikes = self._obj.spikes

        spike_count = min(len(spikes), max_spikes)
        if rowsize is None:
            rowsize = 3 if spike_count < 19 else 5
        rows = math.ceil(spike_count / rowsize)
        columns = min(spike_count, rowsize)

        axes = []
        sharey = None
        for i in range(spike_count):
            ax = figure.add_subplot(rows, columns, i+1, sharey=sharey)
            if i == 0:
                sharey = ax
            else:
                ax.tick_params(labelleft='off')
            axes.append(ax)

            _plot_spike(ax, wave, spikes, i, bottom=bottom, spike_bounds=spike_bounds,
                        lmargin=lmargin, rmargin=rmargin)

        figure.autofmt_xdate()
        return axes

    def report_attr(self, name):
        val = getattr(self, name)
        prefix = '{} = '.format(name)
        if hasattr(val, 'report'):
            ans = val.report(prefix=prefix)
        elif isinstance(val, np.ndarray) and hasattr(val, 'dev'):
            ans = vartype.vartype.format_array(val, prefix=prefix)
        elif hasattr(val, '__len__'):
            joiner = '\n' + len(prefix)*' '
            ans = prefix + joiner.join(str(x) for x in val)
        else:
            ans = prefix + str(val)
        if name in self.mean_attributes and hasattr(val, '__len__'):
            mean = vartype.vartype.average(val)
            ans += '\n{:{}} = {}'.format('', len(name), mean)
        return ans

    def report(self):
        return '\n'.join(self.report_attr(name) for name in self.provides)

[docs]class SteadyState(Feature): """Find the baseline and injection steady states The range *before* `baseline_before` and *after* `baseline_after` is used for `baseline`. The range *between* `steady_after` and `steady_before` is used for `steady`. """ requires = ('wave', 'baseline_before', 'baseline_after', 'steady_after', 'steady_before', 'steady_cutoff') provides = ('baseline', 'steady', 'response', 'baseline_pre', 'baseline_post') mean_attributes = ('baseline', 'steady', 'response', 'baseline_pre', 'baseline_post') array_attributes = ('baseline', 'steady', 'response', 'baseline_pre', 'baseline_post') @property @utilities.once def baseline(self): """The mean voltage of the area outside of injection interval Returns mean value of wave after excluding "outliers", values > 95th or < 5th percentile. """ wave = self._obj.wave before = self._obj.baseline_before after = self._obj.baseline_after if before is None and after is None: raise ValueError('cannot determine baseline') region = ((wave.x < before if before is not None else False) | (wave.x > after if after is not None else False)) what = wave.y[region] cutoffa, cutoffb = np.percentile(what, (40, 60)) cut = what[(what >= cutoffa) & (what <= cutoffb)] return vartype.array_mean(cut) @property @utilities.once def baseline_pre(self): """The mean voltage of the area before the injection interval Returns mean value of wave after excluding "outliers", values > 95th or < 5th percentile. """ wave = self._obj.wave before = self._obj.baseline_before if before is None: return vartype.vartype.nan what = wave.y[(wave.x < before)] cutoffa, cutoffb = np.percentile(what, (40, 60)) cut = what[(what >= cutoffa) & (what <= cutoffb)] return vartype.array_mean(cut) @property @utilities.once def baseline_post(self): """The mean voltage of the area after the injection interval Returns mean value of wave after excluding "outliers", values > 95th or < 5th percentile. """ wave = self._obj.wave after = self._obj.baseline_after if after is None: return vartype.vartype.nan what = wave.y[(wave.x > after)] cutoffa, cutoffb = np.percentile(what, (40, 60)) cut = what[(what >= cutoffa) & (what <= cutoffb)] return vartype.array_mean(cut) @property @utilities.once def steady(self): """Returns mean value of wave between `steady_after` and `steady_before`. "Outliers", values > 80th percentile (which is a parameter), are excluded. 80th percentile excludes the spikes. """ wave = self._obj.wave after = self._obj.steady_after before = self._obj.steady_before cutoff = self._obj.steady_cutoff data = wave.y[(wave.x > after) & (wave.x < before)] cutoff = np.percentile(data, cutoff) cut = data[data <= cutoff] return vartype.array_mean(cut) @property @utilities.once def response(self): return self.steady - self.baseline
[docs] def plot(self, figure=None, pre_post=False): wave = self._obj.wave before = self._obj.baseline_before after = self._obj.baseline_after steady_after = self._obj.steady_after steady_before = self._obj.steady_before time = wave.x[-1] ax = super().plot(figure) if not pre_post: _plot_line(ax, [(0, before), (after, time)], self.baseline, 'baseline', 'k') else: if before is not None: _plot_line(ax, [(0, before)], self.baseline_pre, 'baseline_pre', 'k') if after is not None: _plot_line(ax, [(after, time)], self.baseline_post, 'baseline_post', 'k') _plot_line(ax, [(steady_after, steady_before)], self.steady, 'steady', 'r') ax.annotate('response', xy=(time/2, self.steady.x), xytext=(time/2, self.baseline.x), arrowprops=dict(facecolor='black', shrink=0), horizontalalignment='center', verticalalignment='bottom') ax.legend(loc='upper right') ax.figure.tight_layout()
peak_and_threshold = namedtuple('peak_and_threshold', 'peaks thresholds') def _find_spikes(wave, min_height=0.0, max_charge_time=0.004, charge_threshold=0.02): peaks = detect.detect_peaks(wave.y, P_low=0.75, P_high=0.50) peaks = peaks[wave.y[peaks] > min_height] thresholds = np.empty(peaks.size) for i in range(len(peaks)): start = (wave.x >= wave.x[peaks[i]] - max_charge_time).argmax() x = wave.x[start:peaks[i] + 1] y = wave.y[start:peaks[i] + 1] yderiv = np.diff(y) #spike threshold is point where derivative is 2% of steepest try: ythresh = charge_threshold * yderiv.max() thresh = y[1:][yderiv > ythresh].min() thresholds[i] = thresh except Exception: thresholds[i] = np.nan return peak_and_threshold(peaks, thresholds) class WaveRegion: def __init__(self, wave, left_i, right_i): self._wave = wave self.left_i = left_i self.right_i = right_i @property def left(self): "x coordinate of the left edge of FWHM" if self.left_i == 0: # arr[-1:1] is an empty slice return self._wave.x[0] else: return self._wave.x[self.left_i-1:self.left_i+1].mean() @property def right(self): "x coordinate of the right edge of FWHM" return self._wave.x[self.right_i:self.right_i+2].mean() @property def width(self): return self.right - self.left @property def wave(self): return self._wave[self.left_i:self.right_i+1] @property def x(self): return self._wave.x[self.left_i:self.right_i+1] @property def y(self): return self._wave.y[self.left_i:self.right_i+1] def min(self): return self.wave.min() def relative_to(self, x, y): new = np.rec.fromarrays((self.x - x, self.y - y), names='x,y') return WaveRegion(new, 0, new.size-1) def __str__(self): y = self.y return 'WaveRegion[{} points, x={:.04f}-{:.04f}, y={:.03f}-{:.03f}]'.format( self.right_i - self.left_i + 1, self.left, self.right, self.y.min(), self.y.max()) def report(self, prefix='WaveRegion = '): return '{}{}'.format(prefix, self)
[docs]class Spikes(Feature): """Find the position and height of spikes """ requires = ('wave', 'injection_interval', 'injection_start') provides = ('spike_i', 'spikes', 'spike_count', 'spike_threshold', 'mean_isi', 'isi_spread', 'spike_latency', 'spike_bounds', 'spike_height', 'spike_width', 'mean_spike_height', # TODO: is it OK to have mean_spike_height as # here and as an aggregated attribute? ) array_attributes = ('spike_count', 'spike_height', 'spike_width', 'mean_isi', 'isi_spread', 'spike_latency') mean_attributes = ('spike_height', 'spike_width', 'spike_threshold') @property @utilities.once def spike_i_and_threshold(self): "Indices of spike maximums in the wave.x, wave.y arrays" return _find_spikes(self._obj.wave) @property def spike_i(self): "Indices of spike maximums in the wave.x, wave.y arrays" return self.spike_i_and_threshold.peaks @property def spike_threshold(self): "Indices of spike maximums in the wave.x, wave.y arrays" return self.spike_i_and_threshold.thresholds @property @utilities.once def spikes(self): "An array with .x and .y components marking the spike maximums" return self._obj.wave[self.spike_i] @property def spike_count(self): "The number of spikes" return len(self.spike_i) mean_isi_fallback_variance = 0.001 @property @utilities.once def mean_isi(self): """The mean interval between spikes Defined as: * :math:`<x_{i+1} - x_i>`, if there are at least two spikes, * the length of the depolarization interval otherwise (`injection_interval`) If there less than three spikes, the variance is fixed as `mean_isi_fallback_variance`. """ if self.spike_count > 2: return vartype.array_mean(np.diff(self.spikes.x)) elif self.spike_count == 2: d = self.spikes.x[1]-self.spikes.x[0] return vartype.vartype(d, 0.001) else: return vartype.vartype(self._obj.injection_interval, 0.001) @property @utilities.once def isi_spread(self): """The difference between the largest and smallest inter-spike intervals Only defined when `spike_count` is at least 3. """ if len(self.spikes) > 2: diff = np.diff(self.spikes.x) return diff.ptp() else: return np.nan @property @utilities.once def spike_latency(self): "Latency until the first spike or nan if no spikes" # TODO: add spike_latency to plot if len(self.spikes) > 0: return self.spikes[0].x - self._obj.injection_start else: return self._obj.injection_end - self._obj.injection_start @property @utilities.once def spike_bounds(self): "The FWHM box and other measurements for each spike" spikes, thresholds = self.spike_i_and_threshold ans = [] y = self._obj.wave.y halfheight = (self.spikes.y - thresholds) / 2 + thresholds for i, k in enumerate(self.spike_i): beg = end = k while beg > 1 and y[beg - 1] > halfheight[i]: beg -= 1 while end + 2 < y.size and y[end + 1] > halfheight[i]: end += 1 ans.append(WaveRegion(self._obj.wave, beg, end)) return ans @property @utilities.once def spike_height(self): "The difference between spike peaks and spike threshold" spikes, thresholds = self.spike_i_and_threshold height = self.spikes.y - thresholds return height @property @utilities.once def spike_width(self): return np.array([bounds.width for bounds in self.spike_bounds]) @property @utilities.once def mean_spike_height(self): "The mean absolute position of spike vertices" # TODO: is the variance too big? return vartype.array_mean(self.spikes.y)
[docs] def plot(self, figure=None): from . import drawing_util wave = self._obj.wave ax = super().plot(figure) bottom = -0.06 # self._obj.steady.x # Doing the "proper" thing makes the plot hard to read vline1 = ax.vlines(self.spikes.x, bottom, self.spikes.y, 'r', label='timing of spike maximum') ax.text(0.05, 0.5, plural(self.spike_count, 'spike'), horizontalalignment='left', transform=ax.transAxes) _plot_line(ax, [(self._obj.steady_after, self._obj.steady_before)], self.mean_spike_height, 'spike_height', 'y', zorder=0) ax.legend(loc='upper left', handler_map={vline1: drawing_util.HandlerVLineCollection()}) ax.figure.tight_layout() if self.spike_count > 0: ax2 = ax.figure.add_axes([.7, .45, .25, .4]) ax2.tick_params(labelbottom='off', labelleft='off') ax2.set_title('first spike', fontsize='smaller') _plot_spike(ax2, wave, self.spikes, i=0, bottom=-0.06, spike_bounds=self.spike_bounds)
[docs] def spike_plot(self, figure=None, **kwargs): x = self._obj.spikes.x.mean() spike_bounds = self.spike_bounds thresholds = self.spike_threshold height = self.spike_height lmargin = self.spike_width.max() rmargin = self.spike_width.max() * 2 axes = super().spike_plot(figure=None, spike_bounds=spike_bounds, lmargin=lmargin, rmargin=rmargin, **kwargs) for i in range(len(axes)): y = thresholds[i] + height[i] / 2 axes[i].annotate('FWHM', xy=(spike_bounds[i].left, y), xytext=(spike_bounds[i].right, y), arrowprops=dict(facecolor='black', shrink=0), verticalalignment='bottom') axes[i].axhline(thresholds[i], color='green', linestyle='--', linewidth=0.3)
[docs]class AHP(Feature): """Find the depth of "after hyperpolarization" """ requires = ('wave', 'injection_start', 'injection_end', 'injection_interval', 'spikes', 'spike_count', 'spike_bounds', 'spike_threshold') provides = ('spike_ahp_window', 'spike_ahp', 'spike_ahp_position') array_attributes = ('spike_ahp_window', 'spike_ahp', 'spike_ahp_position') mean_attributes = ('spike_ahp',) @property @utilities.once def spike_ahp_window(self): spikes = self._obj.spikes spike_bounds = self._obj.spike_bounds thresholds = self._obj.spike_threshold injection_start = self._obj.injection_start injection_end = self._obj.injection_end x = self._obj.wave.x y = self._obj.wave.y ans = [] for i in range(len(spikes)): beg = spike_bounds[i].right_i # Don't allow the ahp to straddle an injection start/stop edge. # The ahp will be invalid anyway. rlimit = min(spike_bounds[i+1].left if i < len(spikes)-1 else x[-1], injection_start if injection_start > x[beg] else np.infty, injection_end if injection_end > x[beg] else np.infty) w = spike_bounds[i].width if not np.isnan(w): n_rolling_window = int(w // (x[1] - x[0])) + 1 else: # FIXME: consider rejecting those outright n_rolling_window = 5 # if we are before the AHP, or mostly going down, advance while (beg < y.size - n_rolling_window and y[beg] >= thresholds[i] and x[beg + 1] < rlimit and y[beg] > y[beg + n_rolling_window]): beg += 1 end = beg + n_rolling_window while (end < x.size and (y[end] < thresholds[i] or end - beg < 5) and x[end] < rlimit): end += 1 ans.append(WaveRegion(self._obj.wave, beg, end)) return ans @property @utilities.once def spike_ahp(self): """Returns the (averaged) minimum in y of each AHP window `spike_ahp_window` is used to determine the extent of the AHP. An average of the bottom area of the window of the width of the spike is used. """ windows = self.spike_ahp_window spikes = self._obj.spikes spike_bounds = self._obj.spike_bounds ans = np.empty((len(windows), 2)) for i in range(len(windows)): w = spike_bounds[i].width left = windows[i].x[windows[i].y.argmin()] - w/2 right = windows[i].x[windows[i].y.argmin()] + w/2 cut = windows[i].wave[(windows[i].x >= left) & (windows[i].x <= right)] mean = vartype.array_mean(cut.y) ans[i] = mean.x, mean.dev return np.rec.fromarrays(ans.T, names='x,dev') @property @utilities.once def spike_ahp_position(self): """Returns the (averaged) x of the minimum in y of each AHP window `spike_ahp_window` is used to determine the extent of the AHP. An average of the bottom area of the window of the width of the spike is used. TODO: add to plot """ windows = self.spike_ahp_window spikes = self._obj.spikes spike_bounds = self._obj.spike_bounds ans = np.empty((len(windows), 2)) for i in range(len(windows)): step = windows[i].x[1] - windows[i].x[0] # Make sure that we have at least a few points in the window, # even if the spike is very narrow. w = max(spike_bounds[i].width, 8 * step) left = windows[i].x[windows[i].y.argmin()] - w/2 right = windows[i].x[windows[i].y.argmin()] + w/2 cut = windows[i].wave[(windows[i].x >= left) & (windows[i].x <= right)] bottom = vartype.array_mean(cut.y) relative = cut.y - bottom.x weights = (relative / relative.ptp())**-2 weights = np.fmin(weights, 100) avg = (cut.x * weights).sum() / weights.sum() assert not np.isnan(avg) dev = ((cut.x-avg)**2 * weights).sum()**0.5 / weights.sum()**0.5 assert not np.isnan(dev) # TODO: check the formula for dev ans[i] = (avg, dev) return np.rec.fromarrays(ans.T, names='x,dev') def _do_plots(self, axes): spikes = self._obj.spikes spike_bounds = self._obj.spike_bounds thresholds = self._obj.spike_threshold windows = self.spike_ahp_window ahps = self.spike_ahp low, high = np.inf, -np.inf spike_count = len(axes) for i in range(spike_count): window = windows[i] x = spikes.x[i] width = window.right - x axes[i].plot(window.x, window.y, 'r', label='AHP') _plot_line(axes[i], [(spikes[i].x - 3*spike_bounds[i].width, spikes[i].x + 3*spike_bounds[i].width)], thresholds[i], 'spike threshold', 'green') _plot_line(axes[i], [(x, window.right)], vartype.vartype(*ahps[i]), 'AHP bottom', 'magenta') axes[i].annotate('AHP', xytext=(x + width/2, ahps[i].x), xy=(x + width/2, thresholds[i]), arrowprops=dict(facecolor='black', shrink=0), horizontalalignment='center', verticalalignment='top') diff = abs(thresholds[i] - ahps[i].x) low = min(ahps[i].x - diff*0.5, thresholds[i] - diff*0.5, low) high = max(thresholds[i] + diff*0.5, high) axes[i].set_ylim(low, high)
[docs] def plot(self, figure=None): ax = super().plot(figure) if self._obj.spike_count == 0: ax.text(0.5, 0.5, 'no spikes', horizontalalignment='center', transform=ax.transAxes) else: ax.set_xlim(self._obj.spikes[0].x - self._obj.injection_interval*0.05, self._obj.spikes[-1].x + self._obj.injection_interval*0.05) self._do_plots([ax] * self._obj.spike_count) ax.figure.tight_layout()
[docs] def spike_plot(self, figure=None, **kwargs): spike_bounds = self._obj.spike_bounds axes = super().spike_plot(figure, **kwargs) self._do_plots(axes) for i in range(self._obj.spike_count): l = spike_bounds[i].left r = self.spike_ahp_window[i].right diff = r - l axes[i].set_xlim(l - diff*0.15, r + diff*0.15)
def _find_falling_curve(wave, window=20, after=0.2, before=0.6): d = vartype.array_diff(wave) dd = smooth(d.y, window='hanning', window_len=window)[(d.x > after) & (d.x < before)] start = end = dd.argmin() + (d.x <= after).sum() while start > 0 and wave[start - 1].y > wave[start].y and wave[start].x > after: start -= 1 sm = smooth(wave.y, window='hanning', window_len=window) smallest = sm[end] # find minimum while (end+window < wave.size and wave[end+window].x < before and sm[end:end + window].min() < smallest): smallest = sm[end] end += window // 2 start_override = (d.x > after).argmax() ccut = wave[start_override + 1 : end] return ccut def simple_exp(x, amp, tau): return float(amp) * np.exp(-(x-x[0]) / float(tau)) def negative_exp(x, amp, tau): return float(amp) * (1-np.exp(-(x-x[0]) / float(tau))) falling_param = namedtuple('falling_param', 'amp tau') function_fit = namedtuple('function_fit', 'function params good') def _fit_falling_curve(ccut, baseline, steady): if ccut.size < 5 or not (steady-baseline).negative: func = None params = falling_param(vartype.vartype.nan, vartype.vartype.nan) good = False else: init = (ccut.y.min()-baseline.x, ccut.x.ptp()) func = negative_exp popt, pcov = optimize.curve_fit(func, ccut.x, ccut.y-baseline.x, (-1,1)) pcov = np.zeros((2,2)) + pcov params = falling_param(vartype.vartype(popt[0], pcov[0,0]**0.5), vartype.vartype(popt[1], pcov[1,1]**0.5)) good = params.amp.negative and params.tau.positive return function_fit(func, params, good)
[docs]class FallingCurve(Feature): requires = ('wave', 'injection_start', 'steady_before', 'falling_curve_window', 'baseline', 'steady') provides = ('falling_curve', 'falling_curve_fit', 'falling_curve_amp', 'falling_curve_tau', 'falling_curve_function') array_attributes = ('falling_curve_amp', 'falling_curve_tau', 'falling_curve_function') @property @utilities.once def falling_curve(self): return _find_falling_curve(self._obj.wave, window=self._obj.falling_curve_window, after=self._obj.injection_start, before=self._obj.steady_before) @property @utilities.once def falling_curve_fit(self): return _fit_falling_curve(self.falling_curve, self._obj.baseline, self._obj.steady) @property def falling_curve_amp(self): fit = self.falling_curve_fit return fit.params.amp if fit.good else vartype.vartype.nan @property def falling_curve_tau(self): fit = self.falling_curve_fit return fit.params.tau if fit.good else vartype.vartype.nan @property def falling_curve_function(self): fit = self.falling_curve_fit return fit.function if fit.good else None
[docs] def plot(self, figure=None): ax = super().plot(figure) ccut = self.falling_curve baseline = self._obj.baseline steady = self._obj.steady ax.plot(ccut.x, ccut.y, 'r', label='falling curve') ax.set_xlim(self._obj.injection_start - 0.005, ccut.x.max() + .01) func, popt, good = self.falling_curve_fit if good: label = 'fitted {}'.format(func.__name__) ax.plot(ccut.x, baseline.x + func(ccut.x, *popt), 'g--', label=label) else: ax.text(0.2, 0.5, 'bad fit', horizontalalignment='center', transform=ax.transAxes, color='red') ax.legend(loc='upper right') ax.figure.tight_layout()
[docs]class Rectification(Feature): requires = ('injection_start', 'steady_after', 'steady_before', 'falling_curve', 'steady') provides = 'rectification', array_attributes = 'rectification', mean_attributes = 'rectification', window_len = 11 @property @utilities.once def rectification(self): ccut = self._obj.falling_curve steady = self._obj.steady if ccut.size < self.window_len + 1: return vartype.vartype.nan pos = ccut.y.argmin() end = max(pos + self.window_len//2, ccut.size-1) bottom = vartype.array_mean(ccut[end-self.window_len : end+self.window_len+1].y) return steady - bottom
[docs] def plot(self, figure=None): ax = super().plot(figure) ccut = self._obj.falling_curve after = self._obj.steady_after before = self._obj.steady_before steady = self._obj.steady ax.set_xlim(self._obj.injection_start - 0.005, before) _plot_line(ax, [(after, before)], steady, 'steady', 'r') right = (after + before) / 2 bottom = steady.x - self.rectification.x if np.isnan(bottom): ax.text(0.5, 0.5, 'rectification not detected', horizontalalignment='center', transform=ax.transAxes, color='red') else: _plot_line(ax, [(after, right)], bottom, 'rectification bottom', 'g') ax.annotate('rectification', xytext=(right, bottom), xy=(right, self._obj.steady.x), arrowprops=dict(facecolor='black', shrink=0), horizontalalignment='center', verticalalignment='top') ax.legend(loc='upper right') ax.figure.tight_layout()
[docs]class ChargingCurve(Feature): requires = ('wave', 'injection_start', 'baseline', 'baseline_before', 'spikes', 'spike_count', 'spike_threshold') provides = 'charging_curve_halfheight', array_attributes = 'charging_curve_halfheight', @property @utilities.once def charging_curve_halfheight(self): "The height in the middle between depolarization start and first spike" ccut = self.charging_curve if ccut is None: return vartype.vartype.nan threshold = self._obj.spike_threshold[0] baseline = self._obj.baseline return (threshold - baseline) / 2 @property @utilities.once def charging_curve(self): if self._obj.spike_count < 1: return None wave = self._obj.wave injection_start = self._obj.injection_start spike0 = self._obj.spikes[0] baseline = self._obj.baseline threshold = self._obj.spike_threshold[0] what = wave[(wave.x > injection_start) & (wave.x < spike0.x)] what = what[what.y < threshold] return what
[docs] def plot(self, figure=None): ax = super().plot(figure) baseline = self._obj.baseline ccut = self.charging_curve if ccut is None: ax.text(0.05, 0.5, 'cannot determine charging curve', horizontalalignment='left', transform=ax.transAxes, color='red') else: ax.plot(ccut.x, ccut.y, 'r', label='charging curve') ax.set_xlim(ccut.x[0] - 0.005, self._obj.spikes[0].x) _plot_line(ax, [(ccut.x[0], ccut.x[-1])], baseline + self.charging_curve_halfheight, 'charging curve halfheight', 'g') _plot_line(ax, [(0, self._obj.baseline_before)], baseline, 'baseline', 'k') ax.legend(loc='upper left') ax.figure.tight_layout()
standard_features = ( SteadyState, Spikes, AHP, FallingCurve, Rectification, ChargingCurve, )