Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions neuroanalysis/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ class SquarePulse(Stimulus):
_attributes = Stimulus._attributes + ['duration', 'amplitude']

def __init__(self, start_time, duration, amplitude, description="square pulse", units=None, parent=None):
if duration < 0:
raise ValueError("duration must be non-negative")
self.duration = duration
self.amplitude = amplitude
Stimulus.__init__(self, description=description, start_time=start_time, units=units, parent=parent)
Expand Down Expand Up @@ -518,6 +520,12 @@ class SquarePulseTrain(Stimulus):
_attributes = Stimulus._attributes + ['n_pulses', 'pulse_duration', 'amplitude', 'interval']

def __init__(self, start_time, n_pulses, pulse_duration, amplitude, interval, description="square pulse train", units=None, parent=None):
if not isinstance(n_pulses, int):
raise TypeError("n_pulses must be an integer")
if n_pulses < 1:
raise ValueError("n_pulses must be a positive integer")
if interval < pulse_duration:
raise ValueError("interval must be >= pulse_duration (pulses would overlap)")
self.n_pulses = n_pulses
self.pulse_duration = pulse_duration
self.amplitude = amplitude
Expand Down Expand Up @@ -572,7 +580,8 @@ def __init__(self, start_time, pulse_times, pulse_durations, amplitudes, descrip
self.pulse_times = pulse_times
self.pulse_durations = pulse_durations
self.amplitudes = amplitudes
assert len(pulse_times) == len(pulse_durations) == len(amplitudes)
if not (len(pulse_times) == len(pulse_durations) == len(amplitudes)):
raise ValueError("pulse_times, pulse_durations, and amplitudes must all have the same length")
Stimulus.__init__(self, description=description, start_time=start_time, units=units, parent=parent)

for i,t in enumerate(pulse_times):
Expand Down Expand Up @@ -613,6 +622,8 @@ class Ramp(Stimulus):
_attributes = Stimulus._attributes + ['duration', 'slope', 'initial_amplitude']

def __init__(self, start_time, duration, slope, offset=0, description="linear ramp", units=None, parent=None):
if duration < 0:
raise ValueError("duration must be non-negative")
self.duration = duration
self.slope = slope
self.offset = offset
Expand Down Expand Up @@ -659,6 +670,10 @@ class Sine(Stimulus):
_attributes = Stimulus._attributes + ['duration', 'frequency', 'amplitude', 'phase']

def __init__(self, start_time, duration, frequency, amplitude, phase=0, offset=0, description="sine wave", units=None, parent=None):
if duration < 0:
raise ValueError("duration must be non-negative")
if frequency <= 0:
raise ValueError("frequency must be positive")
self.duration = duration
self.frequency = frequency
self.amplitude = amplitude
Expand All @@ -667,11 +682,14 @@ def __init__(self, start_time, duration, frequency, amplitude, phase=0, offset=0
Stimulus.__init__(self, description=description, start_time=start_time, parent=parent, units=units)

def eval(self, **kwds):
if kwds.get("sample_rate") is not None:
if self.frequency > kwds["sample_rate"] / 2:
raise ValueError("Sample rate must be at least twice the frequency of the sine wave.")
trace = Stimulus.eval(self, **kwds)
start = self.global_start_time
chunk = trace.time_slice(start, start+self.duration, index_mode=kwds.get('index_mode'))
chunk.data[:] += self.offset

t = chunk.time_values - start
phase = self.phase_at(t)
chunk.data[:] += self.amplitude * np.sin(phase)
Expand Down Expand Up @@ -759,6 +777,12 @@ class Chirp(Stimulus):
_attributes = Stimulus._attributes + ['duration', 'start_frequency', 'end_frequency', 'amplitude', 'phase', 'offset']

def __init__(self, start_time, duration, start_frequency, end_frequency, amplitude, phase=0, offset=0, description="frequency chirp", units=None, parent=None):
if duration <= 0:
raise ValueError("duration must be positive")
if start_frequency <= 0:
raise ValueError("start_frequency must be positive")
if end_frequency <= 0:
raise ValueError("end_frequency must be positive")
self.duration = duration
self.start_frequency = start_frequency
self.end_frequency = end_frequency
Expand All @@ -768,6 +792,9 @@ def __init__(self, start_time, duration, start_frequency, end_frequency, amplitu
Stimulus.__init__(self, description=description, start_time=start_time, parent=parent, units=units)

def eval(self, **kwds):
if kwds.get("sample_rate") is not None:
if self.start_frequency > kwds["sample_rate"] / 2 or self.end_frequency > kwds["sample_rate"] / 2:
raise ValueError("Sample rate must be at least twice the maximum frequency in the chirp.")
trace = Stimulus.eval(self, **kwds)
start = self.global_start_time
chunk = trace.time_slice(start, start + self.duration, index_mode=kwds.get('index_mode'))
Expand Down Expand Up @@ -837,6 +864,12 @@ class Psp(Stimulus):
_attributes = Stimulus._attributes + ['rise_time', 'decay_tau', 'amplitude', 'rise_power']

def __init__(self, start_time, rise_time, decay_tau, amplitude, rise_power=2, description="frequency chirp", units=None, parent=None):
if rise_time <= 0:
raise ValueError("rise_time must be positive")
if decay_tau <= 0:
raise ValueError("decay_tau must be positive")
if rise_power <= 0:
raise ValueError("rise_power must be positive")
self.rise_time = rise_time
self.decay_tau = decay_tau
self.amplitude = amplitude
Expand Down
79 changes: 79 additions & 0 deletions neuroanalysis/tests/test_stimuli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict

import numpy as np
import pytest

import neuroanalysis.stimuli as stimuli
from neuroanalysis.data.dataset import TSeries
Expand Down Expand Up @@ -296,6 +297,84 @@ def test_chirp():
freqs = f0 ** np.linspace(1, np.log(f1) / np.log(f0), len(t)+1)[:-1]
assert np.allclose(freqs, stim.frequency_at(t))

def test_validation():
# Sine: Nyquist check in eval
sine = stimuli.Sine(start_time=0, duration=0.1, frequency=100, amplitude=1)
with pytest.raises(ValueError):
sine.eval(n_pts=100, sample_rate=150) # 150 < 2*100
sine.eval(n_pts=100, sample_rate=300) # should be fine

# Sine: frequency must be positive
with pytest.raises(ValueError):
stimuli.Sine(start_time=0, duration=0.1, frequency=0, amplitude=1)
with pytest.raises(ValueError):
stimuli.Sine(start_time=0, duration=0.1, frequency=-5, amplitude=1)

# Sine: duration must be non-negative
with pytest.raises(ValueError):
stimuli.Sine(start_time=0, duration=-1, frequency=10, amplitude=1)

# Chirp: Nyquist check in eval (restore original check)
chirp = stimuli.Chirp(start_time=0, duration=0.1, start_frequency=10, end_frequency=100, amplitude=1)
with pytest.raises(ValueError):
chirp.eval(n_pts=100, sample_rate=150) # 150 < 2*100
chirp.eval(n_pts=100, sample_rate=300) # should be fine

# Chirp: frequencies must be positive
with pytest.raises(ValueError):
stimuli.Chirp(start_time=0, duration=0.1, start_frequency=0, end_frequency=100, amplitude=1)
with pytest.raises(ValueError):
stimuli.Chirp(start_time=0, duration=0.1, start_frequency=10, end_frequency=-1, amplitude=1)

# Chirp: duration must be positive
with pytest.raises(ValueError):
stimuli.Chirp(start_time=0, duration=0, start_frequency=10, end_frequency=100, amplitude=1)
with pytest.raises(ValueError):
stimuli.Chirp(start_time=0, duration=-1, start_frequency=10, end_frequency=100, amplitude=1)

# SquarePulse: duration must be non-negative
with pytest.raises(ValueError):
stimuli.SquarePulse(start_time=0, duration=-0.1, amplitude=1)

# SquarePulseTrain: n_pulses must be a positive integer
with pytest.raises(ValueError):
stimuli.SquarePulseTrain(start_time=0, n_pulses=0, pulse_duration=0.01, amplitude=1, interval=0.1)
with pytest.raises(ValueError):
stimuli.SquarePulseTrain(start_time=0, n_pulses=-3, pulse_duration=0.01, amplitude=1, interval=0.1)
with pytest.raises(TypeError):
stimuli.SquarePulseTrain(start_time=0, n_pulses=1.5, pulse_duration=0.01, amplitude=1, interval=0.1)

# SquarePulseTrain: interval must be >= pulse_duration
with pytest.raises(ValueError):
stimuli.SquarePulseTrain(start_time=0, n_pulses=3, pulse_duration=0.1, amplitude=1, interval=0.05)

# Ramp: duration must be non-negative
with pytest.raises(ValueError):
stimuli.Ramp(start_time=0, duration=-1, slope=1)

# Psp: rise_time must be positive
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=0, decay_tau=0.01, amplitude=1)
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=-0.01, decay_tau=0.01, amplitude=1)

# Psp: decay_tau must be positive
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=0.01, decay_tau=0, amplitude=1)
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=0.01, decay_tau=-0.01, amplitude=1)

# Psp: rise_power must be positive
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=0.01, decay_tau=0.01, amplitude=1, rise_power=0)
with pytest.raises(ValueError):
stimuli.Psp(start_time=0, rise_time=0.01, decay_tau=0.01, amplitude=1, rise_power=-1)

# SquarePulseSeries: mismatched array lengths should raise ValueError (not AssertionError)
with pytest.raises(ValueError):
stimuli.SquarePulseSeries(start_time=0, pulse_times=[0, 0.1], pulse_durations=[0.01], amplitudes=[1, 1])


def test_find_noisy_square_pulses():
dt = 0.0002
np.random.seed(54321)
Expand Down
Loading