Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added data/src/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions data/src/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np

SAMPLING_RATE = 250
SAMPLES_PER_5_SEC = 1250
ADAPT_TRIALS = 8

EXPECTED_FREQS = [6.66, 7.50, 8.57, 10.00, 12.00]
FREQ_COLORS = {
6.66: "blue",
7.50: "green",
8.57: "orange",
10.00: "red",
12.00: "purple",
}

# Savitzky-Golay parameters
SG_WINDOW = 21
SG_POLYORDER = 3

# SNR parameters
SNR_NEIGHBOR_BINS = 5
SNR_EXCLUDE_BINS = 2
SNR_THRESHOLD_LINEAR = 3.0
SNR_THRESHOLD_DB = 10 * np.log10(SNR_THRESHOLD_LINEAR)
CONFIDENCE_LEVEL = 0.80

# labels in the .mat files
EEG_KEY = "eeg"
DIN_KEY = "DIN_1"

# Channels of interest (occipital)
CHANNELS = {
125: "Oz",
115: "O1",
149: "O2",
}

# Extended channels for time-domain plots
CHANNELS_TIME = {125: "Oz", 115: "O1", 149: "O2", 101: "Pz"}

OUTPUT_DIR = "ssvep_analysis_output"
35 changes: 35 additions & 0 deletions data/src/data_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from enum import Enum


class TrialType(Enum):
ADAPT = 0
TEST = 1


class trial_group:
def __init__(
self, start_sample: int, end_sample: int, start_idx: int, end_idx: int
):
self.start_sample = start_sample
self.end_sample = end_sample
self.start_idx = start_idx
self.end_idx = end_idx


class trial_info:
def __init__(
self,
epoch: np.ndarray,
trial: int,
true_freq: float,
closest_freq: float,
n_dins: int,
type: TrialType,
):
self.epoch = epoch
self.trial = trial
self.true_freq = true_freq
self.closest_freq = closest_freq
self.n_dins = n_dins
self.type = type
95 changes: 95 additions & 0 deletions data/src/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections import defaultdict
import scipy.io
from data_classes import *
from consts import *


def load_data(filename: str) -> tuple[np.ndarray, np.ndarray, list[int]]:
mat_data = scipy.io.loadmat(filename)
eeg_data = mat_data[EEG_KEY]
din_data = mat_data[DIN_KEY]

latencies_ms = [int(din_data[1, i].item()) for i in range(din_data.shape[1])]
latencies = [int(round(l * SAMPLING_RATE / 1000)) for l in latencies_ms]

return eeg_data, latencies_ms, latencies


def group_din_markers_into_trials(
latencies: list[int], gap_threshold_samples: int
) -> list[trial_group]:
trial_groups = []
current_start_idx = 0
current_group_start = latencies[0]
current_group_end = latencies[0]

for i in range(1, len(latencies)):
if latencies[i] - latencies[i - 1] > gap_threshold_samples:
trial_groups.append(
trial_group(
start_sample=current_group_start,
end_sample=current_group_end,
start_idx=current_start_idx,
end_idx=i - 1,
)
)
current_group_start = latencies[i]
current_start_idx = i
current_group_end = latencies[i]

trial_groups.append(
trial_group(
start_sample=current_group_start,
end_sample=current_group_end,
start_idx=current_start_idx,
end_idx=len(latencies) - 1,
)
)

return trial_groups


def extract_trials(
eeg_data: np.ndarray, latencies_ms: list[int], trial_groups: list[trial_group]
) -> list[trial_info]:
all_trials = []

for i, tg in enumerate(trial_groups):
start_sample = tg.start_sample
epoch = eeg_data[:, start_sample : start_sample + SAMPLES_PER_5_SEC]

if epoch.shape[1] != SAMPLES_PER_5_SEC:
continue

dins_ms = latencies_ms[tg.start_idx : tg.end_idx + 1]
n_dins = len(dins_ms)
true_freq = 0.0
closest = 0.0

if n_dins > 1:
intervals = [dins_ms[j] - dins_ms[j - 1] for j in range(1, len(dins_ms))]
true_freq = (1000.0 / np.mean(intervals)) / 2.0
closest = min(EXPECTED_FREQS, key=lambda f: abs(f - true_freq))

trial_type = TrialType.ADAPT if i < ADAPT_TRIALS else TrialType.TEST

trial_record = trial_info(
epoch=epoch,
trial=i,
true_freq=round(true_freq, 2),
closest_freq=closest,
n_dins=n_dins,
type=trial_type,
)

all_trials.append(trial_record)

return all_trials


def main(filename: str) -> list[trial_info]:
eeg_data, latencies_ms, latencies = load_data(filename)
trial_groups = group_din_markers_into_trials(latencies, 250)
all_trials = extract_trials(eeg_data, latencies_ms, trial_groups)

return all_trials
150 changes: 150 additions & 0 deletions data/src/fft_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import numpy as np
from scipy.stats import t as t_dist
from collections import defaultdict

from consts import (
SAMPLING_RATE,
SAMPLES_PER_5_SEC,
SNR_NEIGHBOR_BINS,
SNR_EXCLUDE_BINS,
SNR_THRESHOLD_DB,
SNR_THRESHOLD_LINEAR,
EXPECTED_FREQS,
CHANNELS,
CONFIDENCE_LEVEL,
)
from data_classes import trial_info, TrialType
from utils import group_trials_by_frequency, get_trials_by_type

N = SAMPLES_PER_5_SEC
freqs_fft = np.fft.rfftfreq(N, d=1.0 / SAMPLING_RATE)


def compute_rfft(signal: np.ndarray) -> np.ndarray:
signal_centered = signal - np.mean(signal)
return np.fft.rfft(signal_centered)


def compute_fft_magnitude(signal: np.ndarray) -> np.ndarray:
fft_result = compute_rfft(signal)
magnitude = (2.0 / N) * np.abs(fft_result)
return magnitude


def compute_snr(magnitude: np.ndarray, target_freq: float) -> tuple[float, float, int]:
target_bin = np.argmin(np.abs(freqs_fft - target_freq))
signal_power = magnitude[target_bin] ** 2

left_slice = magnitude[
max(0, target_bin - SNR_NEIGHBOR_BINS - SNR_EXCLUDE_BINS) : target_bin
- SNR_EXCLUDE_BINS
]
right_slice = magnitude[
target_bin
+ SNR_EXCLUDE_BINS
+ 1 : target_bin
+ SNR_NEIGHBOR_BINS
+ SNR_EXCLUDE_BINS
+ 1
]

noise_bins = np.concatenate([left_slice, right_slice])
noise_power = np.mean(noise_bins**2) if len(noise_bins) > 0 else 1e-12

snr_linear = signal_power / noise_power if noise_power > 0 else float("inf")
snr_db = 10 * np.log10(snr_linear) if snr_linear > 0 else float("inf")
return snr_linear, snr_db, target_bin


def compute_confidence_intervals(
snr_db_values: list[float],
peak_values: list[float],
confidence_level: float = CONFIDENCE_LEVEL,
) -> dict:
n_reps = len(snr_db_values)

if n_reps < 2:
return {
"mean_snr_db": snr_db_values[0] if n_reps == 1 else 0,
"ci_snr_low": float("-inf"),
"ci_snr_high": float("inf"),
"mean_peak": peak_values[0] if n_reps == 1 else 0,
"ci_peak_low": 0,
"ci_peak_high": 0,
"is_confident": False,
"n_reps": n_reps,
}

mean_snr = np.mean(snr_db_values)
std_snr = np.std(snr_db_values, ddof=1)
se_snr = std_snr / np.sqrt(n_reps)

mean_peak = np.mean(peak_values)
std_peak = np.std(peak_values, ddof=1)
se_peak = std_peak / np.sqrt(n_reps)

t_crit = t_dist.ppf((1 + confidence_level) / 2, df=n_reps - 1)

ci_snr_low = mean_snr - t_crit * se_snr
ci_snr_high = mean_snr + t_crit * se_snr
ci_peak_low = mean_peak - t_crit * se_peak
ci_peak_high = mean_peak + t_crit * se_peak

is_confident = ci_snr_low > SNR_THRESHOLD_DB

return {
"mean_snr_db": mean_snr,
"ci_snr_low": ci_snr_low,
"ci_snr_high": ci_snr_high,
"mean_peak": mean_peak,
"ci_peak_low": ci_peak_low,
"ci_peak_high": ci_peak_high,
"is_confident": is_confident,
"n_reps": n_reps,
}


def main(trials: list[trial_info]) -> dict:
test_trials = get_trials_by_type(trials, TrialType.TEST)
test_by_freq = group_trials_by_frequency(test_trials)

all_results = defaultdict(list)

for ef in EXPECTED_FREQS:
for trial in test_by_freq.get(ef, []):
for ch_idx, ch_name in CHANNELS.items():
signal = trial.epoch[ch_idx, :]
magnitude = compute_fft_magnitude(signal)
snr_linear, snr_db, target_bin = compute_snr(magnitude, ef)

all_results[ef].append(
{
"trial": trial.trial,
"channel": ch_name,
"ch_idx": ch_idx,
"magnitude": magnitude,
"snr_linear": snr_linear,
"snr_db": snr_db,
"target_bin": target_bin,
"peak_mag": magnitude[target_bin],
"is_confident": snr_linear >= SNR_THRESHOLD_LINEAR,
}
)
stats = {}

for ef in EXPECTED_FREQS:
for ch_idx, ch_name in CHANNELS.items():
ch_results = [r for r in all_results[ef] if r["ch_idx"] == ch_idx]
snr_db_vals = [r["snr_db"] for r in ch_results]
peak_vals = [r["peak_mag"] for r in ch_results]

ci = compute_confidence_intervals(snr_db_vals, peak_vals)
Comment thread
MichalSzandar marked this conversation as resolved.

key = f"{ef}_{ch_idx}"
stats[key] = {**ci, "freq": ef, "channel": ch_name, "ch_idx": ch_idx}

return {
"trials": trials,
"all_results": dict(all_results),
"stats": stats,
}
Loading