From b072d87a11476021bf5af1f4865cc8c664f943bf Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Tue, 10 Mar 2026 00:10:34 +0200 Subject: [PATCH 1/8] Adding raw waveform plot --- general functions/plotRawWaveforms.asv | 266 +++ general functions/plotRawWaveforms.m | 226 +++ .../@VStimAnalysis/PlotZScoreComparison.asv | 1506 ----------------- .../Run_Bombcell_Automatic_Sorting.asv | 157 ++ .../Run_Bombcell_Automatic_Sorting.m | 32 + 5 files changed, 681 insertions(+), 1506 deletions(-) create mode 100644 general functions/plotRawWaveforms.asv create mode 100644 general functions/plotRawWaveforms.m delete mode 100644 visualStimulationAnalysis/@VStimAnalysis/PlotZScoreComparison.asv create mode 100644 visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv diff --git a/general functions/plotRawWaveforms.asv b/general functions/plotRawWaveforms.asv new file mode 100644 index 0000000..9b352f2 --- /dev/null +++ b/general functions/plotRawWaveforms.asv @@ -0,0 +1,266 @@ +function plotRawWaveforms(obj, unitID, params) +% plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style +% Optionally plots an auto-correlogram. +% +% INPUTS: +% obj - Visual stimulation object with spikeSortingFolder and dataObj +% unitID - cluster ID to plot (single unit) +% params - (optional) struct with any of the following fields: +% +% WAVEFORM params: +% nWaveforms - number of random waveforms to plot (default: 100) +% nChanAround - channels above/below max amp channel (default: 4) +% nPre - samples before spike peak (default: 20) +% nPost - samples after spike peak (default: 61) +% +% CORRELOGRAM params: +% showCorr - plot auto-correlogram (default: false) +% corrWin - correlogram half-window in ms (default: 100) +% corrBin - correlogram bin size in ms (default: 1) +% +% EXAMPLE: +% % Just waveforms with defaults +% plotRawWaveforms(obj, 42) +% +% % Custom params +% params.nWaveforms = 200; +% params.nChanAround = 6; +% params.showCorr = true; +% params.corrWin = 50; +% params.corrBin = 0.5; +% plotRawWaveforms(obj, 42, params) + +arguments (Input) + obj + unitID (1,1) double + params.nWaveforms = 200; + params.nChanAround = 6; + params.showCorr = true; + params.corrWin = 50; + params.corrBin = 0.5; +end + +%% Parse params with defaults +params = parseParams(params); + +%% Paths +ksDir = obj.spikeSortingFolder; +recordingDir = obj.dataObj.recordingDir; + +%% Settings from obj +n_channels = str2double(obj.dataObj.nSavedChansImec); +sample_rate = obj.dataObj.samplingFrequency; +uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); + +fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... + n_channels, sample_rate, uV_per_bit); + +%% Find binary file +binFiles = dir(fullfile(recordingDir, '*.bin')); +if isempty(binFiles), binFiles = dir(fullfile(recordingDir, '*.dat')); end +if isempty(binFiles), error('No .bin or .dat file found in: %s', recordingDir); end +binPath = fullfile(recordingDir, binFiles(1).name); +fprintf('Using binary file: %s\n', binPath); + +%% Load KS4 output +spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); +spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); +templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] +chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed +chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy'));% [nCh x 2] + +%% Find template index for this unit +unit_ids = (0 : size(templates, 1) - 1)'; +tmpl_idx = find(unit_ids == unitID); +if isempty(tmpl_idx) + error('Unit %d not found in templates.npy', unitID); +end + +%% Find best channel (max peak-to-peak across template channels) +unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] +p2p = max(unit_template) - min(unit_template); +[~, best_tmpl_chan] = max(p2p); + +% Channels to extract: nChanAround above/below best channel +chan_indices = (best_tmpl_chan - params.nChanAround) : (best_tmpl_chan + params.nChanAround); +chan_indices = chan_indices(chan_indices >= 1 & chan_indices <= size(templates, 3)); +n_chans_plot = numel(chan_indices); + +% Index of best channel within the plotted subset +best_local_idx = find(chan_indices == best_tmpl_chan); + +% Map to binary file channels (1-indexed for MATLAB) +bin_chans = chan_map(chan_indices) + 1; +best_bin_chan = bin_chans(best_local_idx); + +%% Get spike times for this unit +st = double(spike_times(spike_clusters == unitID)); +if numel(st) < 2, error('Unit %d has fewer than 2 spikes.', unitID); end +fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... + unitID, numel(st), min(params.nWaveforms, numel(st))); + +idx = randperm(numel(st), min(params.nWaveforms, numel(st))); +st_sub = st(idx); + +%% Extract waveforms from binary +waveform_len = params.nPre + params.nPost + 1; +finfo = dir(binPath); +n_samp_total = finfo.bytes / (n_channels * 2); +fid = fopen(binPath, 'rb'); + +waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); + +for si = 1:numel(st_sub) + s0 = st_sub(si) - params.nPre; + s1 = st_sub(si) + params.nPost; + if s0 < 1 || s1 > n_samp_total, continue; end + + fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); + raw = fread(fid, [n_channels, waveform_len], '*int16'); + if size(raw, 2) < waveform_len, continue; end + + waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; +end +fclose(fid); + +% Baseline subtract (mean of pre-spike window) +baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); +waveforms = waveforms - baseline; + +%% Compute correlogram if requested +if params.showCorr + [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); +end + +%% ---- Build layout ---- +t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; +mean_wf = mean(waveforms, 3, 'omitnan'); +std_wf = std(waveforms, 0, 3, 'omitnan'); + +chan_depths = chan_pos(chan_indices, 2); +[~, depth_order] = sort(chan_depths, 'descend'); % shallowest at top + +colors = lines(n_chans_plot); +fig = figure('Color', 'w', 'Name', sprintf('Unit %d', unitID)); + +if params.showCorr + % Two-column layout: waveforms | correlogram + outer = tiledlayout(fig, 1, 2, 'TileSpacing', 'compact', 'Padding', 'compact'); + title(outer, sprintf('Unit %d | %d waveforms | best ch: %d', ... + unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); + + % Left: nested layout for per-channel waveforms + ax_wf_container = nexttile(outer, 1); + wf_layout = tiledlayout(ax_wf_container.Parent, n_chans_plot, 1, ... + 'TileSpacing', 'none', 'Padding', 'compact'); + wf_layout.Layout.Tile = 1; + xlabel(wf_layout, 'Time (ms)'); + + % Right: correlogram axes + ax_corr = nexttile(outer, 2); +else + % Single-column layout: waveforms only + wf_layout = tiledlayout(fig, n_chans_plot, 1, ... + 'TileSpacing', 'none', 'Padding', 'compact'); + title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... + unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); + xlabel(wf_layout, 'Time (ms)'); +end + +%% Plot one tile per channel +wf_axes = gobjects(n_chans_plot, 1); +for ci = 1:n_chans_plot + plot_ci = depth_order(ci); + ax = nexttile(wf_layout); + wf_axes(ci) = ax; + + % Individual waveforms (translucent) + wf_ci = squeeze(waveforms(plot_ci, :, :)); + plot(ax, t_ms, wf_ci, 'Color', [colors(plot_ci,:), 0.15], 'LineWidth', 0.5); + hold(ax, 'on'); + + % Std shading + upper = mean_wf(plot_ci,:) + std_wf(plot_ci,:); + lower = mean_wf(plot_ci,:) - std_wf(plot_ci,:); + fill(ax, [t_ms, fliplr(t_ms)], [upper, fliplr(lower)], ... + colors(plot_ci,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + + % Mean waveform + plot(ax, t_ms, mean_wf(plot_ci,:), 'Color', colors(plot_ci,:), 'LineWidth', 2); + + xline(ax, 0, '--k', 'Alpha', 0.3); + + % Highlight best channel with yellow background + if plot_ci == best_local_idx + set(ax, 'Color', [1 1 0.85]); + end + + % Channel label + depth + ylabel(ax, sprintf('ch%d\n%.0fµm', bin_chans(plot_ci), chan_depths(plot_ci)), ... + 'FontSize', 7, 'Rotation', 0, 'HorizontalAlignment', 'right'); + + if ci < n_chans_plot + set(ax, 'XTickLabel', []); + end + box(ax, 'off'); +end + +% Shared amplitude scale across all channels +linkaxes(wf_axes, 'y'); + +%% Plot correlogram +if params.showCorr + bar(ax_corr, ccg_bins, ccg_counts, 1, ... + 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); + hold(ax_corr, 'on'); + xline(ax_corr, 0, '--k', 'Alpha', 0.4); + + % Shade refractory period (±2 ms) + ylims = ylim(ax_corr); + patch(ax_corr, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... + 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); + + xlabel(ax_corr, 'Lag (ms)'); + ylabel(ax_corr, 'Spike count'); + title(ax_corr, sprintf('ACG | bin %.1f ms | win ±%d ms', ... + params.corrBin, params.corrWin), 'FontSize', 10); + xlim(ax_corr, [-params.corrWin params.corrWin]); + box(ax_corr, 'off'); +end + +end % main function + + +%% ========================================================================= +function params = parseParams(params) +% Fill in defaults for any missing fields +if ~isfield(params, 'nWaveforms'), params.nWaveforms = 100; end +if ~isfield(params, 'nChanAround'), params.nChanAround = 4; end +if ~isfield(params, 'nPre'), params.nPre = 20; end +if ~isfield(params, 'nPost'), params.nPost = 61; end +if ~isfield(params, 'showCorr'), params.showCorr = false; end +if ~isfield(params, 'corrWin'), params.corrWin = 100; end % ms +if ~isfield(params, 'corrBin'), params.corrBin = 1; end % ms +end + + +%% ========================================================================= +function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) +% Compute auto-correlogram for a single unit +% spike_times_samples - spike times in samples +% fs - sampling rate (Hz) +% win_ms - half-window in ms +% bin_ms - bin size in ms + +st_ms = spike_times_samples / fs * 1000; % convert to ms +edges = -win_ms : bin_ms : win_ms; +bin_centers = edges(1:end-1) + bin_ms / 2; +counts = zeros(1, numel(bin_centers)); + +for i = 1:numel(st_ms) + diffs = st_ms - st_ms(i); % lag to all other spikes + diffs(i) = NaN; % exclude self + diffs = diffs(diffs > -win_ms & diffs < win_ms); % within window + counts = counts + histcounts(diffs, edges); +end +end \ No newline at end of file diff --git a/general functions/plotRawWaveforms.m b/general functions/plotRawWaveforms.m new file mode 100644 index 0000000..e84954c --- /dev/null +++ b/general functions/plotRawWaveforms.m @@ -0,0 +1,226 @@ +function plotRawWaveforms(obj, unitID, params) +% plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style +% Optionally plots an auto-correlogram. +% +% INPUTS: +% obj - Visual stimulation object with spikeSortingFolder and dataObj +% unitID - cluster ID to plot (single unit) +% +% OPTIONAL NAME-VALUE PARAMS: +% nWaveforms - number of random waveforms to plot (default: 100) +% nChanAround - channels above/below max amp channel (default: 4) +% nPre - samples before spike peak (default: 20) +% nPost - samples after spike peak (default: 61) +% showCorr - plot auto-correlogram (default: false) +% corrWin - correlogram half-window in ms (default: 100) +% corrBin - correlogram bin size in ms (default: 1) +% +% EXAMPLES: +% plotRawWaveforms(obj, 42) +% plotRawWaveforms(obj, 42, nWaveforms=200, nChanAround=6) +% plotRawWaveforms(obj, 42, showCorr=true, corrWin=50, corrBin=0.5) + +arguments (Input) + obj + unitID (1,1) double + params.nWaveforms (1,1) double = 100 + params.nChanAround (1,1) double = 4 + params.nPre (1,1) double = 20 + params.nPost (1,1) double = 61 + params.showCorr (1,1) logical = false + params.corrWin (1,1) double = 100 + params.corrBin (1,1) double = 1 +end + +%% Paths +ksDir = obj.spikeSortingFolder; +recordingDir = obj.dataObj.recordingDir; + +%% Settings from obj +n_channels = str2double(obj.dataObj.nSavedChansImec); +sample_rate = obj.dataObj.samplingFrequency; +uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); + +fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... + n_channels, sample_rate, uV_per_bit); + +%% Find binary file +binFiles = dir(fullfile(recordingDir, '*.bin')); +if isempty(binFiles), binFiles = dir(fullfile(recordingDir, '*.dat')); end +if isempty(binFiles), error('No .bin or .dat file found in: %s', recordingDir); end +binPath = fullfile(recordingDir, binFiles(1).name); +fprintf('Using binary file: %s\n', binPath); + +%% Load KS4 output +spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); +spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); +templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] +chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed +chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy'));% [nCh x 2] + +%% Find template index for this unit +unit_ids = (0 : size(templates, 1) - 1)'; +tmpl_idx = find(unit_ids == unitID); +if isempty(tmpl_idx) + error('Unit %d not found in templates.npy', unitID); +end + +%% Find best channel (max peak-to-peak across template channels) +unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] +p2p = max(unit_template) - min(unit_template); +[~, best_tmpl_chan] = max(p2p); + +% Channels to extract: nChanAround above/below best channel +chan_indices = (best_tmpl_chan - params.nChanAround) : (best_tmpl_chan + params.nChanAround); +chan_indices = chan_indices(chan_indices >= 1 & chan_indices <= size(templates, 3)); +n_chans_plot = numel(chan_indices); + +% Index of best channel within the plotted subset +best_local_idx = find(chan_indices == best_tmpl_chan); + +% Map to binary file channels (1-indexed for MATLAB) +bin_chans = chan_map(chan_indices) + 1; +best_bin_chan = bin_chans(best_local_idx); + +%% Get spike times for this unit +st = double(spike_times(spike_clusters == unitID)); +if numel(st) < 2, error('Unit %d has fewer than 2 spikes.', unitID); end +fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... + unitID, numel(st), min(params.nWaveforms, numel(st))); + +idx = randperm(numel(st), min(params.nWaveforms, numel(st))); +st_sub = st(idx); + +%% Extract waveforms from binary +waveform_len = params.nPre + params.nPost + 1; +finfo = dir(binPath); +n_samp_total = finfo.bytes / (n_channels * 2); +fid = fopen(binPath, 'rb'); + +waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); + +for si = 1:numel(st_sub) + s0 = st_sub(si) - params.nPre; + s1 = st_sub(si) + params.nPost; + if s0 < 1 || s1 > n_samp_total, continue; end + + fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); + raw = fread(fid, [n_channels, waveform_len], '*int16'); + if size(raw, 2) < waveform_len, continue; end + + waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; +end +fclose(fid); + +% Baseline subtract (mean of pre-spike window) +baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); +waveforms = waveforms - baseline; + +%% Compute correlogram if requested +if params.showCorr + [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); +end + +%% ---- Waveform figure ---- +t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; +mean_wf = mean(waveforms, 3, 'omitnan'); +std_wf = std(waveforms, 0, 3, 'omitnan'); + +chan_depths = chan_pos(chan_indices, 2); +[~, depth_order] = sort(chan_depths, 'descend'); % shallowest at top + +colors = lines(n_chans_plot); + +figure('Color', 'w', 'Name', sprintf('Unit %d — Waveforms', unitID)); +wf_layout = tiledlayout(n_chans_plot, 1, 'TileSpacing', 'none', 'Padding', 'compact'); +title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... + unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); +xlabel(wf_layout, 'Time (ms)'); + +wf_axes = gobjects(n_chans_plot, 1); +for ci = 1:n_chans_plot + plot_ci = depth_order(ci); + ax = nexttile(wf_layout); + wf_axes(ci) = ax; + + % Individual waveforms (translucent) + wf_ci = squeeze(waveforms(plot_ci, :, :)); + plot(ax, t_ms, wf_ci, 'Color', [colors(plot_ci,:), 0.15], 'LineWidth', 0.5); + hold(ax, 'on'); + + % Std shading + upper = mean_wf(plot_ci,:) + std_wf(plot_ci,:); + lower = mean_wf(plot_ci,:) - std_wf(plot_ci,:); + fill(ax, [t_ms, fliplr(t_ms)], [upper, fliplr(lower)], ... + colors(plot_ci,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + + % Mean waveform + plot(ax, t_ms, mean_wf(plot_ci,:), 'Color', colors(plot_ci,:), 'LineWidth', 2); + + xline(ax, 0, '--k', 'Alpha', 0.3); + + % Highlight best channel with yellow background + if plot_ci == best_local_idx + set(ax, 'Color', [1 1 0.85]); + end + + % Channel label + depth + ylabel(ax, sprintf('ch%d\n%.0fµm', bin_chans(plot_ci), chan_depths(plot_ci)), ... + 'FontSize', 7, 'Rotation', 0, 'HorizontalAlignment', 'right'); + + % Only show x tick labels on bottom subplot + if ci < n_chans_plot + set(ax, 'XTickLabel', []); + end + box(ax, 'off'); +end + +% Shared amplitude scale across all channels +linkaxes(wf_axes, 'y'); + +%% ---- ACG figure (separate) ---- +if params.showCorr + figure('Color', 'w', 'Name', sprintf('Unit %d — ACG', unitID)); + ax_corr = axes; + + bar(ax_corr, ccg_bins, ccg_counts, 1, ... + 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); + hold(ax_corr, 'on'); + xline(ax_corr, 0, '--k', 'Alpha', 0.4); + + % Shade refractory period (±2 ms) + ylims = ylim(ax_corr); + patch(ax_corr, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... + 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); + + xlabel(ax_corr, 'Lag (ms)'); + ylabel(ax_corr, 'Spike count'); + title(ax_corr, sprintf('Unit %d | ACG | bin %.1f ms | win ±%d ms', ... + unitID, params.corrBin, params.corrWin), 'FontSize', 12); + xlim(ax_corr, [-params.corrWin params.corrWin]); + box(ax_corr, 'off'); +end + +end % main function + + +%% ========================================================================= +function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) +% Compute auto-correlogram for a single unit +% spike_times_samples - spike times in samples +% fs - sampling rate (Hz) +% win_ms - half-window in ms +% bin_ms - bin size in ms + +st_ms = spike_times_samples / fs * 1000; +edges = -win_ms : bin_ms : win_ms; +bin_centers = edges(1:end-1) + bin_ms / 2; +counts = zeros(1, numel(bin_centers)); + +for i = 1:numel(st_ms) + diffs = st_ms - st_ms(i); + diffs(i) = NaN; + diffs = diffs(diffs > -win_ms & diffs < win_ms); + counts = counts + histcounts(diffs, edges); +end +end \ No newline at end of file diff --git a/visualStimulationAnalysis/@VStimAnalysis/PlotZScoreComparison.asv b/visualStimulationAnalysis/@VStimAnalysis/PlotZScoreComparison.asv deleted file mode 100644 index f5e3708..0000000 --- a/visualStimulationAnalysis/@VStimAnalysis/PlotZScoreComparison.asv +++ /dev/null @@ -1,1506 +0,0 @@ -function fig = PlotZScoreComparison(expList, Stims2Comp,params) - -arguments - expList (1,:) double %%Number of experiment from excel list - Stims2Comp cell %% Comparison order {'MB','RG','MBR'} would select neurons responsive to moving ball and - % compare this neurons responses to other stimuli. - params.threshold = 0.05 - params.diffResp = false - params.overwrite = false - params.StimsPresent = {'MB','RG'} %assumes that at least moving ball is present - params.StimsNotPresent = {} - params.StimsToCompare = {} %Select 2 stims to compare scatter plots (default: 1st and 2nd stim are compared from the Stims2Comp cell array) - params.overwriteResponse = false - params.overwriteStats = false - params.overwriteGroupStats = false - params.RespDurationWin = 100; %same as default - params.shuffles = 2000; %same as default - params.StatMethod = 'ObsWindow' - params.ignoreNonSignif = false %when comparing first stim, ignore neurons non responsive to other stim - params.EachStimSignif = false %resposnive neurons for each stim are selected (default: responsive neurons of first stime are selected) - params.ComparePairs = {}; %Compare only pairs, recommended - params.PaperFig logical = false -end - -% Compare z-scores and p-values between moving ball and rect grid analyses - -animal = 0; -insertion =0; -animalVector = cell(1,numel(expList)); -insertionVector = cell(1,numel(expList)); -zScoresMB = cell(1,numel(expList)); -zScoresRG = cell(1,numel(expList)); -spKrMB = cell(1,numel(expList)); -spKrRG = cell(1,numel(expList)); -diffSpkMB = cell(1,numel(expList)); -diffSpkRG = cell(1,numel(expList)); - -zScoresSDGm = cell(1,numel(expList)); -zScoresMBR = cell(1,numel(expList)); -zScoresFFF = cell(1,numel(expList)); -spKrMBR = cell(1,numel(expList)); -spKrFFF = cell(1,numel(expList)); -spKrSDGm = cell(1,numel(expList)); -diffSpkMBR = cell(1,numel(expList)); -diffSpkFFF = cell(1,numel(expList)); -diffSpkSDGm = cell(1,numel(expList)); - -zScoresNI = cell(1,numel(expList)); -% zScoresNV = cell(1,numel(expList)); -spKrNI = cell(1,numel(expList)); -spKrNV = cell(1,numel(expList)); -diffSpkNI = cell(1,numel(expList)); -diffSpkNV = cell(1,numel(expList)); - -j = 1; -AnimalI = ""; -InsertionI = 0; - -NP = loadNPclassFromTable(expList(1)); %73 81 -vs = linearlyMovingBallAnalysis(NP); - -%%% Asumes all experiments were analyzed using the same window -vs.ResponseWindow; -MBvs = vs.ResponseWindow; -%%% - -nameOfFile = sprintf('\\Ex_%d-%d_Combined_Neural_responses_%s_filtered.mat',expList(1),expList(end),Stims2Comp{1}); -p = extractBefore(vs.getAnalysisFileName,'lizards'); -p = [p 'lizards']; - -if ~exist([p '\Combined_lizard_analysis'],'dir') - cd(p) - mkdir Combined_lizard_analysis -end -saveDir = [p '\Combined_lizard_analysis']; - -if exist([saveDir nameOfFile],'file') == 2 && ~params.overwrite - - S = load([saveDir nameOfFile]); - - expList2 = S.expList; - - if isequal(expList2,expList) - - forloop = false; - else - forloop = true; - end -else - forloop = true; -end - -longTablePairComp = table( ... - categorical.empty(0,1), ... - categorical.empty(0,1), ... - categorical.empty(0,1), ... - categorical.empty(0,1),... - double.empty(0,1), ... - double.empty(0,1), ... - 'VariableNames', {'animal','insertion','stimulus','NeurID','Z-score','SpkR'} ); - -longTable= table( ... - categorical.empty(0,1), ... - categorical.empty(0,1), ... - categorical.empty(0,1), ... - double.empty(0,1), ... - double.empty(0,1), ... - 'VariableNames', {'animal','insertion','stimulus','respNeur','totalSomaticN'} ); - -if forloop - for ex = expList - - fprintf('Processing recording: %s .\n',NP.recordingName) - NP = loadNPclassFromTable(ex); %73 81 - vs = linearlyMovingBallAnalysis(NP); - vsR = rectGridAnalysis(NP); - - %Assumes that RG and MB are present in all insertions - Animal = string(regexp( vs.getAnalysisFileName, 'PV\d+', 'match', 'once')); - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("MB"), 0,0}; - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("RG"), 0,0}; - - try - vsBr = linearlyMovingBarAnalysis(NP); - params.StimsPresent{3} = 'MBR'; - - if isempty(vsBr.VST) - error('Moving Bar stimulus not found.\n') - else - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("MBR"), 0,0}; - end - catch - params.StimsPresent{3} = ''; - fprintf('Moving Bar stimulus not found.\n') - vsBr = linearlyMovingBallAnalysis(NP); %use rectGrid here to avoid puting lots of ifs. - end - try - vsG = StaticDriftingGratingAnalysis(NP); - params.StimsPresent{4} = 'SDG'; - - if isempty(vsG.VST) - error('Gratings stimulus not found.\n') - else - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("SDGm"), 0,0}; - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("SDGs"), 0,0}; - end - catch - params.StimsPresent{4} = ''; - fprintf('Gratings stimulus not found.\n') - vsG = rectGridAnalysis(NP); %use rectGrid here to avoid puting lots of ifs. - end - try - vsNI = imageAnalysis(NP); - params.StimsPresent{5} = 'NI'; - - if isempty(vsNI.VST) - error('Gratings stimulus not found.\n') - else - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("NI"), 0,0}; - end - catch - params.StimsPresent{5} = ''; - fprintf('Natural images stimulus not found.\n') - vsNI = rectGridAnalysis(NP); %use rectGrid here to avoid puting lots of ifs. - end - try - vsNV = movieAnalysis(NP); - params.StimsPresent{6} = 'NV'; - - if isempty(vsNV.VST) - error('Gratings stimulus not found.\n') - else - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("NV"), 0,0}; - end - catch - params.StimsPresent{6} = ''; - fprintf('Natural video stimulus not found.\n') - vsNV = rectGridAnalysis(NP); %use rectGrid here to avoid puting lots of ifs. - end - - try - vsFFF = fullFieldFlashAnalysis(NP); - params.StimsPresent{7} = 'FFF'; - - if isempty(vsFFF.VST) - error('FFF stimulus not found.\n') - else - longTable(end+1,:) = {categorical(Animal),categorical(j), categorical("FFF"), 0,0}; - end - catch - params.StimsPresent{7} = ''; - fprintf('FFF stimulus not found.\n') - vsFFF = rectGridAnalysis(NP); %use moving ball here to avoid puting lots of ifs. - end - - - %%Load pvals and zscore from rect grid and moving ball - if isequal(params.StimsPresent{1},'') || ~ismember(params.StimsPresent{1}, Stims2Comp) - vs.ResponseWindow; - else - vs.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vs.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vs.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - - if isequal(params.StimsPresent{2},'') || ~ismember(params.StimsPresent{2}, Stims2Comp) - vsR.ResponseWindow; - else - vsR.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsR.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsR.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - if isequal(params.StimsPresent{3},'') || ~ismember(params.StimsPresent{3}, Stims2Comp) - vsBr.ResponseWindow; - else - vsBr.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsBr.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsBr.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - if isequal(params.StimsPresent{4},'') || ~ismember(params.StimsPresent{4}, Stims2Comp) - vsG.ResponseWindow; - else - vsG.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsG.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsG.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - if isequal(params.StimsPresent{5},'') || ~ismember(params.StimsPresent{5}, Stims2Comp) - vsNI.ResponseWindow; - else - vsNI.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsNI.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsNI.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - if isequal(params.StimsPresent{6},'') || ~ismember(params.StimsPresent{6}, Stims2Comp) - vsNV.ResponseWindow; - else - vsNV.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsNV.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsNV.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - if isequal(params.StimsPresent{7},'') || ~ismember(params.StimsPresent{7}, Stims2Comp) - vsFFF.ResponseWindow; - else - vsFFF.ResponseWindow('overwrite',params.overwriteResponse,'durationWindow',params.RespDurationWin); - if isequal(params.StatMethod,'ObsWindow') - vsFFF.ShufflingAnalysis('overwrite',params.overwriteStats,"N_bootstrap", params.shuffles); - else - vsFFF.BootstrapPerNeuron('overwrite',params.overwriteStats); - end - end - - if isequal(params.StatMethod,'ObsWindow') - statsMB = vs.ShufflingAnalysis; - statsRG = vsR.ShufflingAnalysis; - statsMBR = vsBr.ShufflingAnalysis; - statsSDG = vsG.ShufflingAnalysis; - statsFFF = vsFFF.ShufflingAnalysis; - statsNI = vsNI.ShufflingAnalysis; - statsNV = vsNV.ShufflingAnalysis; - else - statsMB = vs.BootstrapPerNeuron; - statsRG = vsR.BootstrapPerNeuron; - statsMBR = vsBr.BootstrapPerNeuron; - statsSDG = vsG.BootstrapPerNeuron; - statsFFF = vsFFF.BootstrapPerNeuron; - statsNI = vsNI.BootstrapPerNeuron; - statsNV = vsNV.BootstrapPerNeuron; - end - - rwRG = vsR.ResponseWindow; - rwMB = vs.ResponseWindow; - rwMBR = vsBr.ResponseWindow; - rwFFF = vsFFF.ResponseWindow; - rwSDG = vsG.ResponseWindow; - rwNI = vsNI.ResponseWindow; - rwNV = vsNV.ResponseWindow; - - %Load stats of Moving Ball, select fastest speed if there are several - zScores_MB = statsMB.Speed1.ZScoreU; - pValuesMB = statsMB.Speed1.pvalsResponse; - spkR_MB = max(rwMB.Speed1.NeuronVals(:,:,4),[],2); - spkDiff_MB = max(rwMB.Speed1.NeuronVals(:,:,5),[],2); - - if isfield(statsMB, 'Speed2') %If - zScores_MB = statsMB.Speed2.ZScoreU; - pValuesMB = statsMB.Speed2.pvalsResponse; - spkR_MB = max(rwMB.Speed2.NeuronVals(:,:,4),[],2); - spkDiff_MB = max(rwMB.Speed2.NeuronVals(:,:,5),[],2); - end - - totalU{j} = numel(zScores_MB); - %Load stats of Rect Grid. - zScores_RG = statsRG.ZScoreU; - pValuesRG = statsRG.pvalsResponse; - spkR_RG = max(rwRG.NeuronVals(:,:,4),[],2); - spkDiff_RG = max(rwRG.NeuronVals(:,:,5),[],2); - - %Load stats of Moving bar. - zScores_MBR = statsMBR.Speed1.ZScoreU; - pValuesMBR = statsMBR.Speed1.pvalsResponse; - spkR_MBR = max(rwMBR.Speed1.NeuronVals(:,:,4),[],2); - spkDiff_MBR = max(rwMBR.Speed1.NeuronVals(:,:,5),[],2); - - %Load stats of FFF - zScores_FFF = statsFFF.ZScoreU; - pValuesFFF = statsFFF.pvalsResponse; - spkR_FFF = max(rwFFF.NeuronVals(:,:,4),[],2); - spkDiff_FFF = max(rwFFF.NeuronVals(:,:,5),[],2); - - %Load stats of SDG moving - - if isequal(params.StimsPresent{4},'') - - zScores_SDGm = statsSDG.ZScoreU; - pValuesSDGm = statsSDG.pvalsResponse; - spkR_SDGm = max(rwSDG.NeuronVals(:,:,4),[],2); - spkDiff_SDGm = max(rwSDG.NeuronVals(:,:,5),[],2); - - %Load stats of SDG static - zScores_SDGs = statsSDG.ZScoreU; - pValuesSDGs = statsSDG.pvalsResponse; - spkR_SDGs = max(rwSDG.NeuronVals(:,:,4),[],2); - spkDiff_SDGs = max(rwSDG.NeuronVals(:,:,5),[],2); - - else - zScores_SDGm = statsSDG.Moving.ZScoreU; - pValuesSDGm = statsSDG.Moving.pvalsResponse; - spkR_SDGm = max(rwSDG.Moving.NeuronVals(:,:,4),[],2); - spkDiff_SDGm = max(rwSDG.Moving.NeuronVals(:,:,5),[],2); - - %Load stats of SDG static - zScores_SDGs = statsSDG.Static.ZScoreU; - pValuesSDGs = statsSDG.Static.pvalsResponse; - spkR_SDGs = max(rwSDG.Static.NeuronVals(:,:,4),[],2); - spkDiff_SDGs = max(rwSDG.Static.NeuronVals(:,:,5),[],2); - end - - %Load stats of Natural images - zScores_NI = statsNI.ZScoreU; - pValuesNI = statsNI.pvalsResponse; - spkR_NI = max(rwNI.NeuronVals(:,:,4),[],2); - spkDiff_NI = max(rwNI.NeuronVals(:,:,5),[],2); - - %Load stats of video - zScores_NV = statsNV.ZScoreU; - pValuesNV = statsNV.pvalsResponse; - spkR_NV = max(rwNV.NeuronVals(:,:,4),[],2); - spkDiff_NV = max(rwNV.NeuronVals(:,:,5),[],2); - - if ~isequal(params.StatMethod,'ObsWindow') - - spkR_NV = mean(statsNV.ObsReponse,1); - spkR_NI = mean(statsNI.ObsReponse,1); - - try - spkR_SDGs = mean(statsSDG.Static.ObsReponse,1); - spkR_SDGm = mean(statsSDG.Moving.ObsReponse,1); - - catch - spkR_SDGs = mean(statsSDG.ObsReponse,1); - spkR_SDGm = mean(statsSDG.ObsReponse,1); - end - - spkR_FFF = mean(statsFFF.ObsReponse,1); - - try - spkR_MBR = mean(statsMBR.Speed1.ObsReponse,1); - catch - spkR_MBR = mean(statsMBR.ObsReponse,1); - end - - spkR_RG = mean(statsRG.ObsReponse,1); - - if isfield(statsMB, 'Speed2') - spkR_MB = mean(statsMB.Speed2.ObsReponse); - else - spkR_MB = mean(statsMB.Speed1.ObsReponse); - end - - end - - if params.ignoreNonSignif - - zScores_NV(pValuesNV>params.threshold) = -1000; - zScores_NI(pValuesNI>params.threshold) = -1000; - zScores_SDGs(pValuesSDGs>params.threshold) = -1000; - zScores_SDGm(pValuesSDGm>params.threshold) = -1000; - zScores_FFF(pValuesFFF>params.threshold) = -1000; - zScores_MBR(pValuesMBR>params.threshold) = -1000; - zScores_RG(pValuesRG>params.threshold) = -1000; - zScores_MB(pValuesMB>params.threshold) = -1000; - - end - - pvals = {'pValuesMB','pValuesRG','pValuesMBR','pValuesFFF','pValuesSDGm','pValuesSDGs','pValuesNI','pValuesNV'... - ;pValuesMB,pValuesRG,pValuesMBR,pValuesFFF,pValuesSDGm,pValuesSDGs,pValuesNI,pValuesNV}; - - [row, col] = find(cellfun(@(x) ischar(x) && endsWith(x, Stims2Comp{1}), pvals)); - - for i=1:numel(params.ComparePairs) - - [row, col] = find(cellfun(@(x) ischar(x) && endsWith(x, params.ComparePairs{i}), pvals)); - - pvalsC{i}= pvals{2,col}; - - end - - vars = who; - - zscoresC1 = vars(contains(vars,sprintf('zScores_%s',params.ComparePairs{1}))); - zscoresC1 = eval(zscoresC1{1}); - unitIDs = 1:numel(zscoresC1); - zscoresC1 = zscoresC1(pvalsC{1}=BootFirst); - j = j+1; - end - - %%Calculate probabilities - - S.groupStats.Bayes_ZscoreCompare = probs; - S.groupStatsP_ZscoreCompare = ps; - - save([saveDir nameOfFile],'-struct', 'S'); - - end - - - %%%Scatter plot comparison for 2 stimuli Z-score (first and second input) - nexttile - %stims to compare - % boxplot(y2,'Labels',Stims2Comp) - - if isempty(params.StimsToCompare) - ind1 = 1; - ind2 = 2; - else - - ind1 = find(strcmp(Stims2Comp2, params.StimsToCompare{1})); - ind2 = find(strcmp(Stims2Comp2, params.StimsToCompare{2})); - - end - - ValsToCompare = {StimZS{ind1},StimZS{ind2}}; - - if numel(ValsToCompare{1}) == numel(ValsToCompare{2}) - - - scatter(ValsToCompare{1},ValsToCompare{2},10,AnIndex,"filled","MarkerFaceAlpha",0.5) - colormap(colormapUsed) - hold on - axis equal - - lims =[min(y(y>-inf)) max(y)]; - plot(lims, lims, 'k--', 'LineWidth', 1.5) - lims = [-5 40]; - ylim(lims) - xlim(lims) - xlabel(Stims2Comp(ind1)) - ylabel(Stims2Comp(ind2)) - - end - - %%%%%% SPIKE RATE ANALYSIS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - - y = cell2mat(stimRSP); - %y = cell2mat(StimZS); - - - % ---- Swarmchart (Larger Left Subplot) ---- - nexttile % Takes most of the space - if ~params.EachStimSignif - swarmchart(x, y, 5, [colormapUsed(allColorIndices,:)], 'filled','MarkerFaceAlpha',0.7); % Marker size 50 - else - swarmchart(x, y, 5, 'filled','MarkerFaceAlpha',0.7); % Marker size 50 - end - - xticks(1:8); - xticklabels(Stims2Comp2); - ylabel('Spike Rate'); - set(fig,'Color','w') - - %%HIERARCHICAL BOOTSTRAPPING SpikeRate hierBoot - if params.overwriteGroupStats || ~isfield(S, 'groupStats') - FirstStim = y(x==1); - - BootFirst = hierBoot(FirstStim(~isnan(FirstStim)),10000,InsIndex(~isnan(FirstStim)),AnIndex(~isnan(FirstStim))); - j=1; - for i = 2:numel(Stims2Comp2) - secondaryStim = y(x==i); - secondaryStim(isnan(secondaryStim)) =0; - secondaryStim = secondaryStim(secondaryStim~=-inf); - BootSec= hierBoot(secondaryStim,10000,InsIndex(secondaryStim~=-inf),AnIndex(secondaryStim~=-inf)); - probs{j} = get_direct_prob(BootFirst,BootSec); % - ps{j} = mean(BootSec>=BootFirst); - j = j+1; - end - - S.groupStats.Bayes_SpikeRateCompare = probs; - S.groupStats.P_SpikeRateCompare = ps; - end - - %%%Scatter plot comparison for 2 stimuli Z-score (first and second input) - nexttile - ValsToCompare = {stimRSP{ind1},stimRSP{ind2}}; - - if numel(ValsToCompare{1}) == numel(ValsToCompare{2}) - - - scatter(ValsToCompare{1},ValsToCompare{2},10,AnIndex,"filled","MarkerFaceAlpha",0.5) - colormap(colormapUsed) - hold on - axis equal - lims = [0 max(xlim)]; - plot(lims, lims, 'k--', 'LineWidth', 1.5) - ylim(lims) - xlim(lims) - xlabel(Stims2Comp(ind1)) - ylabel(Stims2Comp(ind2)) - end - - -end %% end of analysis comparing multiple pairs - -%% %% ANALYSIS OF QUANTITIES OF RESPONSIVE NEURONS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%Run until here, check insertion list to create bootstrapping of neuronal -%quantities that are responsive to each stim -% AllNeur =0; -% fn = fieldnames(S.stimValsSignif); -% for i = 1:numel(Stims2Comp2) -% -% ending = [Stims2Comp2{i} 'g']; -% pattern = ['^zS.*' ending '$']; -% matches = fn(~cellfun('isempty', regexp(fn, pattern))); -% -% if isequal(Stims2Comp2{i},'SDGm') -% matches2 = fn(~cellfun('isempty', regexp(fn, ['^sumNeur.*' 'SDGm' '$']))); -% elseif isequal(Stims2Comp2{i},'SDGs') -% matches2 = fn(~cellfun('isempty', regexp(fn, ['^sumNeur.*' 'SDGm' '$']))); -% else -% matches2 = fn(~cellfun('isempty', regexp(fn, ['^sumNeur.*' Stims2Comp2{i} '$']))); -% end -% -% matTemp = cell2mat(S.stimValsSignif.(matches{1})); -% matTemp = matTemp(matTemp>-inf); -% RespNeurCountFraction{i} = numel(matTemp)/(sum(cell2mat(S.stimValsSignif.(matches2{1})))); -% RespNeurCount{i} = numel(matTemp); -% AllNeur = AllNeur+sum(cell2mat(S.stimValsSignif.(matches2{1}))); -% -% end - - -%Stimuli pairs to compare - -if isempty(params.ComparePairs) - pairs = {Stims2Comp{1},Stims2Comp{2}}; -else - pairs = params.ComparePairs; -end - - - -[G, insID] = findgroups(S.TableRespNeurs.insertion); -hasAll = splitapply(@(s) all(ismember(unique(categorical(pairs)), s)), S.TableRespNeurs.stimulus, G); - -tempTable = S.TableRespNeurs(hasAll(G) & ismember(S.TableRespNeurs.stimulus, unique(categorical(pairs))),:); - - -%pairs = {"SDGm","SDGs";"MB","MBR";"MB","RG";"NV","NI"}; -nBoot = 10000; -j=1; - - - -%%% BOOTSRAPPING - -ps = zeros(1,size(pairs,1)); - -for i = 1:size(pairs,1) - - diffs = []; - for ins = unique(S.TableRespNeurs.insertion)' - - idx1 = S.TableRespNeurs.insertion == categorical(ins) & S.TableRespNeurs.stimulus == pairs{j,1}; - idx2 = S.TableRespNeurs.insertion == categorical(ins) & S.TableRespNeurs.stimulus == pairs{j,2}; - - if any(idx1) && any(idx2) - diffs(end+1,1) = S.TableRespNeurs.respNeur(idx1)/ S.TableRespNeurs.totalSomaticN(idx1) - S.TableRespNeurs.respNeur(idx2)/S.TableRespNeurs.totalSomaticN(idx1); - end - end - - bootDiff = bootstrp(nBoot, @mean, diffs); - ps(j) = mean(bootDiff<=0); - j = j+1; -end - -[G,expID] = findgroups(tempTable.insertion); -totals = splitapply(@sum, tempTable.respNeur, G); - -tempTable.TotalRespNeur = totals(G); - -%%% PLOTTING - - -fig = plotSwarmBootstrapWithComparisons(tempTable,pairs,ps,{'respNeur','totalSomaticN'},fraction = true, yLegend='Responsive/total units',diff=false, filled = false, Xjitter = 'none',Alpha=0.9); - - ax = gca; - ax.YAxis.FontSize = 8; - ax.YAxis.FontName = 'helvetica'; - - ax = gca; - ax.XAxis.FontSize = 8; - ax.XAxis.FontName = 'helvetica'; - - set(fig, 'Units', 'centimeters'); - set(fig, 'Position', [20 20 4 6]); - -end \ No newline at end of file diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv new file mode 100644 index 0000000..64d7cc1 --- /dev/null +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv @@ -0,0 +1,157 @@ + +%% Run/load bombcell and confusion matrices + +% +exp = [49:54,64:97];% +%tiledlayout(numel(exp),1) +for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) + %%%%%%%%%%%% Load data and data paremeters + %1. Load NP class + NP = loadNPclassFromTable(ex); + vs = linearlyMovingBallAnalysis(NP,Session=1); + KSversion =4; + + [qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",0,KSversion); + + %convertPhySorting2tIc(obj,pathToPhyResults,tStart,BombCelled) + + % + % goodUnits = unitType == 1; + % muaUnits = unitType == 2; + % noiseUnits = unitType == 0; + % nonSomaticUnits = unitType == 3; + + % Concordance analysis + % bc load_manual_classifications(vs.spikeSortingFolder) + % pMC = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,0,1); + % pBC = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); + + bombcell_table = readtable([vs.spikeSortingFolder filesep 'cluster_bc_unitType.tsv'], 'FileType', 'text', 'Delimiter', '\t'); + manual_table = readtable([vs.spikeSortingFolder filesep 'cluster_info.tsv'],'FileType','delimitedtext'); + + manual_table = manual_table(:,{'cluster_id','KSLabel','group'}); + sum(strcmp(pKS.label, 'good')) + + + % Load and prepare data + % Assume: + % bombcell_table: Nx2 table, columns: [id, bc_label] ("GOOD","MUA","NON-SOMA","NOISE") + % manual_table: Mx3 table, columns: [id, KS_label, group] ("good","mua","noise") + + % Rename columns for clarity (adjust if yours differ) + bombcell_table.Properties.VariableNames = {'id', 'bc_label'}; + manual_table.Properties.VariableNames = {'id', 'KS_label', 'group'}; + + % Remove NON-SOMA from bombcell + bc = bombcell_table(~strcmp(bombcell_table.bc_label, 'NON-SOMA'), :); + + % Match IDs — keep only IDs present in both tables + [~, ia, ib] = intersect(bc.id, manual_table.id); + bc_matched = bc(ia, :); + man_matched = manual_table(ib, :); + + % Harmonize labels to lowercase for comparison + bc_labels = lower(bc_matched.bc_label); % "good","mua","noise" + ks_labels = lower(man_matched.KS_label); % "good","mua","noise" + man_labels = lower(man_matched.group); % "good","mua","noise" + + %%Define category order + cats = {'good', 'mua', 'noise'}; + + bc_cat = categorical(bc_labels, cats); + ks_cat = categorical(ks_labels, cats); + man_cat = categorical(man_labels, cats); + + % --- Confusion Matrix 1: Manual curation vs BombCell --- + % figure('Position', [100, 100, 700, 600]); + % + % tiledlayout(3,2) + % nexttile + % cm1 = confusionchart(man_cat, bc_cat, ... + % 'Title', sprintf('%s-Manual curation vs BombCell',NP.recordingName),... + % 'XLabel', 'BombCell', ... + % 'YLabel', 'Manual Curation', ... + % 'RowSummary', 'row-normalized', ... + % 'ColumnSummary', 'column-normalized'); + % + % cm1.FontSize = 9; + % + % % Give the chart more room inside the figure + % %cm1.Position = [10, 10, 680, 580]; + + % --- Confusion Matrix 2: KS label vs BombCell --- + fig = figure('Position', [100, 100, 700, 600]); + %tl = nexttile; + cm2 = confusionchart(ks_cat, bc_cat, ... + 'XLabel', 'BombCell', ... + 'YLabel', 'KS Label', ... + 'RowSummary', 'row-normalized', ... + 'ColumnSummary', 'column-normalized'); + cm2.FontSize = 9; + title(sprintf('%KS Label vs BombCell',NP.recordingName)); + + + + % %% --- Confusion Matrix 3: KS label vs Manual curation --- + % figure; + % cm3 = confusionchart(ks_cat, man_cat, ... + % 'Title', printf('KS Label vs Manual Curation',NP.recordingName), ... + % 'XLabel', 'Manual Curation', ... + % 'YLabel', 'KS Label', ... + % 'RowSummary', 'row-normalized', ... + % 'ColumnSummary', 'column-normalized'); + + % --- Print mismatch summary --- + % fprintf('\n=== Manual vs BombCell ===\n') + % mismatch_man_bc = man_cat ~= bc_cat; + % fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... + % sum(mismatch_man_bc), numel(mismatch_man_bc), ... + % 100*mean(mismatch_man_bc)); + + fprintf('\n=== KS Label vs BombCell ===\n') + mismatch_ks_bc = ks_cat ~= bc_cat; + fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... + sum(mismatch_ks_bc), numel(mismatch_ks_bc), ... + 100*mean(mismatch_ks_bc)); + + vs.printFig(fig,sprintf('%KS Label vs BombCell',NP.recordingName),PaperFig =1) + + close + + % fprintf('\n=== KS Label vs Manual Curation ===\n') + % mismatch_ks_man = ks_cat ~= man_cat; + % fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... + % sum(mismatch_ks_man), numel(mismatch_ks_man), ... + % 100*mean(mismatch_ks_man)); + + % ks = Neuropixel.KilosortDataset(vs.spikeSortingFolder); + % ks.load(); + +end +%I want to compare bombcell unit classification with manual classification in phy. + + + +%% Plot raw waveforms of specific units: + +% 1. Add to path: https://github.com/cortex-lab/spikes +% https://github.com/kwikteam/npy-matlab (dependency) + +ksDir = vs.spikeSortingFolder; +sp = loadKSdir(ksDir); % loads all KS output into a struct + +% Get waveforms +gwfparams.dataDir = ksDir; +gwfparams.fileName = 'recording.bin'; +gwfparams.dataType = 'int16'; +gwfparams.nCh = 385; +gwfparams.wfWin = [-40 41]; % samples around spike +gwfparams.nWf = 100; % waveforms per unit +gwfparams.spikeTimes = sp.st; % spike times +gwfparams.spikeClusters = sp.clu; % cluster IDs + +wf = getWaveforms(gwfparams); % wf.waveForms: [units x waveforms x channels x samples] + +% Plot mean waveform for unit 1, best channel +figure; +plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); \ No newline at end of file diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m index 7abbcf3..d55b448 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m @@ -7,6 +7,7 @@ for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) %%%%%%%%%%%% Load data and data paremeters %1. Load NP class + ex=69 NP = loadNPclassFromTable(ex); vs = linearlyMovingBallAnalysis(NP,Session=1); KSversion =4; @@ -124,8 +125,39 @@ % sum(mismatch_ks_man), numel(mismatch_ks_man), ... % 100*mean(mismatch_ks_man)); + imec = Neuropixel.ImecDataset(NP.recordingDir); + ks = Neuropixel.KilosortDataset(vs.spikeSortingFolder,'imecDataset', imec); + ks.load(); + end %I want to compare bombcell unit classification with manual classification in phy. +%% Plot raw waveforms of specific units: + +% 1. Add to path: https://github.com/cortex-lab/spikes +% https://github.com/kwikteam/npy-matlab (dependency) + + +ksDir = vs.spikeSortingFolder; +sp = loadKSdir(ksDir); % loads all KS output into a struct + +% Get waveforms +gwfparams.dataDir = ksDir; +gwfparams.fileName = NP.recordingDir; +gwfparams.dataType = 'int16'; +gwfparams.nCh = 385; +gwfparams.wfWin = [-40 41]; % samples around spike +gwfparams.nWf = 100; % waveforms per unit +gwfparams.spikeTimes = sp.st; % spike times +gwfparams.spikeClusters = sp.clu; % cluster IDs + +wf = getWaveForms(gwfparams); % wf.waveForms: [units x waveforms x channels x samples] + +% Plot mean waveform for unit 1, best channel +figure; +plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); + +%% +plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) \ No newline at end of file From 0e25919f9f5dc7d1c2775967e98eb031ebc17494 Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Wed, 11 Mar 2026 18:42:13 +0200 Subject: [PATCH 2/8] spattial plotting of waveforms --- general functions/plotRawWaveforms.asv | 115 +++++-------- general functions/plotRawWaveforms.m | 157 ++++++++++++------ .../Run_Bombcell_Automatic_Sorting.asv | 18 +- .../Run_Bombcell_Automatic_Sorting.m | 4 +- 4 files changed, 160 insertions(+), 134 deletions(-) diff --git a/general functions/plotRawWaveforms.asv b/general functions/plotRawWaveforms.asv index 9b352f2..86c2464 100644 --- a/general functions/plotRawWaveforms.asv +++ b/general functions/plotRawWaveforms.asv @@ -5,44 +5,33 @@ function plotRawWaveforms(obj, unitID, params) % INPUTS: % obj - Visual stimulation object with spikeSortingFolder and dataObj % unitID - cluster ID to plot (single unit) -% params - (optional) struct with any of the following fields: % -% WAVEFORM params: -% nWaveforms - number of random waveforms to plot (default: 100) -% nChanAround - channels above/below max amp channel (default: 4) -% nPre - samples before spike peak (default: 20) -% nPost - samples after spike peak (default: 61) +% OPTIONAL NAME-VALUE PARAMS: +% nWaveforms - number of random waveforms to plot (default: 100) +% nChanAround - channels above/below max amp channel (default: 4) +% nPre - samples before spike peak (default: 20) +% nPost - samples after spike peak (default: 61) +% showCorr - plot auto-correlogram (default: false) +% corrWin - correlogram half-window in ms (default: 100) +% corrBin - correlogram bin size in ms (default: 1) % -% CORRELOGRAM params: -% showCorr - plot auto-correlogram (default: false) -% corrWin - correlogram half-window in ms (default: 100) -% corrBin - correlogram bin size in ms (default: 1) -% -% EXAMPLE: -% % Just waveforms with defaults +% EXAMPLES: % plotRawWaveforms(obj, 42) -% -% % Custom params -% params.nWaveforms = 200; -% params.nChanAround = 6; -% params.showCorr = true; -% params.corrWin = 50; -% params.corrBin = 0.5; -% plotRawWaveforms(obj, 42, params) +% plotRawWaveforms(obj, 42, nWaveforms=200, nChanAround=6) +% plotRawWaveforms(obj, 42, showCorr=true, corrWin=50, corrBin=0.5) arguments (Input) obj - unitID (1,1) double - params.nWaveforms = 200; - params.nChanAround = 6; - params.showCorr = true; - params.corrWin = 50; - params.corrBin = 0.5; + unitID (1,1) double + params.nWaveforms (1,1) double = 100 + params.nChanAround (1,1) double = 4 + params.nPre (1,1) double = 20 + params.nPost (1,1) double = 61 + params.showCorr (1,1) logical = false + params.corrWin (1,1) double = 100 + params.corrBin (1,1) double = 1 end -%% Parse params with defaults -params = parseParams(params); - %% Paths ksDir = obj.spikeSortingFolder; recordingDir = obj.dataObj.recordingDir; @@ -51,6 +40,7 @@ recordingDir = obj.dataObj.recordingDir; n_channels = str2double(obj.dataObj.nSavedChansImec); sample_rate = obj.dataObj.samplingFrequency; uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); +chPos = obj.dataObj.chLayoutPositions; fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... n_channels, sample_rate, uV_per_bit); @@ -132,7 +122,7 @@ if params.showCorr [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); end -%% ---- Build layout ---- +%% ---- Waveform figure ---- t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; mean_wf = mean(waveforms, 3, 'omitnan'); std_wf = std(waveforms, 0, 3, 'omitnan'); @@ -141,33 +131,13 @@ chan_depths = chan_pos(chan_indices, 2); [~, depth_order] = sort(chan_depths, 'descend'); % shallowest at top colors = lines(n_chans_plot); -fig = figure('Color', 'w', 'Name', sprintf('Unit %d', unitID)); -if params.showCorr - % Two-column layout: waveforms | correlogram - outer = tiledlayout(fig, 1, 2, 'TileSpacing', 'compact', 'Padding', 'compact'); - title(outer, sprintf('Unit %d | %d waveforms | best ch: %d', ... - unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); - - % Left: nested layout for per-channel waveforms - ax_wf_container = nexttile(outer, 1); - wf_layout = tiledlayout(ax_wf_container.Parent, n_chans_plot, 1, ... - 'TileSpacing', 'none', 'Padding', 'compact'); - wf_layout.Layout.Tile = 1; - xlabel(wf_layout, 'Time (ms)'); - - % Right: correlogram axes - ax_corr = nexttile(outer, 2); -else - % Single-column layout: waveforms only - wf_layout = tiledlayout(fig, n_chans_plot, 1, ... - 'TileSpacing', 'none', 'Padding', 'compact'); - title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... - unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); - xlabel(wf_layout, 'Time (ms)'); -end +figure('Color', 'w', 'Name', sprintf('Unit %d — Waveforms', unitID)); +wf_layout = tiledlayout(n_chans_plot, 1, 'TileSpacing', 'none', 'Padding', 'compact'); +title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... + unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); +xlabel(wf_layout, 'Time (ms)'); -%% Plot one tile per channel wf_axes = gobjects(n_chans_plot, 1); for ci = 1:n_chans_plot plot_ci = depth_order(ci); @@ -199,6 +169,7 @@ for ci = 1:n_chans_plot ylabel(ax, sprintf('ch%d\n%.0fµm', bin_chans(plot_ci), chan_depths(plot_ci)), ... 'FontSize', 7, 'Rotation', 0, 'HorizontalAlignment', 'right'); + % Only show x tick labels on bottom subplot if ci < n_chans_plot set(ax, 'XTickLabel', []); end @@ -208,8 +179,11 @@ end % Shared amplitude scale across all channels linkaxes(wf_axes, 'y'); -%% Plot correlogram +%% ---- ACG figure (separate) ---- if params.showCorr + figure('Color', 'w', 'Name', sprintf('Unit %d — ACG', unitID)); + ax_corr = axes; + bar(ax_corr, ccg_bins, ccg_counts, 1, ... 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); hold(ax_corr, 'on'); @@ -222,8 +196,8 @@ if params.showCorr xlabel(ax_corr, 'Lag (ms)'); ylabel(ax_corr, 'Spike count'); - title(ax_corr, sprintf('ACG | bin %.1f ms | win ±%d ms', ... - params.corrBin, params.corrWin), 'FontSize', 10); + title(ax_corr, sprintf('Unit %d | ACG | bin %.1f ms | win ±%d ms', ... + unitID, params.corrBin, params.corrWin), 'FontSize', 12); xlim(ax_corr, [-params.corrWin params.corrWin]); box(ax_corr, 'off'); end @@ -231,19 +205,6 @@ end end % main function -%% ========================================================================= -function params = parseParams(params) -% Fill in defaults for any missing fields -if ~isfield(params, 'nWaveforms'), params.nWaveforms = 100; end -if ~isfield(params, 'nChanAround'), params.nChanAround = 4; end -if ~isfield(params, 'nPre'), params.nPre = 20; end -if ~isfield(params, 'nPost'), params.nPost = 61; end -if ~isfield(params, 'showCorr'), params.showCorr = false; end -if ~isfield(params, 'corrWin'), params.corrWin = 100; end % ms -if ~isfield(params, 'corrBin'), params.corrBin = 1; end % ms -end - - %% ========================================================================= function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) % Compute auto-correlogram for a single unit @@ -252,15 +213,15 @@ function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin % win_ms - half-window in ms % bin_ms - bin size in ms -st_ms = spike_times_samples / fs * 1000; % convert to ms +st_ms = spike_times_samples / fs * 1000; edges = -win_ms : bin_ms : win_ms; bin_centers = edges(1:end-1) + bin_ms / 2; counts = zeros(1, numel(bin_centers)); for i = 1:numel(st_ms) - diffs = st_ms - st_ms(i); % lag to all other spikes - diffs(i) = NaN; % exclude self - diffs = diffs(diffs > -win_ms & diffs < win_ms); % within window - counts = counts + histcounts(diffs, edges); + diffs = st_ms - st_ms(i); + diffs(i) = NaN; + diffs = diffs(diffs > -win_ms & diffs < win_ms); + counts = counts + histcounts(diffs, edges); end end \ No newline at end of file diff --git a/general functions/plotRawWaveforms.m b/general functions/plotRawWaveforms.m index e84954c..84ad274 100644 --- a/general functions/plotRawWaveforms.m +++ b/general functions/plotRawWaveforms.m @@ -1,9 +1,10 @@ function plotRawWaveforms(obj, unitID, params) % plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style -% Optionally plots an auto-correlogram. +% Channels are drawn at their true probe x/y positions. +% Optionally plots an auto-correlogram in a separate figure. % % INPUTS: -% obj - Visual stimulation object with spikeSortingFolder and dataObj +% obj - Visual stimulation object % unitID - cluster ID to plot (single unit) % % OPTIONAL NAME-VALUE PARAMS: @@ -24,7 +25,7 @@ function plotRawWaveforms(obj, unitID, params) obj unitID (1,1) double params.nWaveforms (1,1) double = 100 - params.nChanAround (1,1) double = 4 + params.nChanAround (1,1) double = 10 params.nPre (1,1) double = 20 params.nPost (1,1) double = 61 params.showCorr (1,1) logical = false @@ -40,6 +41,7 @@ function plotRawWaveforms(obj, unitID, params) n_channels = str2double(obj.dataObj.nSavedChansImec); sample_rate = obj.dataObj.samplingFrequency; uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); +chPos = obj.dataObj.chLayoutPositions; % [2 x nAllCh]: row1=x, row2=y fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... n_channels, sample_rate, uV_per_bit); @@ -54,9 +56,9 @@ function plotRawWaveforms(obj, unitID, params) %% Load KS4 output spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); -templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] -chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed -chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy'));% [nCh x 2] +templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] +chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed +chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy')); % [nCh x 2] %% Find template index for this unit unit_ids = (0 : size(templates, 1) - 1)'; @@ -70,16 +72,20 @@ function plotRawWaveforms(obj, unitID, params) p2p = max(unit_template) - min(unit_template); [~, best_tmpl_chan] = max(p2p); -% Channels to extract: nChanAround above/below best channel -chan_indices = (best_tmpl_chan - params.nChanAround) : (best_tmpl_chan + params.nChanAround); -chan_indices = chan_indices(chan_indices >= 1 & chan_indices <= size(templates, 3)); +% Get probe positions for all template channels via chan_map +% chan_pos is [nTemplateCh x 2]: col1=x, col2=y (from KS4, in µm) +% Find nChanAround closest channels to best channel by Euclidean distance +best_xy = chan_pos(best_tmpl_chan, :); % [1 x 2] +dists = sqrt(sum((chan_pos - best_xy).^2, 2)); % [nTemplateCh x 1] +[~, sorted_idx] = sort(dists, 'ascend'); +chan_indices = sorted_idx(1 : min(params.nChanAround + 1, numel(dists)))'; n_chans_plot = numel(chan_indices); % Index of best channel within the plotted subset best_local_idx = find(chan_indices == best_tmpl_chan); % Map to binary file channels (1-indexed for MATLAB) -bin_chans = chan_map(chan_indices) + 1; +bin_chans = chan_map(chan_indices) + 1; % [n_chans_plot x 1], 1-indexed best_bin_chan = bin_chans(best_local_idx); %% Get spike times for this unit @@ -121,62 +127,111 @@ function plotRawWaveforms(obj, unitID, params) [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); end -%% ---- Waveform figure ---- +%% ---- Spatial positions for plotted channels ---- +% chPos is [2 x nAllCh]: row 1 = x (shank col), row 2 = y (depth) +ch_x = chPos(1, bin_chans); % [1 x n_chans_plot] +ch_y = chPos(2, bin_chans); % [1 x n_chans_plot] + +% Detect inter-channel pitch from all channels on the probe +x_unique = unique(chPos(1,:)); +y_unique = unique(chPos(2,:)); +x_spacing = min(diff(sort(x_unique))); +y_spacing = min(diff(sort(y_unique))); + +if isempty(x_spacing) || numel(x_unique) == 1 + x_spacing = 32; % fallback NP1 column pitch +end +if isempty(y_spacing) || numel(y_unique) == 1 + y_spacing = 20; % fallback NP1 row pitch +end + +% Time axis scaled to fit in x_spacing (80% fill) t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; +t_scale = 0.8 * x_spacing / (t_ms(end) - t_ms(1)); % µm per ms + +% Amplitude scale: normalise so max p2p fits in y_spacing (80% fill) mean_wf = mean(waveforms, 3, 'omitnan'); std_wf = std(waveforms, 0, 3, 'omitnan'); +max_p2p = max(max(mean_wf, [], 2) - min(mean_wf, [], 2)); +if max_p2p == 0, max_p2p = 1; end +amp_scale = 0.8 * y_spacing / max_p2p; % µm per µV -chan_depths = chan_pos(chan_indices, 2); -[~, depth_order] = sort(chan_depths, 'descend'); % shallowest at top - -colors = lines(n_chans_plot); +%% ---- Colours: best channel = red, all others = blue ---- +col_default = [0.25 0.45 0.75]; % blue +col_best = [0.85 0.20 0.15]; % red +%% ---- Waveform figure ---- figure('Color', 'w', 'Name', sprintf('Unit %d — Waveforms', unitID)); -wf_layout = tiledlayout(n_chans_plot, 1, 'TileSpacing', 'none', 'Padding', 'compact'); -title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... - unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); -xlabel(wf_layout, 'Time (ms)'); +ax = axes('Color', 'w'); +hold(ax, 'on'); -wf_axes = gobjects(n_chans_plot, 1); for ci = 1:n_chans_plot - plot_ci = depth_order(ci); - ax = nexttile(wf_layout); - wf_axes(ci) = ax; + cx = ch_x(ci); + cy = ch_y(ci); + + if ci == best_local_idx + col = col_best; + else + col = col_default; + end + + x_wf = cx + t_ms * t_scale; % Individual waveforms (translucent) - wf_ci = squeeze(waveforms(plot_ci, :, :)); - plot(ax, t_ms, wf_ci, 'Color', [colors(plot_ci,:), 0.15], 'LineWidth', 0.5); - hold(ax, 'on'); + wf_ci = squeeze(waveforms(ci, :, :)); % [nSamples x nWaveforms] + y_wf = cy + wf_ci * amp_scale; + plot(ax, x_wf, y_wf, 'Color', [col, 0.12], 'LineWidth', 0.5); % Std shading - upper = mean_wf(plot_ci,:) + std_wf(plot_ci,:); - lower = mean_wf(plot_ci,:) - std_wf(plot_ci,:); - fill(ax, [t_ms, fliplr(t_ms)], [upper, fliplr(lower)], ... - colors(plot_ci,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + upper = cy + (mean_wf(ci,:) + std_wf(ci,:)) * amp_scale; + lower = cy + (mean_wf(ci,:) - std_wf(ci,:)) * amp_scale; + fill(ax, [x_wf, fliplr(x_wf)], [upper, fliplr(lower)], ... + col, 'FaceAlpha', 0.2, 'EdgeColor', 'none'); % Mean waveform - plot(ax, t_ms, mean_wf(plot_ci,:), 'Color', colors(plot_ci,:), 'LineWidth', 2); - - xline(ax, 0, '--k', 'Alpha', 0.3); - - % Highlight best channel with yellow background - if plot_ci == best_local_idx - set(ax, 'Color', [1 1 0.85]); - end - - % Channel label + depth - ylabel(ax, sprintf('ch%d\n%.0fµm', bin_chans(plot_ci), chan_depths(plot_ci)), ... - 'FontSize', 7, 'Rotation', 0, 'HorizontalAlignment', 'right'); - - % Only show x tick labels on bottom subplot - if ci < n_chans_plot - set(ax, 'XTickLabel', []); - end - box(ax, 'off'); + y_mean = cy + mean_wf(ci,:) * amp_scale; + plot(ax, x_wf, y_mean, 'k', 'LineWidth', 2); + + % Channel label: two rows, just left of waveform start + text(ax, x_wf(1) - 2, cy + amp_scale * 0, ... + sprintf('ch%d\n(%g, %g)', bin_chans(ci), cx, cy), ... + 'FontSize', 7, 'HorizontalAlignment', 'right', ... + 'VerticalAlignment', 'middle', 'Color', col); end -% Shared amplitude scale across all channels -linkaxes(wf_axes, 'y'); +%% ---- L-shaped scale bar ---- +% Fixed scale: 2 ms horizontal, 200 µV vertical +sb_ms = 1; % ms +sb_uv = 200; % µV +sb_xlen = sb_ms * t_scale; % µm +sb_ylen = sb_uv * amp_scale; % µm + +% Position: to the right of the bottom-right waveform, at the same y level +[~, bottom_right_ci] = min(ch_y - ch_x * 1e-6); % lowest y, rightmost x as tiebreak +br_cx = ch_x(bottom_right_ci); +br_cy = ch_y(bottom_right_ci); + +sb_gap = 0.2 * x_spacing; % horizontal gap from last waveform +sb_ox = br_cx + t_ms(end) * t_scale + sb_gap; % L corner x: just right of waveform end +sb_oy = br_cy; % L corner y: same level as that channel + +% Draw L: vertical arm then horizontal arm, meeting at bottom-left corner +plot(ax, [sb_ox, sb_ox], [sb_oy, sb_oy - sb_ylen], 'k', 'LineWidth', 2); +plot(ax, [sb_ox, sb_ox + sb_xlen],[sb_oy, sb_oy], 'k', 'LineWidth', 2); + +% Labels +text(ax, sb_ox - 2, sb_oy - sb_ylen/2, sprintf('%d µV', sb_uv), ... + 'FontSize', 8, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle', 'Rotation',90); +text(ax, sb_ox + sb_xlen/2, sb_oy + 2, sprintf('%d ms', sb_ms), ... + 'FontSize', 8, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top'); + +%% Axis cosmetics — no tick marks, no box +title(ax, sprintf('Unit %d | %d waveforms | best ch: %d', ... + unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); +set(ax, 'XTick', [], 'YTick', [], 'YDir', 'normal'); +axis(ax, 'tight'); +box(ax, 'off'); +axis(ax, 'off'); % hide axes entirely — scale bar carries all metric info %% ---- ACG figure (separate) ---- if params.showCorr @@ -195,7 +250,7 @@ function plotRawWaveforms(obj, unitID, params) xlabel(ax_corr, 'Lag (ms)'); ylabel(ax_corr, 'Spike count'); - title(ax_corr, sprintf('Unit %d | ACG | bin %.1f ms | win ±%d ms', ... + title(ax_corr, sprintf('Unit %d ACG | RP 2 ms | bin %.1f ms | win ±%d ms', ... unitID, params.corrBin, params.corrWin), 'FontSize', 12); xlim(ax_corr, [-params.corrWin params.corrWin]); box(ax_corr, 'off'); diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv index 64d7cc1..f01e17a 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv @@ -7,6 +7,7 @@ exp = [49:54,64:97];% for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) %%%%%%%%%%%% Load data and data paremeters %1. Load NP class + ex=69 NP = loadNPclassFromTable(ex); vs = linearlyMovingBallAnalysis(NP,Session=1); KSversion =4; @@ -124,8 +125,9 @@ for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodR % sum(mismatch_ks_man), numel(mismatch_ks_man), ... % 100*mean(mismatch_ks_man)); - % ks = Neuropixel.KilosortDataset(vs.spikeSortingFolder); - % ks.load(); + imec = Neuropixel.ImecDataset(NP.recordingDir); + ks = Neuropixel.KilosortDataset(vs.spikeSortingFolder,'imecDataset', imec); + ks.load(); end %I want to compare bombcell unit classification with manual classification in phy. @@ -137,12 +139,13 @@ end % 1. Add to path: https://github.com/cortex-lab/spikes % https://github.com/kwikteam/npy-matlab (dependency) + ksDir = vs.spikeSortingFolder; sp = loadKSdir(ksDir); % loads all KS output into a struct % Get waveforms gwfparams.dataDir = ksDir; -gwfparams.fileName = 'recording.bin'; +gwfparams.fileName = NP.recordingDir; gwfparams.dataType = 'int16'; gwfparams.nCh = 385; gwfparams.wfWin = [-40 41]; % samples around spike @@ -150,8 +153,13 @@ gwfparams.nWf = 100; % waveforms per unit gwfparams.spikeTimes = sp.st; % spike times gwfparams.spikeClusters = sp.clu; % cluster IDs -wf = getWaveforms(gwfparams); % wf.waveForms: [units x waveforms x channels x samples] +wf = getWaveForms(gwfparams); % wf.waveForms: [units x waveforms x channels x samples] % Plot mean waveform for unit 1, best channel figure; -plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); \ No newline at end of file +plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); + +%% +plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) + +plotRawWaveforms_spatially(vs, 47, showCorr=true, corrWin=50, corrBin=0.5,nChanAround105) \ No newline at end of file diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m index d55b448..131be4c 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m @@ -160,4 +160,6 @@ plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); %% -plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) \ No newline at end of file +plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) + +plotRawWaveforms_spatially(vs, 47, showCorr=true, corrWin=50, corrBin=0.5,nChanAround=10) \ No newline at end of file From 1ef508a2f61c120ad2e5d37c1d98dbadf2be8c1c Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Thu, 12 Mar 2026 21:44:56 +0200 Subject: [PATCH 3/8] Chaanges to waveform plots --- general functions/plotRawWaveforms.asv | 360 +++++++++------ general functions/plotRawWaveforms.m | 416 ++++++++++-------- visualStimulationAnalysis/RunAnalysisClass.m | 4 +- .../Run_Bombcell_Automatic_Sorting.asv | 21 +- .../Run_Bombcell_Automatic_Sorting.m | 25 +- 5 files changed, 484 insertions(+), 342 deletions(-) diff --git a/general functions/plotRawWaveforms.asv b/general functions/plotRawWaveforms.asv index 86c2464..52f21d5 100644 --- a/general functions/plotRawWaveforms.asv +++ b/general functions/plotRawWaveforms.asv @@ -1,30 +1,31 @@ -function plotRawWaveforms(obj, unitID, params) +function [fig1, fig2] = plotRawWaveforms(obj, unitIDs, params) % plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style -% Optionally plots an auto-correlogram. +% Each unit is shown in its own tile at true probe positions. +% Optionally plots ACGs for all units in a single tiled figure. % % INPUTS: -% obj - Visual stimulation object with spikeSortingFolder and dataObj -% unitID - cluster ID to plot (single unit) +% obj - Visual stimulation object +% unitIDs - scalar or vector of cluster IDs to plot e.g. 42 or [3 7 42] % % OPTIONAL NAME-VALUE PARAMS: % nWaveforms - number of random waveforms to plot (default: 100) -% nChanAround - channels above/below max amp channel (default: 4) +% nChanAround - nearest channels around max amp channel (default: 10) % nPre - samples before spike peak (default: 20) % nPost - samples after spike peak (default: 61) -% showCorr - plot auto-correlogram (default: false) +% showCorr - plot auto-correlogram figure (default: false) % corrWin - correlogram half-window in ms (default: 100) % corrBin - correlogram bin size in ms (default: 1) % % EXAMPLES: % plotRawWaveforms(obj, 42) -% plotRawWaveforms(obj, 42, nWaveforms=200, nChanAround=6) -% plotRawWaveforms(obj, 42, showCorr=true, corrWin=50, corrBin=0.5) +% plotRawWaveforms(obj, [3 7 42], nWaveforms=200, nChanAround=6) +% plotRawWaveforms(obj, [3 7 42], showCorr=true, corrWin=50, corrBin=0.5) arguments (Input) obj - unitID (1,1) double + unitIDs (1,:) double params.nWaveforms (1,1) double = 100 - params.nChanAround (1,1) double = 4 + params.nChanAround (1,1) double = 10 params.nPre (1,1) double = 20 params.nPost (1,1) double = 61 params.showCorr (1,1) logical = false @@ -32,6 +33,8 @@ arguments (Input) params.corrBin (1,1) double = 1 end +nUnits = numel(unitIDs); + %% Paths ksDir = obj.spikeSortingFolder; recordingDir = obj.dataObj.recordingDir; @@ -40,7 +43,7 @@ recordingDir = obj.dataObj.recordingDir; n_channels = str2double(obj.dataObj.nSavedChansImec); sample_rate = obj.dataObj.samplingFrequency; uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); -chPos = obj.dataObj.chLayoutPositions; +chPos = obj.dataObj.chLayoutPositions; % [2 x nAllCh]: row1=x, row2=y fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... n_channels, sample_rate, uV_per_bit); @@ -52,154 +55,236 @@ if isempty(binFiles), error('No .bin or .dat file found in: %s', recordingDir); binPath = fullfile(recordingDir, binFiles(1).name); fprintf('Using binary file: %s\n', binPath); -%% Load KS4 output +%% Load KS4 output (once, shared across all units) spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); -templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] -chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed -chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy'));% [nCh x 2] - -%% Find template index for this unit -unit_ids = (0 : size(templates, 1) - 1)'; -tmpl_idx = find(unit_ids == unitID); -if isempty(tmpl_idx) - error('Unit %d not found in templates.npy', unitID); -end - -%% Find best channel (max peak-to-peak across template channels) -unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] -p2p = max(unit_template) - min(unit_template); -[~, best_tmpl_chan] = max(p2p); - -% Channels to extract: nChanAround above/below best channel -chan_indices = (best_tmpl_chan - params.nChanAround) : (best_tmpl_chan + params.nChanAround); -chan_indices = chan_indices(chan_indices >= 1 & chan_indices <= size(templates, 3)); -n_chans_plot = numel(chan_indices); +templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] +chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed +chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy')); % [nCh x 2] -% Index of best channel within the plotted subset -best_local_idx = find(chan_indices == best_tmpl_chan); +unit_ids_ks = (0 : size(templates, 1) - 1)'; -% Map to binary file channels (1-indexed for MATLAB) -bin_chans = chan_map(chan_indices) + 1; -best_bin_chan = bin_chans(best_local_idx); +%% Probe pitch (shared across all units) +x_unique = unique(chPos(1,:)); +y_unique = unique(chPos(2,:)); +x_spacing = min(diff(sort(x_unique))); +y_spacing = min(diff(sort(y_unique))); +if isempty(x_spacing) || numel(x_unique) == 1, x_spacing = 32; end +if isempty(y_spacing) || numel(y_unique) == 1, y_spacing = 20; end -%% Get spike times for this unit -st = double(spike_times(spike_clusters == unitID)); -if numel(st) < 2, error('Unit %d has fewer than 2 spikes.', unitID); end -fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... - unitID, numel(st), min(params.nWaveforms, numel(st))); +t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; -idx = randperm(numel(st), min(params.nWaveforms, numel(st))); -st_sub = st(idx); +%% Colours +col_default = [0.25 0.45 0.75]; % blue +col_best = [0.85 0.20 0.15]; % red -%% Extract waveforms from binary -waveform_len = params.nPre + params.nPost + 1; +%% ---- Extract data for each unit ---- finfo = dir(binPath); n_samp_total = finfo.bytes / (n_channels * 2); fid = fopen(binPath, 'rb'); -waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); +unitData = struct(); % will hold per-unit results -for si = 1:numel(st_sub) - s0 = st_sub(si) - params.nPre; - s1 = st_sub(si) + params.nPost; - if s0 < 1 || s1 > n_samp_total, continue; end +for ui = 1:nUnits + unitID = unitIDs(ui); - fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); - raw = fread(fid, [n_channels, waveform_len], '*int16'); - if size(raw, 2) < waveform_len, continue; end + % Template index + tmpl_idx = find(unit_ids_ks == unitID); + if isempty(tmpl_idx) + warning('Unit %d not found in templates.npy, skipping.', unitID); + unitData(ui).valid = false; + continue + end - waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; -end -fclose(fid); + % Best channel by p2p on template + unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] + p2p = max(unit_template) - min(unit_template); + [~, best_tmpl_chan] = max(p2p); + + % nChanAround nearest channels by Euclidean distance on probe + best_xy = chan_pos(best_tmpl_chan, :); + dists = sqrt(sum((chan_pos - best_xy).^2, 2)); + [~, sorted_idx] = sort(dists, 'ascend'); + chan_indices = sorted_idx(1 : min(params.nChanAround + 1, numel(dists)))'; + n_chans_plot = numel(chan_indices); + best_local_idx = find(chan_indices == best_tmpl_chan); + + bin_chans = chan_map(chan_indices) + 1; % 1-indexed + best_bin_chan = bin_chans(best_local_idx); + + % Spike times for this unit + st = double(spike_times(spike_clusters == unitID)); + if numel(st) < 2 + warning('Unit %d has fewer than 2 spikes, skipping.', unitID); + unitData(ui).valid = false; + continue + end -% Baseline subtract (mean of pre-spike window) -baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); -waveforms = waveforms - baseline; + % Random subsample + idx = randperm(numel(st), min(params.nWaveforms, numel(st))); + st_sub = st(idx); + fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... + unitID, numel(st), numel(st_sub)); + + % Extract waveforms + waveform_len = params.nPre + params.nPost + 1; + waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); + + for si = 1:numel(st_sub) + s0 = st_sub(si) - params.nPre; + s1 = st_sub(si) + params.nPost; + if s0 < 1 || s1 > n_samp_total, continue; end + fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); + raw = fread(fid, [n_channels, waveform_len], '*int16'); + if size(raw, 2) < waveform_len, continue; end + waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; + end -%% Compute correlogram if requested -if params.showCorr - [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); + % Baseline subtract + baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); + waveforms = waveforms - baseline; + + % Store + unitData(ui).valid = true; + unitData(ui).unitID = unitID; + unitData(ui).waveforms = waveforms; + unitData(ui).mean_wf = mean(waveforms, 3, 'omitnan'); + unitData(ui).std_wf = std(waveforms, 0, 3, 'omitnan'); + unitData(ui).bin_chans = bin_chans; + unitData(ui).best_bin_chan = best_bin_chan; + unitData(ui).best_local_idx= best_local_idx; + unitData(ui).n_chans_plot = n_chans_plot; + unitData(ui).ch_x = chPos(1, bin_chans); + unitData(ui).ch_y = chPos(2, bin_chans); + unitData(ui).st = st; + unitData(ui).n_wf = numel(st_sub); + + % ACG + if params.showCorr + [unitData(ui).ccg_counts, unitData(ui).ccg_bins] = ... + computeACG(st, sample_rate, params.corrWin, params.corrBin); + end end +fclose(fid); -%% ---- Waveform figure ---- -t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; -mean_wf = mean(waveforms, 3, 'omitnan'); -std_wf = std(waveforms, 0, 3, 'omitnan'); - -chan_depths = chan_pos(chan_indices, 2); -[~, depth_order] = sort(chan_depths, 'descend'); % shallowest at top - -colors = lines(n_chans_plot); - -figure('Color', 'w', 'Name', sprintf('Unit %d — Waveforms', unitID)); -wf_layout = tiledlayout(n_chans_plot, 1, 'TileSpacing', 'none', 'Padding', 'compact'); -title(wf_layout, sprintf('Unit %d | %d waveforms | best ch: %d', ... - unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); -xlabel(wf_layout, 'Time (ms)'); - -wf_axes = gobjects(n_chans_plot, 1); -for ci = 1:n_chans_plot - plot_ci = depth_order(ci); - ax = nexttile(wf_layout); - wf_axes(ci) = ax; - - % Individual waveforms (translucent) - wf_ci = squeeze(waveforms(plot_ci, :, :)); - plot(ax, t_ms, wf_ci, 'Color', [colors(plot_ci,:), 0.15], 'LineWidth', 0.5); +%% ---- Waveform figure: one tile per unit ---- +% Determine tiled layout dimensions +nCols = min(nUnits, 4); +nRows = ceil(nUnits / nCols); + +fig1 = figure('Color', 'w', 'Name', 'Waveforms'); +wf_tl = tiledlayout(fig1, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); +title(wf_tl, 'Raw Waveforms', 'FontSize', 13, 'FontWeight', 'bold'); + +for ui = 1:nUnits + if ~unitData(ui).valid, continue; end + + d = unitData(ui); + mean_wf = d.mean_wf; + std_wf = d.std_wf; + ch_x = d.ch_x; + ch_y = d.ch_y; + bin_chans = d.bin_chans; + best_local_idx = d.best_local_idx; + n_chans_plot = d.n_chans_plot; + + % Per-unit amplitude scale: use mean±std envelope to prevent overlap + % on noisy units (large std compresses the scale automatically) + upper_env = max(mean_wf + std_wf, [], 2); % [nCh x 1] + lower_env = min(mean_wf - std_wf, [], 2); + max_extent = max(upper_env - lower_env); + if max_extent == 0, max_extent = 1; end + amp_scale = 0.8 * y_spacing / max_extent; + t_scale = 0.8 * x_spacing / (t_ms(end) - t_ms(1)); + + % Scale bar µV: round max amplitude to nearest 50 µV + sb_uv = max(50, round(max_extent / 50) * 50); + + ax = nexttile(wf_tl); hold(ax, 'on'); - % Std shading - upper = mean_wf(plot_ci,:) + std_wf(plot_ci,:); - lower = mean_wf(plot_ci,:) - std_wf(plot_ci,:); - fill(ax, [t_ms, fliplr(t_ms)], [upper, fliplr(lower)], ... - colors(plot_ci,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); - - % Mean waveform - plot(ax, t_ms, mean_wf(plot_ci,:), 'Color', colors(plot_ci,:), 'LineWidth', 2); - - xline(ax, 0, '--k', 'Alpha', 0.3); - - % Highlight best channel with yellow background - if plot_ci == best_local_idx - set(ax, 'Color', [1 1 0.85]); + for ci = 1:n_chans_plot + cx = ch_x(ci); + cy = ch_y(ci); + col = col_default; + if ci == best_local_idx, col = col_best; end + + x_wf = cx + t_ms * t_scale; + + % Individual waveforms + wf_ci = squeeze(d.waveforms(ci, :, :)); + plot(ax, x_wf, cy + wf_ci * amp_scale, ... + 'Color', [col, 0.12], 'LineWidth', 0.5); + + % Std shading + upper = cy + (mean_wf(ci,:) + std_wf(ci,:)) * amp_scale; + lower = cy + (mean_wf(ci,:) - std_wf(ci,:)) * amp_scale; + fill(ax, [x_wf, fliplr(x_wf)], [upper, fliplr(lower)], ... + col, 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + + % Mean waveform (black), with coloured std shading + plot(ax, x_wf, cy + mean_wf(ci,:) * amp_scale, ... + 'Color', 'k', 'LineWidth', 2); + + % Channel label (two rows, left of waveform start) + text(ax, x_wf(1) - 2, cy, ... + sprintf('ch%d\n(%g,%g)', bin_chans(ci), cx, cy), ... + 'FontSize', 6, 'HorizontalAlignment', 'right', ... + 'VerticalAlignment', 'middle', 'Color', col); end - % Channel label + depth - ylabel(ax, sprintf('ch%d\n%.0fµm', bin_chans(plot_ci), chan_depths(plot_ci)), ... - 'FontSize', 7, 'Rotation', 0, 'HorizontalAlignment', 'right'); - - % Only show x tick labels on bottom subplot - if ci < n_chans_plot - set(ax, 'XTickLabel', []); - end - box(ax, 'off'); + % L-scale bar: bottom-right channel of this unit + sb_ms = 1; % sb_uv already set above + sb_xlen = sb_ms * t_scale; + sb_ylen = sb_uv * amp_scale; + + [~, br_ci] = min(ch_y - ch_x * 1e-6); + sb_ox = ch_x(br_ci) + t_ms(end) * t_scale + 0.2 * x_spacing; + sb_oy = ch_y(br_ci); + + plot(ax, [sb_ox, sb_ox], [sb_oy, sb_oy - sb_ylen], 'k', 'LineWidth', 2); + plot(ax, [sb_ox, sb_ox + sb_xlen], [sb_oy, sb_oy], 'k', 'LineWidth', 2); + text(ax, sb_ox - 2, sb_oy - sb_ylen/2, sprintf('%d µV', sb_uv), ... + 'FontSize', 7, 'HorizontalAlignment', 'center', ... + 'VerticalAlignment', 'middle', 'Rotation', 90); + text(ax, sb_ox + sb_xlen/2, sb_oy + 2, sprintf('%d ms', sb_ms), ... + 'FontSize', 7, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top'); + + title(ax, sprintf('Unit %d | ch%d | n=%d', ... + d.unitID, d.best_bin_chan, d.n_wf), 'FontSize', 9); + axis(ax, 'tight'); + axis(ax, 'off'); end -% Shared amplitude scale across all channels -linkaxes(wf_axes, 'y'); - -%% ---- ACG figure (separate) ---- +%% ---- ACG figure: one tile per unit ---- if params.showCorr - figure('Color', 'w', 'Name', sprintf('Unit %d — ACG', unitID)); - ax_corr = axes; - - bar(ax_corr, ccg_bins, ccg_counts, 1, ... - 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); - hold(ax_corr, 'on'); - xline(ax_corr, 0, '--k', 'Alpha', 0.4); - - % Shade refractory period (±2 ms) - ylims = ylim(ax_corr); - patch(ax_corr, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... - 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); - - xlabel(ax_corr, 'Lag (ms)'); - ylabel(ax_corr, 'Spike count'); - title(ax_corr, sprintf('Unit %d | ACG | bin %.1f ms | win ±%d ms', ... - unitID, params.corrBin, params.corrWin), 'FontSize', 12); - xlim(ax_corr, [-params.corrWin params.corrWin]); - box(ax_corr, 'off'); + fig2 = figure('Color', 'w', 'Name', 'ACGs'); + acg_tl = tiledlayout(fig2, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); + title(acg_tl, sprintf('ACG | RP 2 ms | bin %.1f ms | win ±%d ms', ... + params.corrBin, params.corrWin), 'FontSize', 12, 'FontWeight', 'bold'); + xlabel(acg_tl, 'Lag (ms)'); + ylabel(acg_tl, 'Spike count'); + + for ui = 1:nUnits + if ~unitData(ui).valid, continue; end + d = unitData(ui); + + ax_c = nexttile(acg_tl); + bar(ax_c, d.ccg_bins, d.ccg_counts, 1, ... + 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); + hold(ax_c, 'on'); + xline(ax_c, 0, '--k', 'Alpha', 0.4); + + ylims = ylim(ax_c); + patch(ax_c, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... + 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); + + xlim(ax_c, [-params.corrWin params.corrWin]); + title(ax_c, sprintf('Unit %d', d.unitID), 'FontSize', 9); + box(ax_c, 'off'); + end +else + fig2 = []; end end % main function @@ -207,17 +292,10 @@ end % main function %% ========================================================================= function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) -% Compute auto-correlogram for a single unit -% spike_times_samples - spike times in samples -% fs - sampling rate (Hz) -% win_ms - half-window in ms -% bin_ms - bin size in ms - st_ms = spike_times_samples / fs * 1000; edges = -win_ms : bin_ms : win_ms; bin_centers = edges(1:end-1) + bin_ms / 2; counts = zeros(1, numel(bin_centers)); - for i = 1:numel(st_ms) diffs = st_ms - st_ms(i); diffs(i) = NaN; diff --git a/general functions/plotRawWaveforms.m b/general functions/plotRawWaveforms.m index 84ad274..b39bfb5 100644 --- a/general functions/plotRawWaveforms.m +++ b/general functions/plotRawWaveforms.m @@ -1,29 +1,29 @@ -function plotRawWaveforms(obj, unitID, params) +function [fig1, fig2] = plotRawWaveforms(obj, unitIDs, params) % plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style -% Channels are drawn at their true probe x/y positions. -% Optionally plots an auto-correlogram in a separate figure. +% Each unit is shown in its own tile at true probe positions. +% Optionally plots ACGs for all units in a single tiled figure. % % INPUTS: -% obj - Visual stimulation object -% unitID - cluster ID to plot (single unit) +% obj - Visual stimulation object +% unitIDs - scalar or vector of cluster IDs to plot e.g. 42 or [3 7 42] % % OPTIONAL NAME-VALUE PARAMS: % nWaveforms - number of random waveforms to plot (default: 100) -% nChanAround - channels above/below max amp channel (default: 4) +% nChanAround - nearest channels around max amp channel (default: 10) % nPre - samples before spike peak (default: 20) % nPost - samples after spike peak (default: 61) -% showCorr - plot auto-correlogram (default: false) +% showCorr - plot auto-correlogram figure (default: false) % corrWin - correlogram half-window in ms (default: 100) % corrBin - correlogram bin size in ms (default: 1) % % EXAMPLES: % plotRawWaveforms(obj, 42) -% plotRawWaveforms(obj, 42, nWaveforms=200, nChanAround=6) -% plotRawWaveforms(obj, 42, showCorr=true, corrWin=50, corrBin=0.5) +% plotRawWaveforms(obj, [3 7 42], nWaveforms=200, nChanAround=6) +% plotRawWaveforms(obj, [3 7 42], showCorr=true, corrWin=50, corrBin=0.5) arguments (Input) obj - unitID (1,1) double + unitIDs (1,:) double params.nWaveforms (1,1) double = 100 params.nChanAround (1,1) double = 10 params.nPre (1,1) double = 20 @@ -33,6 +33,8 @@ function plotRawWaveforms(obj, unitID, params) params.corrBin (1,1) double = 1 end +nUnits = numel(unitIDs); + %% Paths ksDir = obj.spikeSortingFolder; recordingDir = obj.dataObj.recordingDir; @@ -53,207 +55,246 @@ function plotRawWaveforms(obj, unitID, params) binPath = fullfile(recordingDir, binFiles(1).name); fprintf('Using binary file: %s\n', binPath); -%% Load KS4 output +%% Load KS4 output (once, shared across all units) spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy')); % [nCh x 2] -%% Find template index for this unit -unit_ids = (0 : size(templates, 1) - 1)'; -tmpl_idx = find(unit_ids == unitID); -if isempty(tmpl_idx) - error('Unit %d not found in templates.npy', unitID); -end - -%% Find best channel (max peak-to-peak across template channels) -unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] -p2p = max(unit_template) - min(unit_template); -[~, best_tmpl_chan] = max(p2p); - -% Get probe positions for all template channels via chan_map -% chan_pos is [nTemplateCh x 2]: col1=x, col2=y (from KS4, in µm) -% Find nChanAround closest channels to best channel by Euclidean distance -best_xy = chan_pos(best_tmpl_chan, :); % [1 x 2] -dists = sqrt(sum((chan_pos - best_xy).^2, 2)); % [nTemplateCh x 1] -[~, sorted_idx] = sort(dists, 'ascend'); -chan_indices = sorted_idx(1 : min(params.nChanAround + 1, numel(dists)))'; -n_chans_plot = numel(chan_indices); - -% Index of best channel within the plotted subset -best_local_idx = find(chan_indices == best_tmpl_chan); - -% Map to binary file channels (1-indexed for MATLAB) -bin_chans = chan_map(chan_indices) + 1; % [n_chans_plot x 1], 1-indexed -best_bin_chan = bin_chans(best_local_idx); - -%% Get spike times for this unit -st = double(spike_times(spike_clusters == unitID)); -if numel(st) < 2, error('Unit %d has fewer than 2 spikes.', unitID); end -fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... - unitID, numel(st), min(params.nWaveforms, numel(st))); - -idx = randperm(numel(st), min(params.nWaveforms, numel(st))); -st_sub = st(idx); - -%% Extract waveforms from binary -waveform_len = params.nPre + params.nPost + 1; -finfo = dir(binPath); -n_samp_total = finfo.bytes / (n_channels * 2); -fid = fopen(binPath, 'rb'); - -waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); - -for si = 1:numel(st_sub) - s0 = st_sub(si) - params.nPre; - s1 = st_sub(si) + params.nPost; - if s0 < 1 || s1 > n_samp_total, continue; end - - fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); - raw = fread(fid, [n_channels, waveform_len], '*int16'); - if size(raw, 2) < waveform_len, continue; end +unit_ids_ks = (0 : size(templates, 1) - 1)'; - waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; -end -fclose(fid); - -% Baseline subtract (mean of pre-spike window) -baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); -waveforms = waveforms - baseline; - -%% Compute correlogram if requested -if params.showCorr - [ccg_counts, ccg_bins] = computeACG(st, sample_rate, params.corrWin, params.corrBin); -end - -%% ---- Spatial positions for plotted channels ---- -% chPos is [2 x nAllCh]: row 1 = x (shank col), row 2 = y (depth) -ch_x = chPos(1, bin_chans); % [1 x n_chans_plot] -ch_y = chPos(2, bin_chans); % [1 x n_chans_plot] - -% Detect inter-channel pitch from all channels on the probe +%% Probe pitch (shared across all units) x_unique = unique(chPos(1,:)); y_unique = unique(chPos(2,:)); x_spacing = min(diff(sort(x_unique))); y_spacing = min(diff(sort(y_unique))); +if isempty(x_spacing) || numel(x_unique) == 1, x_spacing = 32; end +if isempty(y_spacing) || numel(y_unique) == 1, y_spacing = 20; end -if isempty(x_spacing) || numel(x_unique) == 1 - x_spacing = 32; % fallback NP1 column pitch -end -if isempty(y_spacing) || numel(y_unique) == 1 - y_spacing = 20; % fallback NP1 row pitch -end - -% Time axis scaled to fit in x_spacing (80% fill) -t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; -t_scale = 0.8 * x_spacing / (t_ms(end) - t_ms(1)); % µm per ms - -% Amplitude scale: normalise so max p2p fits in y_spacing (80% fill) -mean_wf = mean(waveforms, 3, 'omitnan'); -std_wf = std(waveforms, 0, 3, 'omitnan'); -max_p2p = max(max(mean_wf, [], 2) - min(mean_wf, [], 2)); -if max_p2p == 0, max_p2p = 1; end -amp_scale = 0.8 * y_spacing / max_p2p; % µm per µV +t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; -%% ---- Colours: best channel = red, all others = blue ---- +%% Colours col_default = [0.25 0.45 0.75]; % blue col_best = [0.85 0.20 0.15]; % red -%% ---- Waveform figure ---- -figure('Color', 'w', 'Name', sprintf('Unit %d — Waveforms', unitID)); -ax = axes('Color', 'w'); -hold(ax, 'on'); +%% ---- Extract data for each unit ---- +finfo = dir(binPath); +n_samp_total = finfo.bytes / (n_channels * 2); +fid = fopen(binPath, 'rb'); -for ci = 1:n_chans_plot - cx = ch_x(ci); - cy = ch_y(ci); +unitData = struct(); % will hold per-unit results - if ci == best_local_idx - col = col_best; - else - col = col_default; +for ui = 1:nUnits + unitID = unitIDs(ui); + + % Template index + tmpl_idx = find(unit_ids_ks == unitID); + if isempty(tmpl_idx) + warning('Unit %d not found in templates.npy, skipping.', unitID); + unitData(ui).valid = false; + continue end - x_wf = cx + t_ms * t_scale; + % Best channel by p2p on template + unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] + p2p = max(unit_template) - min(unit_template); + [~, best_tmpl_chan] = max(p2p); + + % nChanAround nearest channels by Euclidean distance on probe + best_xy = chan_pos(best_tmpl_chan, :); + dists = sqrt(sum((chan_pos - best_xy).^2, 2)); + [~, sorted_idx] = sort(dists, 'ascend'); + chan_indices = sorted_idx(1 : min(params.nChanAround + 1, numel(dists)))'; + n_chans_plot = numel(chan_indices); + best_local_idx = find(chan_indices == best_tmpl_chan); + + bin_chans = chan_map(chan_indices) + 1; % 1-indexed + best_bin_chan = bin_chans(best_local_idx); + + % Spike times for this unit + st = double(spike_times(spike_clusters == unitID)); + if numel(st) < 2 + warning('Unit %d has fewer than 2 spikes, skipping.', unitID); + unitData(ui).valid = false; + continue + end - % Individual waveforms (translucent) - wf_ci = squeeze(waveforms(ci, :, :)); % [nSamples x nWaveforms] - y_wf = cy + wf_ci * amp_scale; - plot(ax, x_wf, y_wf, 'Color', [col, 0.12], 'LineWidth', 0.5); + % Random subsample + idx = randperm(numel(st), min(params.nWaveforms, numel(st))); + st_sub = st(idx); + fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... + unitID, numel(st), numel(st_sub)); + + % Extract waveforms + waveform_len = params.nPre + params.nPost + 1; + waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); + + for si = 1:numel(st_sub) + s0 = st_sub(si) - params.nPre; + s1 = st_sub(si) + params.nPost; + if s0 < 1 || s1 > n_samp_total, continue; end + fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); + raw = fread(fid, [n_channels, waveform_len], '*int16'); + if size(raw, 2) < waveform_len, continue; end + waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; + end - % Std shading - upper = cy + (mean_wf(ci,:) + std_wf(ci,:)) * amp_scale; - lower = cy + (mean_wf(ci,:) - std_wf(ci,:)) * amp_scale; - fill(ax, [x_wf, fliplr(x_wf)], [upper, fliplr(lower)], ... - col, 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + % Baseline subtract + baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); + waveforms = waveforms - baseline; + + % Store + unitData(ui).valid = true; + unitData(ui).unitID = unitID; + unitData(ui).waveforms = waveforms; + % Exclude outlier waveforms based on peak-to-peak MAD + % Compute p2p amplitude for each waveform (max across channels and time) + wf_p2p = squeeze(max(max(waveforms,[],1),[],2) - ... + min(min(waveforms,[],1),[],2)); % [1 x nWaveforms] + wf_median = median(wf_p2p, 'omitnan'); + wf_mad = median(abs(wf_p2p - wf_median), 'omitnan'); + inlier_mask = abs(wf_p2p - wf_median) < 5 * wf_mad; % 5-MAD threshold + fprintf('Unit %d: %d/%d waveforms kept for envelope (outlier rejection)\n', ... + unitID, sum(inlier_mask), numel(inlier_mask)); + + unitData(ui).mean_wf = mean(waveforms(:,:,inlier_mask), 3, 'omitnan'); + unitData(ui).std_wf = std(waveforms(:,:,inlier_mask), 0, 3, 'omitnan'); + unitData(ui).bin_chans = bin_chans; + unitData(ui).best_bin_chan = best_bin_chan; + unitData(ui).best_local_idx= best_local_idx; + unitData(ui).n_chans_plot = n_chans_plot; + unitData(ui).ch_x = chPos(1, bin_chans); + unitData(ui).ch_y = chPos(2, bin_chans); + unitData(ui).st = st; + unitData(ui).n_wf = numel(st_sub); + + % ACG + if params.showCorr + [unitData(ui).ccg_counts, unitData(ui).ccg_bins] = ... + computeACG(st, sample_rate, params.corrWin, params.corrBin); + end +end +fclose(fid); - % Mean waveform - y_mean = cy + mean_wf(ci,:) * amp_scale; - plot(ax, x_wf, y_mean, 'k', 'LineWidth', 2); +%% ---- Waveform figure: one tile per unit ---- +% Determine tiled layout dimensions +nCols = min(nUnits, 4); +nRows = ceil(nUnits / nCols); + +fig1 = figure('Color', 'w', 'Name', 'Waveforms'); +wf_tl = tiledlayout(fig1, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); +title(wf_tl, 'Raw Waveforms', 'FontSize', 13, 'FontWeight', 'bold'); + +for ui = 1:nUnits + if ~unitData(ui).valid, continue; end + + d = unitData(ui); + mean_wf = d.mean_wf; + std_wf = d.std_wf; + ch_x = d.ch_x; + ch_y = d.ch_y; + bin_chans = d.bin_chans; + best_local_idx = d.best_local_idx; + n_chans_plot = d.n_chans_plot; + + % Per-unit amplitude scale: use mean±std envelope to prevent overlap + % on noisy units (large std compresses the scale automatically) + upper_env = max(mean_wf + std_wf, [], 2); % [nCh x 1] + lower_env = min(mean_wf - std_wf, [], 2); + max_extent = max(upper_env - lower_env); + if max_extent == 0, max_extent = 1; end + amp_scale = 0.8 * y_spacing / max_extent; + t_scale = 0.8 * x_spacing / (t_ms(end) - t_ms(1)); + + % Scale bar µV: round max amplitude to nearest 50 µV + sb_uv = max(50, round(max_extent / 50) * 50); + + ax = nexttile(wf_tl); + hold(ax, 'on'); + + for ci = 1:n_chans_plot + cx = ch_x(ci); + cy = ch_y(ci); + col = col_default; + if ci == best_local_idx, col = col_best; end + + x_wf = cx + t_ms * t_scale; + + % Individual waveforms + wf_ci = squeeze(d.waveforms(ci, :, :)); + plot(ax, x_wf, cy + wf_ci * amp_scale, ... + 'Color', [col, 0.12], 'LineWidth', 0.5); + + % Std shading + upper = cy + (mean_wf(ci,:) + std_wf(ci,:)) * amp_scale; + lower = cy + (mean_wf(ci,:) - std_wf(ci,:)) * amp_scale; + fill(ax, [x_wf, fliplr(x_wf)], [upper, fliplr(lower)], ... + col, 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + + % Mean waveform (black), with coloured std shading + plot(ax, x_wf, cy + mean_wf(ci,:) * amp_scale, ... + 'Color', 'k', 'LineWidth', 2); + + % Channel label (two rows, left of waveform start) + text(ax, x_wf(1) - 2, cy, ... + sprintf('ch%d\n(%g,%g)', bin_chans(ci), cx, cy), ... + 'FontSize', 6, 'HorizontalAlignment', 'right', ... + 'VerticalAlignment', 'middle', 'Color', col); + end - % Channel label: two rows, just left of waveform start - text(ax, x_wf(1) - 2, cy + amp_scale * 0, ... - sprintf('ch%d\n(%g, %g)', bin_chans(ci), cx, cy), ... - 'FontSize', 7, 'HorizontalAlignment', 'right', ... - 'VerticalAlignment', 'middle', 'Color', col); + % L-scale bar: bottom-right channel of this unit + sb_ms = 1; % sb_uv already set above + sb_xlen = sb_ms * t_scale; + sb_ylen = sb_uv * amp_scale; + + [~, br_ci] = min(ch_y - ch_x * 1e-6); + sb_ox = ch_x(br_ci) + t_ms(end) * t_scale + 0.2 * x_spacing; + sb_oy = ch_y(br_ci); + + plot(ax, [sb_ox, sb_ox], [sb_oy, sb_oy - sb_ylen], 'k', 'LineWidth', 2); + plot(ax, [sb_ox, sb_ox + sb_xlen], [sb_oy, sb_oy], 'k', 'LineWidth', 2); + text(ax, sb_ox - 2, sb_oy - sb_ylen/2, sprintf('%d µV', sb_uv), ... + 'FontSize', 7, 'HorizontalAlignment', 'center', ... + 'VerticalAlignment', 'middle', 'Rotation', 90); + text(ax, sb_ox + sb_xlen/2, sb_oy + 2, sprintf('%d ms', sb_ms), ... + 'FontSize', 7, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top'); + + title(ax, sprintf('Unit %d | ch%d | n=%d', ... + d.unitID, d.best_bin_chan, d.n_wf), 'FontSize', 9); + axis(ax, 'tight'); + axis(ax, 'off'); end -%% ---- L-shaped scale bar ---- -% Fixed scale: 2 ms horizontal, 200 µV vertical -sb_ms = 1; % ms -sb_uv = 200; % µV -sb_xlen = sb_ms * t_scale; % µm -sb_ylen = sb_uv * amp_scale; % µm - -% Position: to the right of the bottom-right waveform, at the same y level -[~, bottom_right_ci] = min(ch_y - ch_x * 1e-6); % lowest y, rightmost x as tiebreak -br_cx = ch_x(bottom_right_ci); -br_cy = ch_y(bottom_right_ci); - -sb_gap = 0.2 * x_spacing; % horizontal gap from last waveform -sb_ox = br_cx + t_ms(end) * t_scale + sb_gap; % L corner x: just right of waveform end -sb_oy = br_cy; % L corner y: same level as that channel - -% Draw L: vertical arm then horizontal arm, meeting at bottom-left corner -plot(ax, [sb_ox, sb_ox], [sb_oy, sb_oy - sb_ylen], 'k', 'LineWidth', 2); -plot(ax, [sb_ox, sb_ox + sb_xlen],[sb_oy, sb_oy], 'k', 'LineWidth', 2); - -% Labels -text(ax, sb_ox - 2, sb_oy - sb_ylen/2, sprintf('%d µV', sb_uv), ... - 'FontSize', 8, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle', 'Rotation',90); -text(ax, sb_ox + sb_xlen/2, sb_oy + 2, sprintf('%d ms', sb_ms), ... - 'FontSize', 8, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top'); - -%% Axis cosmetics — no tick marks, no box -title(ax, sprintf('Unit %d | %d waveforms | best ch: %d', ... - unitID, numel(st_sub), best_bin_chan), 'FontSize', 12); -set(ax, 'XTick', [], 'YTick', [], 'YDir', 'normal'); -axis(ax, 'tight'); -box(ax, 'off'); -axis(ax, 'off'); % hide axes entirely — scale bar carries all metric info - -%% ---- ACG figure (separate) ---- +%% ---- ACG figure: one tile per unit ---- if params.showCorr - figure('Color', 'w', 'Name', sprintf('Unit %d — ACG', unitID)); - ax_corr = axes; - - bar(ax_corr, ccg_bins, ccg_counts, 1, ... - 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); - hold(ax_corr, 'on'); - xline(ax_corr, 0, '--k', 'Alpha', 0.4); - - % Shade refractory period (±2 ms) - ylims = ylim(ax_corr); - patch(ax_corr, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... - 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); - - xlabel(ax_corr, 'Lag (ms)'); - ylabel(ax_corr, 'Spike count'); - title(ax_corr, sprintf('Unit %d ACG | RP 2 ms | bin %.1f ms | win ±%d ms', ... - unitID, params.corrBin, params.corrWin), 'FontSize', 12); - xlim(ax_corr, [-params.corrWin params.corrWin]); - box(ax_corr, 'off'); + fig2 = figure('Color', 'w', 'Name', 'ACGs'); + acg_tl = tiledlayout(fig2, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); + title(acg_tl, sprintf('ACG | RP 2 ms | bin %.1f ms | win ±%d ms', ... + params.corrBin, params.corrWin), 'FontSize', 12, 'FontWeight', 'bold'); + xlabel(acg_tl, 'Lag (ms)'); + ylabel(acg_tl, 'Spike count'); + + for ui = 1:nUnits + if ~unitData(ui).valid, continue; end + d = unitData(ui); + + ax_c = nexttile(acg_tl); + bar(ax_c, d.ccg_bins, d.ccg_counts, 1, ... + 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); + hold(ax_c, 'on'); + xline(ax_c, 0, '--k', 'Alpha', 0.4); + + ylims = ylim(ax_c); + patch(ax_c, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... + 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); + + xlim(ax_c, [-params.corrWin params.corrWin]); + title(ax_c, sprintf('Unit %d', d.unitID), 'FontSize', 9); + box(ax_c, 'off'); + end +else + fig2 = []; end end % main function @@ -261,17 +302,10 @@ function plotRawWaveforms(obj, unitID, params) %% ========================================================================= function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) -% Compute auto-correlogram for a single unit -% spike_times_samples - spike times in samples -% fs - sampling rate (Hz) -% win_ms - half-window in ms -% bin_ms - bin size in ms - st_ms = spike_times_samples / fs * 1000; edges = -win_ms : bin_ms : win_ms; bin_centers = edges(1:end-1) + bin_ms / 2; counts = zeros(1, numel(bin_centers)); - for i = 1:numel(st_ms) diffs = st_ms - st_ms(i); diffs(i) = NaN; diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index 78d5485..fb86853 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -61,8 +61,8 @@ %[49:54,84:90,92:96] %All SDG experiments %solve MBR %bootsrapRespBase -VStimAnalysis.PlotZScoreComparison([49:54,64:97],{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=false,... - overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR +VStimAnalysis.PlotZScoreComparison([49:54,64:97],{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=true,ComparePairs={'MB','RG'},PaperFig=true,... + overwriteResponse=true,overwriteStats=true)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% Gratings diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv index f01e17a..39e33d0 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv @@ -159,7 +159,22 @@ wf = getWaveForms(gwfparams); % wf.waveForms: [units x waveforms x channels figure; plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); -%% -plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) +%% Check low amp waveforms 10 neurons per experiment -plotRawWaveforms_spatially(vs, 47, showCorr=true, corrWin=50, corrBin=0.5,nChanAround105) \ No newline at end of file +PVexps = [49:54,64:97]; +idx = randi(length(PVexps), 1, 4); +selected = PVexps(idx); + + + +for i = selected + NP = loadNPclassFromTable(53); + vs = linearlyMovingBallAnalysis(NP,Session=1); + + p = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); + phy_IDg = p.phy_ID(string(p.label') == 'good'); + + + plotRawWaveforms(vs, [47:50], showCorr=true, corrWin=50, corrBin=0.5) + +end diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m index 131be4c..638974d 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m @@ -7,12 +7,12 @@ for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) %%%%%%%%%%%% Load data and data paremeters %1. Load NP class - ex=69 + ex=53 NP = loadNPclassFromTable(ex); vs = linearlyMovingBallAnalysis(NP,Session=1); KSversion =4; - [qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",0,KSversion); + [qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",0,KSversion,1); %convertPhySorting2tIc(obj,pathToPhyResults,tStart,BombCelled) @@ -159,7 +159,22 @@ figure; plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); -%% -plotRawWaveforms(vs, 47, showCorr=true, corrWin=50, corrBin=0.5) +%% Check low amp waveforms 10 neurons per experiment -plotRawWaveforms_spatially(vs, 47, showCorr=true, corrWin=50, corrBin=0.5,nChanAround=10) \ No newline at end of file +PVexps = [49:54,64:97]; +idx = randi(length(PVexps), 1, 4); +selected = PVexps(idx); + + + +for i = selected + NP = loadNPclassFromTable(53); + vs = linearlyMovingBallAnalysis(NP,Session=1); + + p = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); + phy_IDg = p.phy_ID(string(p.label') == 'good'); + + + plotRawWaveforms(vs, [47:50], showCorr=true, corrWin=50, corrBin=0.5) + +end From c790253ec3253da2a9422d10f974176160a64406 Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Thu, 19 Mar 2026 23:24:47 +0200 Subject: [PATCH 4/8] Adding new spatial tuning function and cmodified receptive fields --- general functions/plotRawWaveforms.asv | 305 ------------ general functions/plotRawWaveforms.m | 2 +- .../@VStimAnalysis/BootstrapPerNeuron.m | 56 ++- .../CalculateReceptiveFields.m | 109 ++++- .../@linearlyMovingBallAnalysis/plotRaster.m | 98 +++- .../CalculateReceptiveFields.m | 90 +++- .../@rectGridAnalysis/plotRaster.m | 28 +- .../RunAnalysisClass.asv | 210 ++++++++ visualStimulationAnalysis/RunAnalysisClass.m | 38 +- .../Run_Bombcell_Automatic_Sorting.asv | 180 ------- .../Run_Bombcell_Automatic_Sorting.m | 31 +- .../SpatialTuningIndex.asv | 408 +++++++++++++++ .../SpatialTuningIndex.m | 408 +++++++++++++++ .../SpatialTuningIndexV1.m | 246 ++++++++++ visualStimulationAnalysis/plotPSTH_MultiExp.m | 463 ++++++++++++++++++ .../plotSpatialTuningIndex.m | 189 +++++++ 16 files changed, 2302 insertions(+), 559 deletions(-) delete mode 100644 general functions/plotRawWaveforms.asv create mode 100644 visualStimulationAnalysis/RunAnalysisClass.asv delete mode 100644 visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv create mode 100644 visualStimulationAnalysis/SpatialTuningIndex.asv create mode 100644 visualStimulationAnalysis/SpatialTuningIndex.m create mode 100644 visualStimulationAnalysis/SpatialTuningIndexV1.m create mode 100644 visualStimulationAnalysis/plotPSTH_MultiExp.m create mode 100644 visualStimulationAnalysis/plotSpatialTuningIndex.m diff --git a/general functions/plotRawWaveforms.asv b/general functions/plotRawWaveforms.asv deleted file mode 100644 index 52f21d5..0000000 --- a/general functions/plotRawWaveforms.asv +++ /dev/null @@ -1,305 +0,0 @@ -function [fig1, fig2] = plotRawWaveforms(obj, unitIDs, params) -% plotRawWaveforms - Plot raw spike waveforms from KS4 output, Phy-style -% Each unit is shown in its own tile at true probe positions. -% Optionally plots ACGs for all units in a single tiled figure. -% -% INPUTS: -% obj - Visual stimulation object -% unitIDs - scalar or vector of cluster IDs to plot e.g. 42 or [3 7 42] -% -% OPTIONAL NAME-VALUE PARAMS: -% nWaveforms - number of random waveforms to plot (default: 100) -% nChanAround - nearest channels around max amp channel (default: 10) -% nPre - samples before spike peak (default: 20) -% nPost - samples after spike peak (default: 61) -% showCorr - plot auto-correlogram figure (default: false) -% corrWin - correlogram half-window in ms (default: 100) -% corrBin - correlogram bin size in ms (default: 1) -% -% EXAMPLES: -% plotRawWaveforms(obj, 42) -% plotRawWaveforms(obj, [3 7 42], nWaveforms=200, nChanAround=6) -% plotRawWaveforms(obj, [3 7 42], showCorr=true, corrWin=50, corrBin=0.5) - -arguments (Input) - obj - unitIDs (1,:) double - params.nWaveforms (1,1) double = 100 - params.nChanAround (1,1) double = 10 - params.nPre (1,1) double = 20 - params.nPost (1,1) double = 61 - params.showCorr (1,1) logical = false - params.corrWin (1,1) double = 100 - params.corrBin (1,1) double = 1 -end - -nUnits = numel(unitIDs); - -%% Paths -ksDir = obj.spikeSortingFolder; -recordingDir = obj.dataObj.recordingDir; - -%% Settings from obj -n_channels = str2double(obj.dataObj.nSavedChansImec); -sample_rate = obj.dataObj.samplingFrequency; -uV_per_bit = unique(obj.dataObj.MicrovoltsPerAD); -chPos = obj.dataObj.chLayoutPositions; % [2 x nAllCh]: row1=x, row2=y - -fprintf('Settings — nCh: %d | Fs: %d Hz | uV/bit: %.4f\n', ... - n_channels, sample_rate, uV_per_bit); - -%% Find binary file -binFiles = dir(fullfile(recordingDir, '*.bin')); -if isempty(binFiles), binFiles = dir(fullfile(recordingDir, '*.dat')); end -if isempty(binFiles), error('No .bin or .dat file found in: %s', recordingDir); end -binPath = fullfile(recordingDir, binFiles(1).name); -fprintf('Using binary file: %s\n', binPath); - -%% Load KS4 output (once, shared across all units) -spike_times = readNPY(fullfile(ksDir, 'spike_times.npy')); -spike_clusters = readNPY(fullfile(ksDir, 'spike_clusters.npy')); -templates = readNPY(fullfile(ksDir, 'templates.npy')); % [nUnits x T x nCh] -chan_map = readNPY(fullfile(ksDir, 'channel_map.npy')); % [nCh x 1], 0-indexed -chan_pos = readNPY(fullfile(ksDir, 'channel_positions.npy')); % [nCh x 2] - -unit_ids_ks = (0 : size(templates, 1) - 1)'; - -%% Probe pitch (shared across all units) -x_unique = unique(chPos(1,:)); -y_unique = unique(chPos(2,:)); -x_spacing = min(diff(sort(x_unique))); -y_spacing = min(diff(sort(y_unique))); -if isempty(x_spacing) || numel(x_unique) == 1, x_spacing = 32; end -if isempty(y_spacing) || numel(y_unique) == 1, y_spacing = 20; end - -t_ms = (-params.nPre : params.nPost) / sample_rate * 1000; - -%% Colours -col_default = [0.25 0.45 0.75]; % blue -col_best = [0.85 0.20 0.15]; % red - -%% ---- Extract data for each unit ---- -finfo = dir(binPath); -n_samp_total = finfo.bytes / (n_channels * 2); -fid = fopen(binPath, 'rb'); - -unitData = struct(); % will hold per-unit results - -for ui = 1:nUnits - unitID = unitIDs(ui); - - % Template index - tmpl_idx = find(unit_ids_ks == unitID); - if isempty(tmpl_idx) - warning('Unit %d not found in templates.npy, skipping.', unitID); - unitData(ui).valid = false; - continue - end - - % Best channel by p2p on template - unit_template = squeeze(templates(tmpl_idx, :, :)); % [T x nCh] - p2p = max(unit_template) - min(unit_template); - [~, best_tmpl_chan] = max(p2p); - - % nChanAround nearest channels by Euclidean distance on probe - best_xy = chan_pos(best_tmpl_chan, :); - dists = sqrt(sum((chan_pos - best_xy).^2, 2)); - [~, sorted_idx] = sort(dists, 'ascend'); - chan_indices = sorted_idx(1 : min(params.nChanAround + 1, numel(dists)))'; - n_chans_plot = numel(chan_indices); - best_local_idx = find(chan_indices == best_tmpl_chan); - - bin_chans = chan_map(chan_indices) + 1; % 1-indexed - best_bin_chan = bin_chans(best_local_idx); - - % Spike times for this unit - st = double(spike_times(spike_clusters == unitID)); - if numel(st) < 2 - warning('Unit %d has fewer than 2 spikes, skipping.', unitID); - unitData(ui).valid = false; - continue - end - - % Random subsample - idx = randperm(numel(st), min(params.nWaveforms, numel(st))); - st_sub = st(idx); - fprintf('Unit %d: %d total spikes, extracting %d waveforms\n', ... - unitID, numel(st), numel(st_sub)); - - % Extract waveforms - waveform_len = params.nPre + params.nPost + 1; - waveforms = NaN(n_chans_plot, waveform_len, numel(st_sub)); - - for si = 1:numel(st_sub) - s0 = st_sub(si) - params.nPre; - s1 = st_sub(si) + params.nPost; - if s0 < 1 || s1 > n_samp_total, continue; end - fseek(fid, (s0 - 1) * n_channels * 2, 'bof'); - raw = fread(fid, [n_channels, waveform_len], '*int16'); - if size(raw, 2) < waveform_len, continue; end - waveforms(:, :, si) = double(raw(bin_chans, :)) * uV_per_bit; - end - - % Baseline subtract - baseline = mean(waveforms(:, 1:params.nPre, :), 2, 'omitnan'); - waveforms = waveforms - baseline; - - % Store - unitData(ui).valid = true; - unitData(ui).unitID = unitID; - unitData(ui).waveforms = waveforms; - unitData(ui).mean_wf = mean(waveforms, 3, 'omitnan'); - unitData(ui).std_wf = std(waveforms, 0, 3, 'omitnan'); - unitData(ui).bin_chans = bin_chans; - unitData(ui).best_bin_chan = best_bin_chan; - unitData(ui).best_local_idx= best_local_idx; - unitData(ui).n_chans_plot = n_chans_plot; - unitData(ui).ch_x = chPos(1, bin_chans); - unitData(ui).ch_y = chPos(2, bin_chans); - unitData(ui).st = st; - unitData(ui).n_wf = numel(st_sub); - - % ACG - if params.showCorr - [unitData(ui).ccg_counts, unitData(ui).ccg_bins] = ... - computeACG(st, sample_rate, params.corrWin, params.corrBin); - end -end -fclose(fid); - -%% ---- Waveform figure: one tile per unit ---- -% Determine tiled layout dimensions -nCols = min(nUnits, 4); -nRows = ceil(nUnits / nCols); - -fig1 = figure('Color', 'w', 'Name', 'Waveforms'); -wf_tl = tiledlayout(fig1, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); -title(wf_tl, 'Raw Waveforms', 'FontSize', 13, 'FontWeight', 'bold'); - -for ui = 1:nUnits - if ~unitData(ui).valid, continue; end - - d = unitData(ui); - mean_wf = d.mean_wf; - std_wf = d.std_wf; - ch_x = d.ch_x; - ch_y = d.ch_y; - bin_chans = d.bin_chans; - best_local_idx = d.best_local_idx; - n_chans_plot = d.n_chans_plot; - - % Per-unit amplitude scale: use mean±std envelope to prevent overlap - % on noisy units (large std compresses the scale automatically) - upper_env = max(mean_wf + std_wf, [], 2); % [nCh x 1] - lower_env = min(mean_wf - std_wf, [], 2); - max_extent = max(upper_env - lower_env); - if max_extent == 0, max_extent = 1; end - amp_scale = 0.8 * y_spacing / max_extent; - t_scale = 0.8 * x_spacing / (t_ms(end) - t_ms(1)); - - % Scale bar µV: round max amplitude to nearest 50 µV - sb_uv = max(50, round(max_extent / 50) * 50); - - ax = nexttile(wf_tl); - hold(ax, 'on'); - - for ci = 1:n_chans_plot - cx = ch_x(ci); - cy = ch_y(ci); - col = col_default; - if ci == best_local_idx, col = col_best; end - - x_wf = cx + t_ms * t_scale; - - % Individual waveforms - wf_ci = squeeze(d.waveforms(ci, :, :)); - plot(ax, x_wf, cy + wf_ci * amp_scale, ... - 'Color', [col, 0.12], 'LineWidth', 0.5); - - % Std shading - upper = cy + (mean_wf(ci,:) + std_wf(ci,:)) * amp_scale; - lower = cy + (mean_wf(ci,:) - std_wf(ci,:)) * amp_scale; - fill(ax, [x_wf, fliplr(x_wf)], [upper, fliplr(lower)], ... - col, 'FaceAlpha', 0.2, 'EdgeColor', 'none'); - - % Mean waveform (black), with coloured std shading - plot(ax, x_wf, cy + mean_wf(ci,:) * amp_scale, ... - 'Color', 'k', 'LineWidth', 2); - - % Channel label (two rows, left of waveform start) - text(ax, x_wf(1) - 2, cy, ... - sprintf('ch%d\n(%g,%g)', bin_chans(ci), cx, cy), ... - 'FontSize', 6, 'HorizontalAlignment', 'right', ... - 'VerticalAlignment', 'middle', 'Color', col); - end - - % L-scale bar: bottom-right channel of this unit - sb_ms = 1; % sb_uv already set above - sb_xlen = sb_ms * t_scale; - sb_ylen = sb_uv * amp_scale; - - [~, br_ci] = min(ch_y - ch_x * 1e-6); - sb_ox = ch_x(br_ci) + t_ms(end) * t_scale + 0.2 * x_spacing; - sb_oy = ch_y(br_ci); - - plot(ax, [sb_ox, sb_ox], [sb_oy, sb_oy - sb_ylen], 'k', 'LineWidth', 2); - plot(ax, [sb_ox, sb_ox + sb_xlen], [sb_oy, sb_oy], 'k', 'LineWidth', 2); - text(ax, sb_ox - 2, sb_oy - sb_ylen/2, sprintf('%d µV', sb_uv), ... - 'FontSize', 7, 'HorizontalAlignment', 'center', ... - 'VerticalAlignment', 'middle', 'Rotation', 90); - text(ax, sb_ox + sb_xlen/2, sb_oy + 2, sprintf('%d ms', sb_ms), ... - 'FontSize', 7, 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top'); - - title(ax, sprintf('Unit %d | ch%d | n=%d', ... - d.unitID, d.best_bin_chan, d.n_wf), 'FontSize', 9); - axis(ax, 'tight'); - axis(ax, 'off'); -end - -%% ---- ACG figure: one tile per unit ---- -if params.showCorr - fig2 = figure('Color', 'w', 'Name', 'ACGs'); - acg_tl = tiledlayout(fig2, nRows, nCols, 'TileSpacing', 'compact', 'Padding', 'compact'); - title(acg_tl, sprintf('ACG | RP 2 ms | bin %.1f ms | win ±%d ms', ... - params.corrBin, params.corrWin), 'FontSize', 12, 'FontWeight', 'bold'); - xlabel(acg_tl, 'Lag (ms)'); - ylabel(acg_tl, 'Spike count'); - - for ui = 1:nUnits - if ~unitData(ui).valid, continue; end - d = unitData(ui); - - ax_c = nexttile(acg_tl); - bar(ax_c, d.ccg_bins, d.ccg_counts, 1, ... - 'FaceColor', [0.3 0.5 0.8], 'EdgeColor', 'none'); - hold(ax_c, 'on'); - xline(ax_c, 0, '--k', 'Alpha', 0.4); - - ylims = ylim(ax_c); - patch(ax_c, [-2 2 2 -2], [0 0 ylims(2) ylims(2)], ... - 'r', 'FaceAlpha', 0.1, 'EdgeColor', 'none'); - - xlim(ax_c, [-params.corrWin params.corrWin]); - title(ax_c, sprintf('Unit %d', d.unitID), 'FontSize', 9); - box(ax_c, 'off'); - end -else - fig2 = []; -end - -end % main function - - -%% ========================================================================= -function [counts, bin_centers] = computeACG(spike_times_samples, fs, win_ms, bin_ms) -st_ms = spike_times_samples / fs * 1000; -edges = -win_ms : bin_ms : win_ms; -bin_centers = edges(1:end-1) + bin_ms / 2; -counts = zeros(1, numel(bin_centers)); -for i = 1:numel(st_ms) - diffs = st_ms - st_ms(i); - diffs(i) = NaN; - diffs = diffs(diffs > -win_ms & diffs < win_ms); - counts = counts + histcounts(diffs, edges); -end -end \ No newline at end of file diff --git a/general functions/plotRawWaveforms.m b/general functions/plotRawWaveforms.m index b39bfb5..1c129f0 100644 --- a/general functions/plotRawWaveforms.m +++ b/general functions/plotRawWaveforms.m @@ -25,7 +25,7 @@ obj unitIDs (1,:) double params.nWaveforms (1,1) double = 100 - params.nChanAround (1,1) double = 10 + params.nChanAround (1,1) double = 8 params.nPre (1,1) double = 20 params.nPost (1,1) double = 61 params.showCorr (1,1) logical = false diff --git a/visualStimulationAnalysis/@VStimAnalysis/BootstrapPerNeuron.m b/visualStimulationAnalysis/@VStimAnalysis/BootstrapPerNeuron.m index 7c5c987..f216e72 100644 --- a/visualStimulationAnalysis/@VStimAnalysis/BootstrapPerNeuron.m +++ b/visualStimulationAnalysis/@VStimAnalysis/BootstrapPerNeuron.m @@ -3,8 +3,8 @@ arguments (Input) obj params.nBoot = 10000 - params.EmptyTrialPerc = 0.6 - params.FilterEmptyResponses = false + params.EmptyTrialPerc = 0.7 %If empty trials per category are higher than EmptyTrialPerc then filter + params.FilterEmptyResponses = true params.overwrite = false end % Computes per-neuron z-scores of stimulus responses vs baseline using bootstrap @@ -24,10 +24,61 @@ p = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); label = string(p.label'); goodU = p.ic(:,label == 'good'); %somatic neurons +responseParams = obj.ResponseWindow; + if isempty(goodU) warning('%s has No somatic Neurons, skipping experiment/n',obj.dataObj.recordingName) + results = []; + fprintf('Saving results to file.\n'); + if isequal(obj.stimName, 'linearlyMovingBall') + % S.(fieldName).BootResponse = respBoot; + % S.(fieldName).BootBaseline = baseBoot; + S.Speed1.BootDiff = []; + S.Speed1.pvalsResponse = []; + S.Speed1.ZScoreU = []; + S.Speed1.ObsDiff = []; + S.Speed1.ObsReponse = []; + S.Speed1.ObsBaseline = []; + + if isfield(responseParams, "Speed2") + S.Speed2.BootDiff = []; + S.Speed2.pvalsResponse = []; + S.Speed2.ZScoreU = []; + S.Speed2.ObsDiff = []; + S.Speed2.ObsReponse = []; + S.Speed2.ObsBaseline = []; + end + elseif isequal(obj.stimName,'StaticDriftingGrating') + % S.(fieldName).BootResponse = respBoot; + % S.(fieldName).BootBaseline = baseBoot; + S.Moving.BootDiff = []; + S.Moving.pvalsResponse = []; + S.Moving.ZScoreU = []; + S.Moving.ObsDiff = []; + S.Moving.ObsReponse = []; + S.Moving.ObsBaseline = []; + + S.Static.BootDiff = []; + S.Static.pvalsResponse = []; + S.Static.ZScoreU = []; + S.Static.ObsDiff = []; + S.Static.ObsReponse = []; + S.Static.ObsBaseline = []; + else + % S.BootResponse = respBoot; + % S.BootBaseline = baseBoot; + S.BootDiff = []; + S.pvalsResponse = []; + S.ZScoreU = []; + S.ObsDiff = []; + S.ObsReponse = []; + S.ObsBaseline = []; + end + + S.params = params; + save(obj.getAnalysisFileName,'-struct', 'S'); return end @@ -40,7 +91,6 @@ end -responseParams = obj.ResponseWindow; %%If it is a moving stimulus with speed cathegories if isfield(responseParams, "Speed1") diff --git a/visualStimulationAnalysis/@linearlyMovingBallAnalysis/CalculateReceptiveFields.m b/visualStimulationAnalysis/@linearlyMovingBallAnalysis/CalculateReceptiveFields.m index fd97c8f..b587d59 100644 --- a/visualStimulationAnalysis/@linearlyMovingBallAnalysis/CalculateReceptiveFields.m +++ b/visualStimulationAnalysis/@linearlyMovingBallAnalysis/CalculateReceptiveFields.m @@ -16,6 +16,8 @@ params.nShuffle = 2 %Number of shuffles to generate shuffled receptive fields. params.testConvolution = false params.reduceFactor = 20 %reduce factor for screen resolution + params.statType string = "BootstrapPerNeuron" + params.nGrid = 9 end if params.inputParams,disp(params),return,end @@ -37,10 +39,18 @@ end NeuronResp = obj.ResponseWindow; -Stats = obj.ShufflingAnalysis; -goodU = NeuronResp.goodU; + +% Stats struct for p-values +if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; +else + Stats = obj.ShufflingAnalysis; +end + p = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); phy_IDg = p.phy_ID(string(p.label') == 'good'); +label = string(p.label'); +goodU = p.ic(:, label == 'good'); fieldName = sprintf('Speed%d', params.speed); pvals = Stats.(fieldName).pvalsResponse; @@ -101,6 +111,7 @@ trialDivisionVid = size(C,1)/numel(unique(C(:,2)))/numel(unique(C(:,3)))/numel(unique(C(:,4)))... /numel(unique(C(:,5))); + %%%Create a matrix with trials that have unique positions ChangePosX = zeros(sizeX(1)*sizeX(2)*sizeX(3)*sizeN*trialDivisionVid,sizeX(4)); ChangePosY = zeros(sizeX(1)*sizeX(2)*sizeX(3)*sizeN*trialDivisionVid,sizeX(4)); @@ -308,7 +319,7 @@ if params.testConvolution %%%Test convolution spikeSumArt = zeros(size(spikeSum)); -spikeSumArt([1:10 91:100 181:190 271:280],1,end-20:end)=1;%Spiking at the end of first offset +spikeSumArt([1:10 91:100 161:180],1,end-20:end)=1;%Spiking at the end of first offset %spikeSumArt(nT/4+1:nT/4+15,7,end-20:end) =1;% spikeSumsDiv{1}{1} = spikeSumArt; figure;imagesc(squeeze( spikeSumsDiv{1}{1}(:,:,:)));colormap(flipud(gray(64))); @@ -358,19 +369,37 @@ % Generate video trials (this part stays the same) videoTrials = zeros(nT/trialDivisionVid,redCoorY,redCoorY,sizeX(4),'single'); h =1; +pad = ceil(max(sizesU)/(2*reduceFactor)) + 1; +[x_pad, y_pad] = meshgrid(1:redCoorY + 2*pad, 1:redCoorY + 2*pad); +x_pad = fliplr(x_pad); +cropOffsetX = (redCoorX - redCoorY)/2; + for i = 1:trialDivisionVid:nT + % for j = 1:sizeX(4) + % xyScreen = zeros(redCoorY,redCoorX,"single"); + % centerX = ChangePosX(i,j)/reduceFactor; + % centerY = ChangePosY(i,j)/reduceFactor; + % radius = sizeV(i)/2; + % distances = sqrt((x - centerX).^2 + (y - centerY).^2); + % xyScreen(distances <= radius/reduceFactor+0.5) = 1; + % videoTrials(h,:,:,j) = xyScreen(:,(redCoorX-redCoorY)/2+1:(redCoorX-redCoorY)/2+redCoorY); + % end + for j = 1:sizeX(4) - xyScreen = zeros(redCoorY,redCoorX,"single"); - centerX = ChangePosX(i,j)/reduceFactor; - centerY = ChangePosY(i,j)/reduceFactor; + xyScreen = zeros(redCoorY + 2*pad, redCoorY + 2*pad, "single"); + centerX = ChangePosX(i,j)/reduceFactor - cropOffsetX + pad; + centerY = ChangePosY(i,j)/reduceFactor + pad; radius = sizeV(i)/2; - distances = sqrt((x - centerX).^2 + (y - centerY).^2); - xyScreen(distances <= radius/reduceFactor+0.5) = 1; - videoTrials(h,:,:,j) = xyScreen(:,(redCoorX-redCoorY)/2+1:(redCoorX-redCoorY)/2+redCoorY); + distances = sqrt((x_pad - centerX).^2 + (y_pad - centerY).^2); + xyScreen(distances <= radius/reduceFactor + 0.5) = 1; + % Crop back to original square size, removing padding + videoTrials(h,:,:,j) = xyScreen(pad+1:pad+redCoorY, pad+1:pad+redCoorY); end h = h+1; end +%implay(squeeze(videoTrials(9,:,:,:))); + for t = 1:numel(IndexDiv) for q = 1:numel(IndexQ) @@ -406,7 +435,11 @@ % videoTrials(:,:,j) = xyScreen(:,(redCoorX-redCoorY)/2+1:(redCoorX-redCoorY)/2+redCoorY); % end - videoTrialsi = squeeze(videoTrials(ceil(p/2),:,:,:)); + if trialDivision*2 == trialDivisionVid %two luminosities are used, so trial division for videos are twicethe trialdivision + videoTrialsi = squeeze(videoTrials(ceil(p/2),:,:,:)); + else + videoTrialsi = squeeze(videoTrials(p,:,:,:)); + end % OPTIMIZATION 3: Vectorized spike mean calculation spikeMean = mean(spikeSum(i:i+trialDivision-1,:,:), 'omitnan'); spikeMeanShuff = mean(shuffledData(i:i+trialDivision-1,:,:,:), 'omitnan'); @@ -456,6 +489,58 @@ %implay(squeeze(RFu(:,:,:,1))); %implay(videoTrials) + %%%%%%%%%% Spike rate grid map + nGrid = params.nGrid; + cropOffsetX = (redCoorX - redCoorY)/2; + + xMin = cropOffsetX * reduceFactor; + xMax = (cropOffsetX + redCoorY) * reduceFactor; + yMin = 0; + yMax = redCoorY * reduceFactor; + + xEdges = linspace(xMin, xMax, nGrid+1); + yEdges = linspace(yMin, yMax, nGrid+1); + + gridSpikeRate = zeros(nGrid, nGrid, nNeurons, length(Usize), length(Ulum)); + gridSpikeRateShuff = zeros(nGrid, nGrid, nNeurons, nShuffle, length(Usize), length(Ulum)); + trialCount = zeros(nGrid, nGrid, length(Usize), length(Ulum)); + + for i = 1:nT + xPos = mean(ChangePosX(i,:)); + yPos = mean(ChangePosY(i,:)); + + xBin = discretize(xPos, xEdges); + yBin = discretize(yPos, yEdges); + + if isnan(xBin) || isnan(yBin) + continue + end + + sizeIdx = find(Usize == C(i,5)); + lumIdx = find(Ulum == C(i,6)); + + trialCount(yBin, xBin, sizeIdx, lumIdx) = trialCount(yBin, xBin, sizeIdx, lumIdx) + 1; + + gridSpikeRate(yBin, xBin, :, sizeIdx, lumIdx) = gridSpikeRate(yBin, xBin, :, sizeIdx, lumIdx) + ... + reshape(mean(spikeSum(i,:,:), 3), [1 1 nNeurons]); + + for s = 1:nShuffle + gridSpikeRateShuff(yBin, xBin, :, s, sizeIdx, lumIdx) = gridSpikeRateShuff(yBin, xBin, :, s, sizeIdx, lumIdx) + ... + reshape(mean(shuffledData(i,:,:,s), 3), [1 1 nNeurons]); + end + end + + % Normalize by trial count + for si = 1:length(Usize) + for li = 1:length(Ulum) + tc = max(trialCount(:,:,si,li), 1); + gridSpikeRate(:,:,:,si,li) = gridSpikeRate(:,:,:,si,li) ./ tc; + for s = 1:nShuffle + gridSpikeRateShuff(:,:,:,s,si,li) = gridSpikeRateShuff(:,:,:,s,si,li) ./ tc; + end + end + end + %%%%%%%%%% Normalization parameters L = size(spikeSum,3); time_zero_index = ceil(L / 2); @@ -501,6 +586,8 @@ names = {'X','Y'}; + %figure;imagesc(squeeze(RFuDirSizeLumFilt(1,:,:,:,:))); + if params.noEyeMoves save(sprintf('NEM-RFuSTDirSizeLumFilt-Q%d-Div-%s-%s',q,names{t},NP.recordingName),'RFuDirSizeLumFilt','-v7.3') save(sprintf('NEM-RFuSTDirSize-Q%d-Div-%s-%s',q,names{t},NP.recordingName),'RFuSTDirSize','-v7.3') @@ -528,6 +615,8 @@ S.RFuSTDirSizeLum = RFuSTDirSizeLum; S.RFuST = RFuST; S.RFuShuffST = RFuShuffST; + S.gridSpikeRate = gridSpikeRate; + S.gridSpikeRateShuff = gridSpikeRateShuff; save(sprintf('%s-Speed-%d.mat',filename,params.speed),'-struct','S'); results = S; end diff --git a/visualStimulationAnalysis/@linearlyMovingBallAnalysis/plotRaster.m b/visualStimulationAnalysis/@linearlyMovingBallAnalysis/plotRaster.m index 67ea352..2575e8f 100644 --- a/visualStimulationAnalysis/@linearlyMovingBallAnalysis/plotRaster.m +++ b/visualStimulationAnalysis/@linearlyMovingBallAnalysis/plotRaster.m @@ -6,7 +6,7 @@ function plotRaster(obj,params) params.analysisTime = datetime('now') params.inputParams = false params.preBase = 200 - params.bin = 15 + params.bin = 30 params.exNeurons = 1 params.AllSomaticNeurons = false params.AllResponsiveNeurons = false @@ -15,16 +15,25 @@ function plotRaster(obj,params) params.MergeNtrials =1 params.oneTrial = false params.GaussianLength = 10 + params.Gaussian logical = false params.MaxVal_1 =true params.useNormTrialWindow = false params.OneDirection string = "all" params.OneLuminosity string = "all" params.PaperFig logical = false + params.statType string = "BootstrapPerNeuron" end NeuronResp = obj.ResponseWindow; -Stats = obj.ShufflingAnalysis; + +if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; +else + Stats = obj.ShufflingAnalysis; +end + + if params.speed ~= "max" fieldName = sprintf('Speed%d', str2double(params.speed)); @@ -46,12 +55,13 @@ function plotRaster(obj,params) end -goodU = NeuronResp.goodU; p = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); phy_IDg = p.phy_ID(string(p.label') == 'good'); pvals = Stats.(fieldName).pvalsResponse; stimDur = NeuronResp.(fieldName).stimDur; stimInter = NeuronResp.stimInter; +label = string(p.label'); +goodU = p.ic(:,label == 'good'); %somatic neurons C = NeuronResp.(fieldName).C; @@ -73,9 +83,9 @@ function plotRaster(obj,params) if params.OneLuminosity ~= "all" switch params.OneLuminosity case "black" - C = NeuronResp.(fieldName).C(round(C(:,6), 2)==1,:); + C = C(round(C(:,6), 2)==1,:); case "white" - C = NeuronResp.(fieldName).C(round(C(:,6), 2)==255,:); + C = C(round(C(:,6), 2)==255,:); otherwise error("Unknown inputPa value: %s", params.OneLuminosity) end @@ -121,7 +131,9 @@ function plotRaster(obj,params) [Mr] = BuildBurstMatrix(goodU(:,eNeuron),round(p.t/params.bin),round((directimesSorted-preBase)/params.bin),round((stimDur+preBase*2)/params.bin)); -[Mr]=ConvBurstMatrix(Mr,fspecial('gaussian',[1 params.GaussianLength],3),'same'); +if params.Gaussian + [Mr]=ConvBurstMatrix(Mr,fspecial('gaussian',[1 params.GaussianLength],3),'same'); +end channels = goodU(1,eNeuron); @@ -257,15 +269,50 @@ function plotRaster(obj,params) maxRespIn = maxRespIn-1; X = squeeze(Mr2(maxRespIn*trialDivision+1:maxRespIn*trialDivision + trialDivision,:,:)); window = 500; %in ms - % Moving mean across 2nd dimension - mm = movmean(X, round(window/params.bin), 2, 'Endpoints', 'discard'); - % Average across rows to get kernel score - score = mean(mm, 1); - % Find max kernel location - [maxVal, idx] = max(score); + + + % % Moving mean across 2nd dimension + % mm = movmean(X, round(window/params.bin), 2, 'Endpoints', 'discard'); + % % Average across rows to get kernel score + % score = mean(mm, 1); + % % Find max kernel location + % [maxVal, idx] = max(score); + + X(X>1) = 1; + [n_rows, n_cols] = size(X); + n_windows = n_cols - round(window/params.bin) + 1; + + % Compute mean for every sliding window in every row + % Result: 20 x n_windows matrix + window_means = zeros(n_rows, n_windows); + for col = 1:n_windows + window_means(:, col) = mean(X(:, col:col+round(window/params.bin)-1), 2); + end + + % Find the overall maximum mean across all rows and windows + [~, linear_idx] = max(window_means(:)); + + % Convert linear index to (row, col) — col = start of window + [best_row, best_col] = ind2sub(size(window_means), linear_idx); % Kernel column range - start = idx; + start = best_col*params.bin; + + + % % --- Plot --- + % figure; + % imagesc(X); + % colorbar; + % axis tight; + % hold on; + % + % % Highlight the full best row (horizontal span) + % rectangle('Position', [0.5, best_row - 0.5, n_cols, 1], ... + % 'EdgeColor', 'r', 'LineWidth', 1.5, 'LineStyle', '--'); + % + % % Highlight the selected window (column span within best row) + % rectangle('Position', [best_col - 0.5, best_row - 0.5, round(window/params.bin), 1], ... + % 'EdgeColor', 'y', 'LineWidth', 2.5); else if params.useNormTrialWindow @@ -292,18 +339,23 @@ function plotRaster(obj,params) 'k','FaceAlpha',0.1,'EdgeColor','none') - TrialM = squeeze(Mr2(trials,:,round((preBase+start)/params.bin):round((preBase+start+window)/params.bin)))'; + % TrialM = squeeze(Mr2(trials,round((preBase+start)/params.bin):round((preBase+start+window)/params.bin)))'; + % + % [mxTrial TrialNumber] = max(sum(TrialM)); - [mxTrial TrialNumber] = max(sum(TrialM)); + RasterTrials = trials(best_row); - RasterTrials = trials(TrialNumber); + % patch([(preBase+start)/params.bin (preBase+start+window)/params.bin (preBase+start+window)/params.bin (preBase+start)/params.bin],... + % [RasterTrials-0.5 RasterTrials-0.5 RasterTrials+0.5 RasterTrials+0.5],... + % 'r','FaceAlpha',0.3,'EdgeColor','none') - patch([(preBase+start)/params.bin (preBase+start+window)/params.bin (preBase+start+window)/params.bin (preBase+start)/params.bin],... + patch([(start)/params.bin (start+window)/params.bin (start+window)/params.bin (start)/params.bin],... [RasterTrials-0.5 RasterTrials-0.5 RasterTrials+0.5 RasterTrials+0.5],... 'r','FaceAlpha',0.3,'EdgeColor','none') + %%%%%% Plot PSTH %%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% @@ -362,7 +414,7 @@ function plotRaster(obj,params) xlabel('Time [s]','FontSize',10,'FontName','helvetica'); ylims = ylim; - yticks([round(ylims(2)/10)*5 round(ylims(2)/10)*10]) + yticks([round(ylims(2)/10)*5 ceil(ylims(2)/10)*10]) %%%%PLot raw data several trials one @@ -370,19 +422,21 @@ function plotRaster(obj,params) %Mark selected trial - bin3 = 2; + bin3 = 1; trialM = BuildBurstMatrix(goodU(:,u),round(p.t/bin3),round((directimesSorted+start)/bin3),round((window)/bin3)); TrialM = squeeze(trialM(trials,:,:))'; - [mxTrial TrialNumber] = max(sum(TrialM)); + [mxTrial TrialNumber] = max(mean(TrialM)); + + %RasterTrials = trials(TrialNumber); - RasterTrials = trials(TrialNumber); + RasterTrials = trials(best_row); chan = goodU(1,u); subplot(18,1,[1 3]) - startTimes = directimesSorted(RasterTrials)+start; + startTimes = directimesSorted(RasterTrials)+start-preBase; freq = "AP"; %or "LFP" diff --git a/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m b/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m index 9aae3eb..7465c3e 100644 --- a/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m +++ b/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m @@ -19,6 +19,8 @@ params.durationOff = 3000; params.offsetR = 50; %Response after onset of stim params.TakeAllStimDur = true %calculate the receptive fields taking into account the whole window + params.statType string = "BootstrapPerNeuron" + params.nGrid = 9 end @@ -39,10 +41,18 @@ end NeuronResp = obj.ResponseWindow; -Stats = obj.ShufflingAnalysis; -goodU = NeuronResp.goodU; + +% Stats struct for p-values +if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; +else + Stats = obj.ShufflingAnalysis; +end + p = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); phy_IDg = p.phy_ID(string(p.label') == 'good'); +label = string(p.label'); +goodU = p.ic(:, label == 'good'); pvals = Stats.pvalsResponse; C = NeuronResp.C; @@ -76,8 +86,10 @@ params.durationOff = NeuronResp.stimInter; end -[Mr] = BuildBurstMatrix(goodU,round(p.t/params.bin),round((directimesSorted+params.offsetR)/params.bin),round(params.duration/params.bin)); -[Mro] = BuildBurstMatrix(goodU,round(p.t/params.bin),round((directimesSorted+stimDur)/params.bin),round(params.durationOff/params.bin)); +durationMin = min([params.duration params.durationOff]); + +[Mr] = BuildBurstMatrix(goodU,round(p.t/params.bin),round((directimesSorted+params.offsetR)/params.bin),round(durationMin/params.bin)); +[Mro] = BuildBurstMatrix(goodU,round(p.t/params.bin),round((directimesSorted+stimDur)/params.bin),round(durationMin/params.bin)); % Mr = Mr.*(1000/params.bin); %convert to seconds % Mro = Mro.*(1000/params.bin); %convert to seconds @@ -182,7 +194,7 @@ %%Create summary of identical trials - for u = 1:length(goodU) + for u = 1:size(goodU,2) for o = 1:2 @@ -215,6 +227,10 @@ rectData = obj.VST.rectData; +% Before the loop: +XcStore = zeros(1, size(C,1)/trialDiv); +YcStore = zeros(1, size(C,1)/trialDiv); + j=1; for i = 1:trialDiv:length(C) @@ -229,6 +245,9 @@ Yc = round((rectData.Y4{1,C(i,3)}(C(i,2))-rectData.Y1{1,C(i,3)}(C(i,2)))/2)+rectData.Y1{1,C(i,3)}(C(i,2));%... Yc = Yc/params.reduceFactor; + XcStore(j) = Xc; % still in pixel coords at this point + YcStore(j) = Yc; + r = round((rectData.X2{1,C(i,3)}(C(i,2))-rectData.X1{1,C(i,3)}(C(i,2)))/2); r= r/params.reduceFactor; @@ -255,8 +274,61 @@ end -% M = MrMean(:,u)'./Nbase(u); +%%%%%%%%%% Spike rate grid map +nGrid = params.nGrid; +xEdges = linspace(0, screenSide(3)/params.reduceFactor, nGrid+1); % reduced pixel coords +yEdges = linspace(0, screenSide(4)/params.reduceFactor, nGrid+1); + +gridSpikeRate = zeros(nGrid, nGrid, nN, 2, nSize, nLums); +gridSpikeRateShuff = zeros(nGrid, nGrid, nN, nShuffle, 2, nSize, nLums); +trialCount = zeros(nGrid, nGrid, nSize, nLums); + +jj = 1; +for i = 1:trialDiv:nT + + xBin = discretize(XcStore(jj), xEdges); + yBin = discretize(YcStore(jj), yEdges); + + if isnan(xBin) || isnan(yBin) + jj = jj + 1; + continue + end + + sizeIdx = find(uSize == C(i,3)); + lumIdx = find(uLums == C(i,4)); + + trialCount(yBin, xBin, sizeIdx, lumIdx) = trialCount(yBin, xBin, sizeIdx, lumIdx) + 1; + % On and off response + onRate = reshape(mean(mean(Mr( i:i+trialDiv-1,:,:), 1), 3) .* (1000/params.bin), [1 1 nN]); + offRate = reshape(mean(mean(Mro(i:i+trialDiv-1,:,:), 1), 3) .* (1000/params.bin), [1 1 nN]); + + gridSpikeRate(yBin, xBin, :, 1, sizeIdx, lumIdx) = gridSpikeRate(yBin, xBin, :, 1, sizeIdx, lumIdx) + onRate; + gridSpikeRate(yBin, xBin, :, 2, sizeIdx, lumIdx) = gridSpikeRate(yBin, xBin, :, 2, sizeIdx, lumIdx) + offRate; + + for s = 1:nShuffle + shuffRate = reshape(mean(mean(shuffledData(i:i+trialDiv-1,:,:,s), 1), 3), [1 1 nN]); + gridSpikeRateShuff(yBin, xBin, :, s, sizeIdx, lumIdx) = ... + gridSpikeRateShuff(yBin, xBin, :, s, sizeIdx, lumIdx) + shuffRate; + end + + jj = jj + 1; +end + +% Normalize by trial count +for si = 1:nSize + for li = 1:nLums + tc = max(trialCount(:,:,si,li), 1); % [nGrid x nGrid] + for s = 1:nShuffle + gridSpikeRateShuff(:,:,:,s,si,li) = gridSpikeRateShuff(:,:,:,s,si,li) ./ tc; + end + for oi = 1:2 + gridSpikeRate(:,:,:,oi,si,li) = gridSpikeRate(:,:,:,oi,si,li) ./ tc; + end + end +end + +%%%%%% Convolution VD = reshape(VideoScreen,[1 1 1 size(VideoScreen,1) size(VideoScreen,1) size(VideoScreen,3)]); VD = repmat(VD,[1,1,1,1,1,1,nN]); @@ -265,7 +337,7 @@ MrMean(NanPos) = 0; -Res = reshape(MrMean,[size(MrMean,1),size(MrMean,2),size(MrMean,3),1,1,size(MrMean,4),nN]).*1000; +Res = reshape(MrMean,[size(MrMean,1),size(MrMean,2),size(MrMean,3),1,1,size(MrMean,4),nN]); %Take mean RFu = reshape(mean(VD.*Res,6),[size(MrMean,1),size(MrMean,2),size(MrMean,3),size(VD,4),size(VD,4),nN]); @@ -298,6 +370,10 @@ S.shuffledData = shuffledData; +S.gridSpikeRateShuff = gridSpikeRateShuff; + +S.gridSpikeRate = gridSpikeRate; + S.params = params; save(filename,'-struct','S'); diff --git a/visualStimulationAnalysis/@rectGridAnalysis/plotRaster.m b/visualStimulationAnalysis/@rectGridAnalysis/plotRaster.m index a7c1f4e..a21684b 100644 --- a/visualStimulationAnalysis/@rectGridAnalysis/plotRaster.m +++ b/visualStimulationAnalysis/@rectGridAnalysis/plotRaster.m @@ -6,7 +6,7 @@ function plotRaster(obj,params) params.analysisTime = datetime('now') params.inputParams = false params.preBase = 200 - params.bin = 40 + params.bin = 15 params.exNeurons = [] params.AllSomaticNeurons = false params.AllResponsiveNeurons = true @@ -18,12 +18,19 @@ function plotRaster(obj,params) params.plotPatch logical = true params.PaperFig logical = false params.stim2show = 300 + params.statType string = "BootstrapPerNeuron" end NeuronResp = obj.ResponseWindow; -Stats = obj.ShufflingAnalysis; + +if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; +else + Stats = obj.ShufflingAnalysis; +end + directimesSorted = NeuronResp.C(:,1)'; nSize = numel(unique(NeuronResp.C(:,3))); @@ -38,10 +45,11 @@ function plotRaster(obj,params) proportionTrials = 1/(numel(NeuronResp.C(:,1))/numel(directimesSorted)); -goodU = NeuronResp.goodU; p = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); phy_IDg = p.phy_ID(string(p.label') == 'good'); pvals = Stats.pvalsResponse; +label = string(p.label'); +goodU = p.ic(:,label == 'good'); %somatic neurons stimDur = NeuronResp.stimDur; @@ -111,7 +119,7 @@ function plotRaster(obj,params) trialsPerCath = trialsPerCath/mergeTrials; nT = nT/mergeTrials; else - Mr2=Mr(:,u,:); + Mr2=Mr(:,ur,:); mergeTrials =1; end @@ -288,16 +296,19 @@ function plotRaster(obj,params) % pos1 = cb.Position(1); % cb.Position(1) = pos1 + 0.03; + figName = sprintf('%s-rect-GRid-raster-eNeuron-%d-Lum-%d',obj.dataObj.recordingName,u,params.selectedLum); + if params.PaperFig - obj.printFig(fig,sprintf('%s-rect-GRid-raster-eNeuron-%d',obj.dataObj.recordingName,u),PaperFig = params.PaperFig) + obj.printFig(fig,figName,PaperFig = params.PaperFig) elseif params.overwrite - obj.printFig(fig,sprintf('%s-rect-GRid-raster-eNeuron-%d',obj.dataObj.recordingName,u)) + obj.printFig(fig,figName) end %%Plot raw data + maxRespIn = maxRespIn-1; trialsPerCath = length(directimesSorted)/(length(unique(seqMatrix))); @@ -345,10 +356,11 @@ function plotRaster(obj,params) tr = numel(ind); end + figName = sprintf('%s-rect-GRid-rawData-%d-Trials-raster-eNeuron-%d-Lum%d',obj.dataObj.recordingName,tr,u,params.selectedLum); if params.PaperFig - obj.printFig(fig2,sprintf('%s-rect-GRid-rawData-%d-Trials-raster-eNeuron-%d',obj.dataObj.recordingName,tr,u),PaperFig = params.PaperFig) + obj.printFig(fig2,figName,PaperFig = params.PaperFig) elseif params.overwrite - obj.printFig(fig2,sprintf('%s-rect-GRid-rawData-%d-Trials-raster-eNeuron-%d',obj.dataObj.recordingName,u)) + obj.printFig(fig2,figName) end %prettify_plot diff --git a/visualStimulationAnalysis/RunAnalysisClass.asv b/visualStimulationAnalysis/RunAnalysisClass.asv new file mode 100644 index 0000000..9f7d3fd --- /dev/null +++ b/visualStimulationAnalysis/RunAnalysisClass.asv @@ -0,0 +1,210 @@ +cd('\\sil3\data\Large_scale_mapping_NP') +excelFile = 'Experiment_Excel.xlsx'; + +data = readtable(excelFile); + +%% +%% Rect Grid +for ex = [49:54,64:97] %84:91 + NP = loadNPclassFromTable(ex); %73 81 + vsRe = rectGridAnalysis(NP); + % vsRe.getSessionTime("overwrite",true); + % %vsRe.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % vsRe.getDiodeTriggers('overwrite',true); + % vsRe.getSyncedDiodeTriggers("overwrite",true); + % % vsRe.plotSpatialTuningSpikes; + % % vsRe.plotSpatialTuningLFP; + % vsRe.ResponseWindow('overwrite',true) + % results = vsRe.ShufflingAnalysis('overwrite',true); + % vsRe.plotRaster(MergeNtrials=1,overwrite=true,AllResponsiveNeurons = true, selectedLum=[],oneTrial = true,PaperFig = true) %43 + % close all;vsRe.plotRaster(MergeNtrials=1,overwrite=true,exNeurons=18, selectedLum=255,oneTrial = true,PaperFig = true) %43 + vsRe.CalculateReceptiveFields('overwrite',true) + %[colorbarLims] = vsRe.PlotReceptiveFields(exNeurons=18,allStimParamsCombined=true,PaperFig=true,overwrite=true); + %result = vsRe.BootstrapPerNeuron('overwrite',true); + +end +% vsRe.CalculateReceptiveFields +% vsRe.PlotReceptiveFields("meanAllNeurons",true) + +%% Moving ball + +for ex = [84:97]%97 74:84 (Neurons, 96_74, ) + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBallAnalysis(NP,Session=1); + % vs.getSessionTime("overwrite",true); + % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % % %vs.plotDiodeTriggers + % vs.getSyncedDiodeTriggers("overwrite",true); + % % %vs.plotSpatialTuningSpikes; + % r = vs.ResponseWindow('overwrite',true); + % results = vs.ShufflingAnalysis('overwrite',true); + % % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'MergeNtrials',3) + % %vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'MergeNtrials',2,'bin',5,'GaussianLength',30,'MaxVal_1', false) + % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'speed',2,'MergeNtrials',3) + %vs.plotRaster('exNeurons',82,'overwrite',true,'MergeNtrials',1,'OneDirection','up','OneLuminosity','white','PaperFig',true) + % % %vs.plotCorrSpikePattern + % vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'OneDirection','up','OneLuminosity','white','MergeNtrials',1,'PaperFig',true) + + %vs.plotRaster('exNeurons',9,'AllResponsiveNeurons',false,'overwrite',true,'MergeNtrials',3,MaxVal_1=false) + vs.CalculateReceptiveFields('overwrite',true,testConvolution=false); + % colorbarLims=vs.PlotReceptiveFields('exNeurons',82,'overwrite',true,'OneDirection','up','OneLuminosity','white','PaperFig',true); + %result = vs.BootstrapPerNeuron('overwrite',true);%('overwrite',true); + % pvals0_6Filter =result.Speed2.pvalsResponse'; + % compare = [pvals,pvalsNoFilt,pvals0_6Filter]; +end + +%% PlotZScoreComparison +%[49:54 57:81] MBR all experiments 'NV','NI' +%[44:56,64:88] All experiments +%[28:32,44,45,47,48,56,98] All SA experiments +%Check triggers 45, SA82 44,45,47:54,56,64:88 +% All stim: 'FFF','SDG','MBR','MB','RG','NI','NV' +%[49:54,64:97] %All PV good experiments +% %%[89,90,92,93,95,96,97] %Al NV and NI experiments +%[49:54,84:90,92:96] %All SDG experiments +%solve MBR +%bootsrapRespBase +VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... + overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR +%% PSTH for all experiments +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); + +%% Calculate spatial tuning +SpatialTuningIndex([52:54,64:97]) + +%% Gratings + +for ex = [54 84:90] + NP = loadNPclassFromTable(ex); %73 81 + vs = StaticDriftingGratingAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + result = vs.BootstrapPerNeuron('overwrite',true); +end + +%% movie + +for ex = [89,90,92,93,95:97] + NP = loadNPclassFromTable(ex); %73 81 + vs = movieAnalysis(NP); + % vs.getSessionTime("overwrite",true); + % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + %vs.getSyncedDiodeTriggers("overwrite",true); + %r = vs.ResponseWindow('overwrite',true); + %results = vs.ShufflingAnalysis('overwrite',true); + vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) +end + + +%% image + +for ex = [89,90,92,93,95:97] + NP = loadNPclassFromTable(ex); %73 81 + vs = imageAnalysis(NP); + %vs.getSessionTime("overwrite",true); + %vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + %vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + %results = vs.ShufflingAnalysis('overwrite',true); + vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) + +end + + +%% Moving bar +for ex = 81 + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBarAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + +%% FFF +for ex = 56 + NP = loadNPclassFromTable(ex); %73 81 + vs = fullFieldFlashAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + + +%% Run for all +for ex = 85:88 + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBallAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + +%% Check experiments in timseseries viewer +timeSeriesViewer(NP) +t=NP.getTrigger; +data.VS_ordered(ex) + +stimOn = t{3}; +stimOff = t{4}; + +MBRtOn = stimOn(stimOn > t{1}(1) & stimOn < t{2}(1)); +MBRtOff = stimOff(stimOff > t{1}(1) & stimOff < t{2}(1)); + +MBtOn = stimOn(stimOn > t{1}(2) & stimOn < t{2}(2)); +MBtOff = stimOff(stimOff > t{1}(2) & stimOff < t{2}(2)); + +RGtOn = stimOn(stimOn > t{1}(3) & stimOn < t{2}(3)); +RGtOff = stimOff(stimOff > t{1}(3) & stimOff < t{2}(3)); + +NGtOn = stimOn(stimOn > t{1}(4) & stimOn < t{2}(4)); +NGtOff = stimOff(stimOff > t{1}(4) & stimOff < t{2}(4)); + +DtOn = stimOn(stimOn > t{1}(5) & stimOn < t{2}(5)); +DtOff = stimOff(stimOff > t{1}(5) & stimOff < t{2}(5)); + +MovingBallTriggersDiode = d3.stimOnFlipTimes; + + + +%% %% check neural data sync and analog data sync + +allTimes = [stimOn(:); stimOff(:); onSync(:); offSync(:)]; % concatenate as column + +% Sort from earliest to latest +sortedTimesDiodeOldMethod = sort(allTimes); diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index fb86853..11c74e7 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -5,7 +5,7 @@ %% %% Rect Grid -for ex = [97] %84:91 +for ex = 52 %84:91 NP = loadNPclassFromTable(ex); %73 81 vsRe = rectGridAnalysis(NP); % vsRe.getSessionTime("overwrite",true); @@ -16,10 +16,11 @@ % % vsRe.plotSpatialTuningLFP; % vsRe.ResponseWindow('overwrite',true) % results = vsRe.ShufflingAnalysis('overwrite',true); - close all;vsRe.plotRaster(MergeNtrials=1,overwrite=true,exNeurons = 43, selectedLum=255,oneTrial = true,PaperFig = true) + % vsRe.plotRaster(MergeNtrials=1,overwrite=true,AllResponsiveNeurons = true, selectedLum=[],oneTrial = true,PaperFig = true) %43 + % close all;vsRe.plotRaster(MergeNtrials=1,overwrite=true,exNeurons=18, selectedLum=255,oneTrial = true,PaperFig = true) %43 vsRe.CalculateReceptiveFields('overwrite',true) - [colorbarLims] = vsRe.PlotReceptiveFields(exNeurons=43,allStimParamsCombined=true,PaperFig=true,overwrite=true); - result = vsRe.BootstrapPerNeuron('overwrite',true); + %[colorbarLims] = vsRe.PlotReceptiveFields(exNeurons=18,allStimParamsCombined=true,PaperFig=true,overwrite=true); + %result = vsRe.BootstrapPerNeuron('overwrite',true); end % vsRe.CalculateReceptiveFields @@ -27,7 +28,7 @@ %% Moving ball -for ex = [69,81,95,97] %97 +for ex = [84:97]%97 74:84 (Neurons, 96_74, ) NP = loadNPclassFromTable(ex); %73 81 vs = linearlyMovingBallAnalysis(NP,Session=1); % vs.getSessionTime("overwrite",true); @@ -35,17 +36,19 @@ % % %vs.plotDiodeTriggers % vs.getSyncedDiodeTriggers("overwrite",true); % % %vs.plotSpatialTuningSpikes; - r = vs.ResponseWindow('overwrite',true); - results = vs.ShufflingAnalysis('overwrite',true); + % r = vs.ResponseWindow('overwrite',true); + % results = vs.ShufflingAnalysis('overwrite',true); % % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'MergeNtrials',3) % %vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'MergeNtrials',2,'bin',5,'GaussianLength',30,'MaxVal_1', false) % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'speed',2,'MergeNtrials',3) - %vs.plotRaster('exNeurons',73,'overwrite',true,'MergeNtrials',1,'OneDirection','up','OneLuminosity','white','PaperFig',true) - % %vs.plotCorrSpikePattern - % %vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'speed',2) - %vs.CalculateReceptiveFields('overwrite',true); - %vs.PlotReceptiveFields('exNeurons',73,'overwrite',true,'OneDirection','up','OneLuminosity','white','PaperFig',true) - result = vs.BootstrapPerNeuron('overwrite',true);%('overwrite',true); + %vs.plotRaster('exNeurons',82,'overwrite',true,'MergeNtrials',1,'OneDirection','up','OneLuminosity','white','PaperFig',true) + % % %vs.plotCorrSpikePattern + % vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'OneDirection','up','OneLuminosity','white','MergeNtrials',1,'PaperFig',true) + + %vs.plotRaster('exNeurons',9,'AllResponsiveNeurons',false,'overwrite',true,'MergeNtrials',3,MaxVal_1=false) + vs.CalculateReceptiveFields('overwrite',true,testConvolution=false); + % colorbarLims=vs.PlotReceptiveFields('exNeurons',82,'overwrite',true,'OneDirection','up','OneLuminosity','white','PaperFig',true); + %result = vs.BootstrapPerNeuron('overwrite',true);%('overwrite',true); % pvals0_6Filter =result.Speed2.pvalsResponse'; % compare = [pvals,pvalsNoFilt,pvals0_6Filter]; end @@ -61,8 +64,13 @@ %[49:54,84:90,92:96] %All SDG experiments %solve MBR %bootsrapRespBase -VStimAnalysis.PlotZScoreComparison([49:54,64:97],{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=true,ComparePairs={'MB','RG'},PaperFig=true,... - overwriteResponse=true,overwriteStats=true)%[49:54,57:91] %%Check why I have different array dimensions in MBR +VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... + overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR +%% PSTH for all experiments +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); + +%% Calculate spatial tuning +SpatialTuningIndex([49:54,64:97], overwrite=true) %% Gratings diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv deleted file mode 100644 index 39e33d0..0000000 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.asv +++ /dev/null @@ -1,180 +0,0 @@ - -%% Run/load bombcell and confusion matrices - -% -exp = [49:54,64:97];% -%tiledlayout(numel(exp),1) -for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) - %%%%%%%%%%%% Load data and data paremeters - %1. Load NP class - ex=69 - NP = loadNPclassFromTable(ex); - vs = linearlyMovingBallAnalysis(NP,Session=1); - KSversion =4; - - [qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",0,KSversion); - - %convertPhySorting2tIc(obj,pathToPhyResults,tStart,BombCelled) - - % - % goodUnits = unitType == 1; - % muaUnits = unitType == 2; - % noiseUnits = unitType == 0; - % nonSomaticUnits = unitType == 3; - - % Concordance analysis - % bc load_manual_classifications(vs.spikeSortingFolder) - % pMC = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,0,1); - % pBC = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); - - bombcell_table = readtable([vs.spikeSortingFolder filesep 'cluster_bc_unitType.tsv'], 'FileType', 'text', 'Delimiter', '\t'); - manual_table = readtable([vs.spikeSortingFolder filesep 'cluster_info.tsv'],'FileType','delimitedtext'); - - manual_table = manual_table(:,{'cluster_id','KSLabel','group'}); - sum(strcmp(pKS.label, 'good')) - - - % Load and prepare data - % Assume: - % bombcell_table: Nx2 table, columns: [id, bc_label] ("GOOD","MUA","NON-SOMA","NOISE") - % manual_table: Mx3 table, columns: [id, KS_label, group] ("good","mua","noise") - - % Rename columns for clarity (adjust if yours differ) - bombcell_table.Properties.VariableNames = {'id', 'bc_label'}; - manual_table.Properties.VariableNames = {'id', 'KS_label', 'group'}; - - % Remove NON-SOMA from bombcell - bc = bombcell_table(~strcmp(bombcell_table.bc_label, 'NON-SOMA'), :); - - % Match IDs — keep only IDs present in both tables - [~, ia, ib] = intersect(bc.id, manual_table.id); - bc_matched = bc(ia, :); - man_matched = manual_table(ib, :); - - % Harmonize labels to lowercase for comparison - bc_labels = lower(bc_matched.bc_label); % "good","mua","noise" - ks_labels = lower(man_matched.KS_label); % "good","mua","noise" - man_labels = lower(man_matched.group); % "good","mua","noise" - - %%Define category order - cats = {'good', 'mua', 'noise'}; - - bc_cat = categorical(bc_labels, cats); - ks_cat = categorical(ks_labels, cats); - man_cat = categorical(man_labels, cats); - - % --- Confusion Matrix 1: Manual curation vs BombCell --- - % figure('Position', [100, 100, 700, 600]); - % - % tiledlayout(3,2) - % nexttile - % cm1 = confusionchart(man_cat, bc_cat, ... - % 'Title', sprintf('%s-Manual curation vs BombCell',NP.recordingName),... - % 'XLabel', 'BombCell', ... - % 'YLabel', 'Manual Curation', ... - % 'RowSummary', 'row-normalized', ... - % 'ColumnSummary', 'column-normalized'); - % - % cm1.FontSize = 9; - % - % % Give the chart more room inside the figure - % %cm1.Position = [10, 10, 680, 580]; - - % --- Confusion Matrix 2: KS label vs BombCell --- - fig = figure('Position', [100, 100, 700, 600]); - %tl = nexttile; - cm2 = confusionchart(ks_cat, bc_cat, ... - 'XLabel', 'BombCell', ... - 'YLabel', 'KS Label', ... - 'RowSummary', 'row-normalized', ... - 'ColumnSummary', 'column-normalized'); - cm2.FontSize = 9; - title(sprintf('%KS Label vs BombCell',NP.recordingName)); - - - - % %% --- Confusion Matrix 3: KS label vs Manual curation --- - % figure; - % cm3 = confusionchart(ks_cat, man_cat, ... - % 'Title', printf('KS Label vs Manual Curation',NP.recordingName), ... - % 'XLabel', 'Manual Curation', ... - % 'YLabel', 'KS Label', ... - % 'RowSummary', 'row-normalized', ... - % 'ColumnSummary', 'column-normalized'); - - % --- Print mismatch summary --- - % fprintf('\n=== Manual vs BombCell ===\n') - % mismatch_man_bc = man_cat ~= bc_cat; - % fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... - % sum(mismatch_man_bc), numel(mismatch_man_bc), ... - % 100*mean(mismatch_man_bc)); - - fprintf('\n=== KS Label vs BombCell ===\n') - mismatch_ks_bc = ks_cat ~= bc_cat; - fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... - sum(mismatch_ks_bc), numel(mismatch_ks_bc), ... - 100*mean(mismatch_ks_bc)); - - vs.printFig(fig,sprintf('%KS Label vs BombCell',NP.recordingName),PaperFig =1) - - close - - % fprintf('\n=== KS Label vs Manual Curation ===\n') - % mismatch_ks_man = ks_cat ~= man_cat; - % fprintf('Total mismatches: %d / %d (%.1f%%)\n', ... - % sum(mismatch_ks_man), numel(mismatch_ks_man), ... - % 100*mean(mismatch_ks_man)); - - imec = Neuropixel.ImecDataset(NP.recordingDir); - ks = Neuropixel.KilosortDataset(vs.spikeSortingFolder,'imecDataset', imec); - ks.load(); - -end -%I want to compare bombcell unit classification with manual classification in phy. - - - -%% Plot raw waveforms of specific units: - -% 1. Add to path: https://github.com/cortex-lab/spikes -% https://github.com/kwikteam/npy-matlab (dependency) - - -ksDir = vs.spikeSortingFolder; -sp = loadKSdir(ksDir); % loads all KS output into a struct - -% Get waveforms -gwfparams.dataDir = ksDir; -gwfparams.fileName = NP.recordingDir; -gwfparams.dataType = 'int16'; -gwfparams.nCh = 385; -gwfparams.wfWin = [-40 41]; % samples around spike -gwfparams.nWf = 100; % waveforms per unit -gwfparams.spikeTimes = sp.st; % spike times -gwfparams.spikeClusters = sp.clu; % cluster IDs - -wf = getWaveForms(gwfparams); % wf.waveForms: [units x waveforms x channels x samples] - -% Plot mean waveform for unit 1, best channel -figure; -plot(squeeze(mean(wf.waveFormsMean(1,:,:), 2))); - -%% Check low amp waveforms 10 neurons per experiment - -PVexps = [49:54,64:97]; -idx = randi(length(PVexps), 1, 4); -selected = PVexps(idx); - - - -for i = selected - NP = loadNPclassFromTable(53); - vs = linearlyMovingBallAnalysis(NP,Session=1); - - p = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); - phy_IDg = p.phy_ID(string(p.label') == 'good'); - - - plotRawWaveforms(vs, [47:50], showCorr=true, corrWin=50, corrBin=0.5) - -end diff --git a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m index 638974d..ebbf9cb 100644 --- a/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m +++ b/visualStimulationAnalysis/Run_Bombcell_Automatic_Sorting.m @@ -4,10 +4,9 @@ % exp = [49:54,64:97];% %tiledlayout(numel(exp),1) -for ex = exp%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) +for ex = exp(2:end)%GoodRecordingsPV%allGoodRec %GoodRecordings%GoodRecordingsPV%GoodRecordingsPV%selecN{1}(1,:) %1:size(data,1) %%%%%%%%%%%% Load data and data paremeters %1. Load NP class - ex=53 NP = loadNPclassFromTable(ex); vs = linearlyMovingBallAnalysis(NP,Session=1); KSversion =4; @@ -15,7 +14,9 @@ [qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",0,KSversion,1); %convertPhySorting2tIc(obj,pathToPhyResults,tStart,BombCelled) - +end +%% +for ex = exp % % goodUnits = unitType == 1; % muaUnits = unitType == 2; @@ -166,15 +167,29 @@ selected = PVexps(idx); - -for i = selected - NP = loadNPclassFromTable(53); +%% +selected =69; +for i = selected(1:end) + NP = loadNPclassFromTable(i); vs = linearlyMovingBallAnalysis(NP,Session=1); - p = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder,0,1,1); + p = vs.dataObj.convertPhySorting2tIc(vs.spikeSortingFolder); phy_IDg = p.phy_ID(string(p.label') == 'good'); + [param, qMetric, fractionRPVs_allTauR] = bc.load.loadSavedMetrics([NP.recordingDir filesep 'qMetrics']); - plotRawWaveforms(vs, [47:50], showCorr=true, corrWin=50, corrBin=0.5) + [~ ,idx] = sort(qMetric.rawAmplitude(ismember(qMetric.phy_clusterID,phy_IDg))); + + %Select units with lowest amplitude + selecUnits = qMetric.phy_clusterID(ismember(qMetric.phy_clusterID,phy_IDg)); + selecUnits = selecUnits(idx(1:min([10 numel(selecUnits)]))); + selecUnits = 104; + + plotRawWaveforms(vs, selecUnits, showCorr=true, corrWin=50, corrBin=0.5,nChanAround=6) + + qMetric.signalToNoiseRatio(qMetric.phy_clusterID == 630,:) + % q = qMetric(ismember(qMetric.phy_clusterID,selecUnits),:); end + +[qMetric,unitType]=NP.getBombCell(NP.recordingDir+"\kilosort4",1,KSversion,0); \ No newline at end of file diff --git a/visualStimulationAnalysis/SpatialTuningIndex.asv b/visualStimulationAnalysis/SpatialTuningIndex.asv new file mode 100644 index 0000000..0f6d98a --- /dev/null +++ b/visualStimulationAnalysis/SpatialTuningIndex.asv @@ -0,0 +1,408 @@ +function results = SpatialTuningIndex(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.topPercent double = 10 + params.overwrite logical = false + params.statType string = "BootstrapPerNeuron" + params.speed double = 1 + params.plot logical = true + params.indexType string = "L_combined" % L_amplitude, L_geometric, L_combined + params.onOff double = 1 % 1=on, 2=off (rectGrid only) + params.sizeIdx double = 1 + params.lumIdx double = 1 + params.nBoot double = 10000 + params.yLegend char = 'Spatial Tuning Index' + params.yMaxVis double = 1 + params.Alpha double = 0.4 + params.PaperFig logical = false +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); + +switch params.stimTypes(1) + case "rectGrid" + vs_first = rectGridAnalysis(NP_first); + case "linearlyMovingBall" + vs_first = linearlyMovingBallAnalysis(NP_first); +end + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_SpatialTuningIndex_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to compute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved SpatialTuningIndex from:\n %s\n', [saveDir nameOfFile]); + % Jump straight to table building + tbl = S.tbl; + goto_plot = true; + else + fprintf('Experiment list mismatch — recomputing.\n'); + goto_plot = false; + end +else + goto_plot = false; +end + +% ========================================================================= +% COMPUTE +% ========================================================================= +if ~goto_plot + + nExp = numel(exList); + nStim = numel(params.stimTypes); + + tbl = table(); + + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + continue + end + + nameParts = split(NP.recordingName, '_'); + animalName = nameParts{1}; + + % ---------------------------------------------------------- + % Find union of responsive neurons across ALL stim types + % ---------------------------------------------------------- + % Get phy IDs and responsive units for each stim type + respPhyIDs_all = cell(1, nStim); + phyIDs_all = cell(1, nStim); + + p_s = obj_s.dataObj.convertPhySorting2tIc(obj_s.spikeSortingFolder); + phy_IDg = p_s.phy_ID(string(p_s.label') == 'good'); + + + for s = 1:nStim + stimType = params.stimTypes(s); + try + switch stimType + case "rectGrid" + obj_s = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj_s = linearlyMovingBallAnalysis(NP); + end + + if params.statType == "BootstrapPerNeuron" + Stats = obj_s.BootstrapPerNeuron; + else + Stats = obj_s.ShufflingAnalysis; + end + + + try + switch stimType + case "linearlyMovingBall" + fieldName = sprintf('Speed%d', params.speed); + pvals = Stats.(fieldName).pvalsResponse; + otherwise + pvals = Stats.pvalsResponse; + end + catch + pvals = Stats.pvalsResponse; + end + + respU = find(pvals < 0.05); + phyIDs_all{s} = phy_IDg; % all good unit phy IDs for this stim + respPhyIDs_all{s} = phy_IDg(respU); % only responsive ones + fprintf(' [%s] %d responsive neuron(s).\n', stimType, numel(respU)); + + catch ME + warning('Could not get pvals for %s exp %d: %s', stimType, ex, ME.message); + phyIDs_all{s} = []; + respPhyIDs_all{s} = []; + end + end + + % Union of responsive phy IDs across stim types + sharedPhyIDs = respPhyIDs_all{1}; + for s = 2:nStim + sharedPhyIDs = union(sharedPhyIDs, respPhyIDs_all{s}); + end + + if isempty(sharedPhyIDs) + fprintf(' No responsive neurons in exp %d — skipping.\n', ex); + continue + end + + fprintf(' %d neuron(s) responsive to at least one stim type in exp %d.\n', numel(sharedPhyIDs), ex); + + + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Build analysis object + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + continue + end + + + % ---------------------------------------------------------- + % Load grid results + % ---------------------------------------------------------- + S_rf = obj.CalculateReceptiveFields; + + gridSpikeRate = S_rf.gridSpikeRate; + gridSpikeRateShuff = S_rf.gridSpikeRateShuff; + + switch stimType + case "rectGrid" + % Select onOff from both + gridSpikeRateSelected = gridSpikeRate(:,:,:,params.onOff,:,:); % [nGrid nGrid nN nSize nLum] -- but with singleton onOff removed + gridShuffSelected = gridSpikeRateShuff(:,:,:,:,params.onOff,:,:); % [nGrid nGrid nN nShuffle nSize nLum] + case "linearlyMovingBall" + gridSpikeRateSelected = gridSpikeRate; % [nGrid nGrid nN nSize nLum] + gridShuffSelected = gridSpikeRateShuff; % [nGrid nGrid nN nShuffle nSize nLum] + end + + % Find indices in this stim's good units that match sharedPhyIDs + [~, neuronIdx] = ismember(sharedPhyIDs, phyIDs_all{s}); + neuronIdx = neuronIdx(neuronIdx > 0); % remove any not found in this stim + + gridSpikeRateSelected = gridSpikeRateSelected(:,:,neuronIdx,:,:); + gridShuffSelected = gridShuffSelected(:,:,neuronIdx,:,:,:); + + % Average over shuffles and reshape explicitly — no squeeze + gridShuffMean = mean(gridShuffSelected, 4); % [nGrid nGrid nN 1 nSize nLum] + + % Get dimensions explicitly + nN = size(gridSpikeRateSelected, 3); + nSize = size(gridSpikeRateSelected, 5); + nLum = size(gridSpikeRateSelected, 6); + + % Reshape both to clean [nGrid nGrid nN nSize nLum] + gridSpikeRateSelected = reshape(gridSpikeRateSelected, [nGrid nGrid nN nSize nLum]); + gridShuffMean = reshape(gridShuffMean, [nGrid nGrid nN nSize nLum]); + + nCells = nGrid * nGrid; + maxDist = sqrt(2) * (nGrid - 1); + + % Average over shuffles + + + % ---------------------------------------------------------- + % Compute indices + % ---------------------------------------------------------- + + fprintf('gridSpikeRate size: %s\n', num2str(size(gridSpikeRate))); + fprintf('gridSpikeRateShuff size: %s\n', num2str(size(gridSpikeRateShuff))); + fprintf('gridShuffMean size: %s\n', num2str(size(gridShuffMean))); + + for si = 1:nSize + for li = 1:nLum + + rateFlat = reshape(gridSpikeRateSelected(:,:,:,si,li), [nCells, nN]); + rateFlatShuff = reshape(gridShuffMean(:,:,:,si,li), [nCells, nN]); + + L_amplitude = zeros(nN, 1); + L_geometric = zeros(nN, 1); + L_combined = zeros(nN, 1); + + for u = 1:nN + + rateVec = rateFlat(:, u); + rateVecShuff = rateFlatShuff(:, u); + + % Top cells + threshold = prctile(rateVec, 100 - params.topPercent); + thresholdShuff = prctile(rateVecShuff, 100 - params.topPercent); + + topIdx = find(rateVec >= threshold); + topIdxShuff = find(rateVecShuff >= thresholdShuff); + restIdx = setdiff(1:nCells, topIdx); + restIdxShuff = setdiff(1:nCells, topIdxShuff); + + % Amplitude + meanTop = mean(rateVec(topIdx)); + meanRest = mean(rateVec(restIdx)); + meanAll = mean(rateVec); + meanTopShuff = mean(rateVecShuff(topIdxShuff)); + meanRestShuff = mean(rateVecShuff(restIdxShuff)); + meanAllShuff = mean(rateVecShuff); + + if meanAll == 0, meanAll = eps; end + if meanAllShuff == 0, meanAllShuff = eps; end + + L_amplitude(u) = ... + (meanTop - meanRest) / meanAll - ... + (meanTopShuff - meanRestShuff) / meanAllShuff; + + % Geometric + [rowIdx, colIdx] = ind2sub([nGrid nGrid], topIdx); + [rowIdxShuff, colIdxShuff] = ind2sub([nGrid nGrid], topIdxShuff); + + if size(rowIdx, 1) > 1 + D = mean(pdist([rowIdx, colIdx], 'euclidean')) / maxDist; + else + D = 0; + end + if size(rowIdxShuff, 1) > 1 + DShuff = mean(pdist([rowIdxShuff, colIdxShuff], 'euclidean')) / maxDist; + else + DShuff = 0; + end + + L_geometric(u) = (1 - D) - (1 - DShuff); + L_combined(u) = L_amplitude(u) * L_geometric(u); + + end + + % Build rows for this condition + rows = table(); + rows.L_amplitude = L_amplitude; + rows.L_geometric = L_geometric; + rows.L_combined = L_combined; + rows.stimulus = categorical(repmat({char(stimType)}, nN, 1)); + rows.insertion = categorical(repmat(ex, nN, 1)); + rows.animal = categorical(repmat({animalName}, nN, 1)); + rows.NeurID = (1:nN)'; + rows.onOff = repmat(params.onOff, nN, 1); % params.onOff for rectGrid, meaningless but consistent for movingBall + rows.sizeIdx = repmat(si, nN, 1); + rows.lumIdx = repmat(li, nN, 1); + + tbl = [tbl; rows]; + + end + end + + fprintf(' [%s] Indices computed. %d neurons.\n', stimType, nN); + + end % stim loop + end % exp loop + + % Clean categories + tbl.stimulus = removecats(tbl.stimulus); + tbl.animal = removecats(tbl.animal); + tbl.insertion = removecats(tbl.insertion); + + % Save + S.expList = exList; + S.tbl = tbl; + S.params = params; + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved to:\n %s\n', [saveDir nameOfFile]); + +end % compute block + +results.tbl = tbl; + +% ========================================================================= +% PLOT +% ========================================================================= +if params.plot + + % Filter table to requested condition + idx = tbl.onOff == params.onOff & ... + tbl.sizeIdx == params.sizeIdx & ... + tbl.lumIdx == params.lumIdx; + + tblPlot = tbl(idx, :); + tblPlot.value = tblPlot.(params.indexType); % select which index to plot + + % ---------------------------------------------------------- + % Compute p-values using hierBoot + % ---------------------------------------------------------- + ps = []; + + pairs = {char(params.stimTypes(1)), char(params.stimTypes(2))}; + + + ps = zeros(size(pairs, 1), 1); + j = 1; + + for i = 1:size(pairs, 1) + diffs = []; + insers = []; + animals = []; + + for ins = unique(tblPlot.insertion)' + idx1 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,1}; + idx2 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,2}; + + V1 = tblPlot.value(idx1); + V2 = tblPlot.value(idx2); + + if isempty(V1) || isempty(V2) + continue + end + + animal = unique(tblPlot.animal(idx1)); + diffs = [diffs; V1 - V2]; + insers = [insers; double(repmat(ins, size(V1,1), 1))]; + animals = [animals; double(repmat(animal, size(V1,1), 1))]; + end + + if isempty(diffs) + ps(j) = NaN; + else + bootDiff = hierBoot(diffs, params.nBoot, insers, animals); + ps(j) = mean(bootDiff <= 0); + end + j = j + 1; + end + + + % ---------------------------------------------------------- + % Plot + % ---------------------------------------------------------- + V1max = max(tblPlot.value, [], 'omitnan'); + + [fig, ~] = plotSwarmBootstrapWithComparisons(tblPlot, pairs, ps, {'value'}, ... + yLegend = params.yLegend, ... + yMaxVis = max(params.yMaxVis, V1max), ... + diff = false, ... + Alpha = params.Alpha, ... + plotMeanSem = true); + + title(sprintf('%s — %s (onOff=%d, size=%d, lum=%d)', ... + params.indexType, strjoin(params.stimTypes, '/'), ... + params.onOff, params.sizeIdx, params.lumIdx), ... + 'FontSize', 9); + + if params.PaperFig + vs_first.printFig(fig, sprintf('SpatialTuningIndex-%s-%s', ... + params.indexType, strjoin(params.stimTypes, '-')), ... + PaperFig = params.PaperFig); + end + + results.fig = fig; + results.ps = ps; + +end + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/SpatialTuningIndex.m b/visualStimulationAnalysis/SpatialTuningIndex.m new file mode 100644 index 0000000..0f6d98a --- /dev/null +++ b/visualStimulationAnalysis/SpatialTuningIndex.m @@ -0,0 +1,408 @@ +function results = SpatialTuningIndex(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.topPercent double = 10 + params.overwrite logical = false + params.statType string = "BootstrapPerNeuron" + params.speed double = 1 + params.plot logical = true + params.indexType string = "L_combined" % L_amplitude, L_geometric, L_combined + params.onOff double = 1 % 1=on, 2=off (rectGrid only) + params.sizeIdx double = 1 + params.lumIdx double = 1 + params.nBoot double = 10000 + params.yLegend char = 'Spatial Tuning Index' + params.yMaxVis double = 1 + params.Alpha double = 0.4 + params.PaperFig logical = false +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); + +switch params.stimTypes(1) + case "rectGrid" + vs_first = rectGridAnalysis(NP_first); + case "linearlyMovingBall" + vs_first = linearlyMovingBallAnalysis(NP_first); +end + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_SpatialTuningIndex_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to compute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved SpatialTuningIndex from:\n %s\n', [saveDir nameOfFile]); + % Jump straight to table building + tbl = S.tbl; + goto_plot = true; + else + fprintf('Experiment list mismatch — recomputing.\n'); + goto_plot = false; + end +else + goto_plot = false; +end + +% ========================================================================= +% COMPUTE +% ========================================================================= +if ~goto_plot + + nExp = numel(exList); + nStim = numel(params.stimTypes); + + tbl = table(); + + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + continue + end + + nameParts = split(NP.recordingName, '_'); + animalName = nameParts{1}; + + % ---------------------------------------------------------- + % Find union of responsive neurons across ALL stim types + % ---------------------------------------------------------- + % Get phy IDs and responsive units for each stim type + respPhyIDs_all = cell(1, nStim); + phyIDs_all = cell(1, nStim); + + p_s = obj_s.dataObj.convertPhySorting2tIc(obj_s.spikeSortingFolder); + phy_IDg = p_s.phy_ID(string(p_s.label') == 'good'); + + + for s = 1:nStim + stimType = params.stimTypes(s); + try + switch stimType + case "rectGrid" + obj_s = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj_s = linearlyMovingBallAnalysis(NP); + end + + if params.statType == "BootstrapPerNeuron" + Stats = obj_s.BootstrapPerNeuron; + else + Stats = obj_s.ShufflingAnalysis; + end + + + try + switch stimType + case "linearlyMovingBall" + fieldName = sprintf('Speed%d', params.speed); + pvals = Stats.(fieldName).pvalsResponse; + otherwise + pvals = Stats.pvalsResponse; + end + catch + pvals = Stats.pvalsResponse; + end + + respU = find(pvals < 0.05); + phyIDs_all{s} = phy_IDg; % all good unit phy IDs for this stim + respPhyIDs_all{s} = phy_IDg(respU); % only responsive ones + fprintf(' [%s] %d responsive neuron(s).\n', stimType, numel(respU)); + + catch ME + warning('Could not get pvals for %s exp %d: %s', stimType, ex, ME.message); + phyIDs_all{s} = []; + respPhyIDs_all{s} = []; + end + end + + % Union of responsive phy IDs across stim types + sharedPhyIDs = respPhyIDs_all{1}; + for s = 2:nStim + sharedPhyIDs = union(sharedPhyIDs, respPhyIDs_all{s}); + end + + if isempty(sharedPhyIDs) + fprintf(' No responsive neurons in exp %d — skipping.\n', ex); + continue + end + + fprintf(' %d neuron(s) responsive to at least one stim type in exp %d.\n', numel(sharedPhyIDs), ex); + + + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Build analysis object + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + continue + end + + + % ---------------------------------------------------------- + % Load grid results + % ---------------------------------------------------------- + S_rf = obj.CalculateReceptiveFields; + + gridSpikeRate = S_rf.gridSpikeRate; + gridSpikeRateShuff = S_rf.gridSpikeRateShuff; + + switch stimType + case "rectGrid" + % Select onOff from both + gridSpikeRateSelected = gridSpikeRate(:,:,:,params.onOff,:,:); % [nGrid nGrid nN nSize nLum] -- but with singleton onOff removed + gridShuffSelected = gridSpikeRateShuff(:,:,:,:,params.onOff,:,:); % [nGrid nGrid nN nShuffle nSize nLum] + case "linearlyMovingBall" + gridSpikeRateSelected = gridSpikeRate; % [nGrid nGrid nN nSize nLum] + gridShuffSelected = gridSpikeRateShuff; % [nGrid nGrid nN nShuffle nSize nLum] + end + + % Find indices in this stim's good units that match sharedPhyIDs + [~, neuronIdx] = ismember(sharedPhyIDs, phyIDs_all{s}); + neuronIdx = neuronIdx(neuronIdx > 0); % remove any not found in this stim + + gridSpikeRateSelected = gridSpikeRateSelected(:,:,neuronIdx,:,:); + gridShuffSelected = gridShuffSelected(:,:,neuronIdx,:,:,:); + + % Average over shuffles and reshape explicitly — no squeeze + gridShuffMean = mean(gridShuffSelected, 4); % [nGrid nGrid nN 1 nSize nLum] + + % Get dimensions explicitly + nN = size(gridSpikeRateSelected, 3); + nSize = size(gridSpikeRateSelected, 5); + nLum = size(gridSpikeRateSelected, 6); + + % Reshape both to clean [nGrid nGrid nN nSize nLum] + gridSpikeRateSelected = reshape(gridSpikeRateSelected, [nGrid nGrid nN nSize nLum]); + gridShuffMean = reshape(gridShuffMean, [nGrid nGrid nN nSize nLum]); + + nCells = nGrid * nGrid; + maxDist = sqrt(2) * (nGrid - 1); + + % Average over shuffles + + + % ---------------------------------------------------------- + % Compute indices + % ---------------------------------------------------------- + + fprintf('gridSpikeRate size: %s\n', num2str(size(gridSpikeRate))); + fprintf('gridSpikeRateShuff size: %s\n', num2str(size(gridSpikeRateShuff))); + fprintf('gridShuffMean size: %s\n', num2str(size(gridShuffMean))); + + for si = 1:nSize + for li = 1:nLum + + rateFlat = reshape(gridSpikeRateSelected(:,:,:,si,li), [nCells, nN]); + rateFlatShuff = reshape(gridShuffMean(:,:,:,si,li), [nCells, nN]); + + L_amplitude = zeros(nN, 1); + L_geometric = zeros(nN, 1); + L_combined = zeros(nN, 1); + + for u = 1:nN + + rateVec = rateFlat(:, u); + rateVecShuff = rateFlatShuff(:, u); + + % Top cells + threshold = prctile(rateVec, 100 - params.topPercent); + thresholdShuff = prctile(rateVecShuff, 100 - params.topPercent); + + topIdx = find(rateVec >= threshold); + topIdxShuff = find(rateVecShuff >= thresholdShuff); + restIdx = setdiff(1:nCells, topIdx); + restIdxShuff = setdiff(1:nCells, topIdxShuff); + + % Amplitude + meanTop = mean(rateVec(topIdx)); + meanRest = mean(rateVec(restIdx)); + meanAll = mean(rateVec); + meanTopShuff = mean(rateVecShuff(topIdxShuff)); + meanRestShuff = mean(rateVecShuff(restIdxShuff)); + meanAllShuff = mean(rateVecShuff); + + if meanAll == 0, meanAll = eps; end + if meanAllShuff == 0, meanAllShuff = eps; end + + L_amplitude(u) = ... + (meanTop - meanRest) / meanAll - ... + (meanTopShuff - meanRestShuff) / meanAllShuff; + + % Geometric + [rowIdx, colIdx] = ind2sub([nGrid nGrid], topIdx); + [rowIdxShuff, colIdxShuff] = ind2sub([nGrid nGrid], topIdxShuff); + + if size(rowIdx, 1) > 1 + D = mean(pdist([rowIdx, colIdx], 'euclidean')) / maxDist; + else + D = 0; + end + if size(rowIdxShuff, 1) > 1 + DShuff = mean(pdist([rowIdxShuff, colIdxShuff], 'euclidean')) / maxDist; + else + DShuff = 0; + end + + L_geometric(u) = (1 - D) - (1 - DShuff); + L_combined(u) = L_amplitude(u) * L_geometric(u); + + end + + % Build rows for this condition + rows = table(); + rows.L_amplitude = L_amplitude; + rows.L_geometric = L_geometric; + rows.L_combined = L_combined; + rows.stimulus = categorical(repmat({char(stimType)}, nN, 1)); + rows.insertion = categorical(repmat(ex, nN, 1)); + rows.animal = categorical(repmat({animalName}, nN, 1)); + rows.NeurID = (1:nN)'; + rows.onOff = repmat(params.onOff, nN, 1); % params.onOff for rectGrid, meaningless but consistent for movingBall + rows.sizeIdx = repmat(si, nN, 1); + rows.lumIdx = repmat(li, nN, 1); + + tbl = [tbl; rows]; + + end + end + + fprintf(' [%s] Indices computed. %d neurons.\n', stimType, nN); + + end % stim loop + end % exp loop + + % Clean categories + tbl.stimulus = removecats(tbl.stimulus); + tbl.animal = removecats(tbl.animal); + tbl.insertion = removecats(tbl.insertion); + + % Save + S.expList = exList; + S.tbl = tbl; + S.params = params; + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved to:\n %s\n', [saveDir nameOfFile]); + +end % compute block + +results.tbl = tbl; + +% ========================================================================= +% PLOT +% ========================================================================= +if params.plot + + % Filter table to requested condition + idx = tbl.onOff == params.onOff & ... + tbl.sizeIdx == params.sizeIdx & ... + tbl.lumIdx == params.lumIdx; + + tblPlot = tbl(idx, :); + tblPlot.value = tblPlot.(params.indexType); % select which index to plot + + % ---------------------------------------------------------- + % Compute p-values using hierBoot + % ---------------------------------------------------------- + ps = []; + + pairs = {char(params.stimTypes(1)), char(params.stimTypes(2))}; + + + ps = zeros(size(pairs, 1), 1); + j = 1; + + for i = 1:size(pairs, 1) + diffs = []; + insers = []; + animals = []; + + for ins = unique(tblPlot.insertion)' + idx1 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,1}; + idx2 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,2}; + + V1 = tblPlot.value(idx1); + V2 = tblPlot.value(idx2); + + if isempty(V1) || isempty(V2) + continue + end + + animal = unique(tblPlot.animal(idx1)); + diffs = [diffs; V1 - V2]; + insers = [insers; double(repmat(ins, size(V1,1), 1))]; + animals = [animals; double(repmat(animal, size(V1,1), 1))]; + end + + if isempty(diffs) + ps(j) = NaN; + else + bootDiff = hierBoot(diffs, params.nBoot, insers, animals); + ps(j) = mean(bootDiff <= 0); + end + j = j + 1; + end + + + % ---------------------------------------------------------- + % Plot + % ---------------------------------------------------------- + V1max = max(tblPlot.value, [], 'omitnan'); + + [fig, ~] = plotSwarmBootstrapWithComparisons(tblPlot, pairs, ps, {'value'}, ... + yLegend = params.yLegend, ... + yMaxVis = max(params.yMaxVis, V1max), ... + diff = false, ... + Alpha = params.Alpha, ... + plotMeanSem = true); + + title(sprintf('%s — %s (onOff=%d, size=%d, lum=%d)', ... + params.indexType, strjoin(params.stimTypes, '/'), ... + params.onOff, params.sizeIdx, params.lumIdx), ... + 'FontSize', 9); + + if params.PaperFig + vs_first.printFig(fig, sprintf('SpatialTuningIndex-%s-%s', ... + params.indexType, strjoin(params.stimTypes, '-')), ... + PaperFig = params.PaperFig); + end + + results.fig = fig; + results.ps = ps; + +end + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/SpatialTuningIndexV1.m b/visualStimulationAnalysis/SpatialTuningIndexV1.m new file mode 100644 index 0000000..68a48d4 --- /dev/null +++ b/visualStimulationAnalysis/SpatialTuningIndexV1.m @@ -0,0 +1,246 @@ +function results = SpatialTuningIndex(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.topPercent double = 10 + params.overwrite logical = false + params.statType string = "BootstrapPerNeuron" + params.speed double = 1 +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); + +switch params.stimTypes(1) + case "rectGrid" + vs_first = rectGridAnalysis(NP_first); + case "linearlyMovingBall" + vs_first = linearlyMovingBallAnalysis(NP_first); +end + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_SpatialTuningIndex_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to compute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved SpatialTuningIndex from:\n %s\n', [saveDir nameOfFile]); + results = S; + return + else + fprintf('Experiment list mismatch — recomputing.\n'); + end +end + +% ------------------------------------------------------------------------- +% EXPERIMENT LOOP +% ------------------------------------------------------------------------- +nExp = numel(exList); +nStim = numel(params.stimTypes); + +% Will grow as we discover dimensions from first valid experiment +L_amplitude_all = cell(nStim, nExp); +L_geometric_all = cell(nStim, nExp); +L_combined_all = cell(nStim, nExp); + +for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + continue + end + + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Build analysis object + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + continue + end + + % ---------------------------------------------------------- + % Check for responsive neurons + % ---------------------------------------------------------- + try + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + % Resolve field name depending on stim type + try + switch stimType + case "linearlyMovingBall" + fieldName = sprintf('Speed%d', params.speed); + pvals = Stats.(fieldName).pvalsResponse; + otherwise + pvals = Stats.pvalsResponse; + end + catch + pvals = Stats.pvalsResponse; + end + + respU = find(pvals < 0.05); + + catch ME + warning('Could not load stats for %s exp %d: %s', stimType, ex, ME.message); + L_amplitude_all{s, ei} = []; + L_geometric_all{s, ei} = []; + L_combined_all{s, ei} = []; + continue + end + + if isempty(respU) + fprintf(' [%s] No responsive neurons in exp %d — skipping.\n', stimType, ex); + L_amplitude_all{s, ei} = []; + L_geometric_all{s, ei} = []; + L_combined_all{s, ei} = []; + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, ex, numel(respU)); + + % Load grid results + + S_rf = obj.CalculateReceptiveFields; + + gridSpikeRate = S_rf.gridSpikeRate; % [nGrid x nGrid x nN x onOff x nSize x nLum] + gridSpikeRateShuff = S_rf.gridSpikeRateShuff; % [nGrid x nGrid x nN x nShuffle x nSize x nLum] + + [nGrid, ~, nN, nOnOff, nSize, nLum] = size(gridSpikeRate); + nShuffle = size(gridSpikeRateShuff, 4); + nCells = nGrid * nGrid; + + % Average over shuffles + gridShuffMean = mean(gridSpikeRateShuff, 4); % [nGrid x nGrid x nN x nSize x nLum] + + L_amplitude = zeros(nN, nOnOff, nSize, nLum); + L_geometric = zeros(nN, nOnOff, nSize, nLum); + L_combined = zeros(nN, nOnOff, nSize, nLum); + + maxDist = sqrt(2) * (nGrid - 1); + + for oi = 1:nOnOff + for si = 1:nSize + for li = 1:nLum + + rateFlat = reshape(gridSpikeRate(:,:,:,oi,si,li), [nCells, nN]); + rateFlatShuff = reshape(gridShuffMean(:,:,:,si,li), [nCells, nN]); + + for u = 1:nN + + rateVec = rateFlat(:, u); + rateVecShuff = rateFlatShuff(:, u); + + %% ---- Shared: top cells ---- + threshold = prctile(rateVec, 100 - params.topPercent); + thresholdShuff = prctile(rateVecShuff, 100 - params.topPercent); + + topIdx = find(rateVec >= threshold); + topIdxShuff = find(rateVecShuff >= thresholdShuff); + + restIdx = setdiff(1:nCells, topIdx); + restIdxShuff = setdiff(1:nCells, topIdxShuff); + + %% ---- 1. Amplitude index ---- + meanTop = mean(rateVec(topIdx)); + meanRest = mean(rateVec(restIdx)); + meanAll = mean(rateVec); + + meanTopShuff = mean(rateVecShuff(topIdxShuff)); + meanRestShuff = mean(rateVecShuff(restIdxShuff)); + meanAllShuff = mean(rateVecShuff); + + if meanAll == 0, meanAll = eps; end + if meanAllShuff == 0, meanAllShuff = eps; end + + L_amplitude(u, oi, si, li) = ... + (meanTop - meanRest) / meanAll - ... + (meanTopShuff - meanRestShuff) / meanAllShuff; + + %% ---- 2. Geometric index ---- + [rowIdx, colIdx] = ind2sub([nGrid nGrid], topIdx); + [rowIdxShuff, colIdxShuff] = ind2sub([nGrid nGrid], topIdxShuff); + + if size(rowIdx, 1) > 1 + D = mean(pdist([rowIdx, colIdx], 'euclidean')) / maxDist; + else + D = 0; + end + + if size(rowIdxShuff, 1) > 1 + DShuff = mean(pdist([rowIdxShuff, colIdxShuff], 'euclidean')) / maxDist; + else + DShuff = 0; + end + + L_geometric(u, oi, si, li) = (1 - D) - (1 - DShuff); + + %% ---- 3. Combined index ---- + L_combined(u, oi, si, li) = L_amplitude(u, oi, si, li) * L_geometric(u, oi, si, li); + + end + end + end + end + + L_amplitude_all{s, ei} = L_amplitude; % [nN x nOnOff x nSize x nLum] + L_geometric_all{s, ei} = L_geometric; + L_combined_all{s, ei} = L_combined; + + fprintf(' [%s] Done. %d neurons.\n', stimType, nN); + + end % stim loop +end % experiment loop + +% ------------------------------------------------------------------------- +% Save +% ------------------------------------------------------------------------- +S.expList = exList; +S.L_amplitude_all = L_amplitude_all; % {nStim x nExp} cell, each [nN x nOnOff x nSize x nLum] +S.L_geometric_all = L_geometric_all; +S.L_combined_all = L_combined_all; +S.params = params; + +save([saveDir nameOfFile], '-struct', 'S'); +fprintf('\nSaved SpatialTuningIndex to:\n %s\n', [saveDir nameOfFile]); + +results = S; + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.m b/visualStimulationAnalysis/plotPSTH_MultiExp.m new file mode 100644 index 0000000..24e42dc --- /dev/null +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.m @@ -0,0 +1,463 @@ +function plotPSTH_MultiExp(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.bin double = 30 + params.binWidth double = 10 + params.statType string = "BootstrapPerNeuron" + params.speed string = "max" + params.alpha double = 0.05 + params.shadeSTD logical = true + params.postStim double = 500 % ms after stim onset to include + params.preBase double = 200 % ms of baseline before stim onset + params.overwrite logical = false % force recompute even if file exists + params.TakeTopPercentTrials double = 0.3 %Percentage of highest spiking rate trials to take to calculate PSTHs + params.zScore logical = false % normalize firing rate to z-score using baseline + params.PaperFig logical = false %Is this going to be used in the paper? +end + +% ------------------------------------------------------------------------- +% Build save path using first experiment to get the analysis folder +% This mirrors the convention used in PlotZScoreComparison +% ------------------------------------------------------------------------- + +% Load first experiment just to get the folder path +NP_first = loadNPclassFromTable(exList(1)); +vs_first = linearlyMovingBallAnalysis(NP_first); % used only for path + +% Build the save directory path +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +% Build filename — includes stim types so different comparisons don't clash +stimLabel = strjoin(params.stimTypes, '-'); % e.g. "rectGrid-linearlyMovingBall" +nameOfFile = sprintf('\\Ex_%d-%d_Combined_PSTHs_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to run the experiment loop or load from disk +% forloop = true → compute PSTHs from scratch +% forloop = false → load saved struct and skip to plotting +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + % File exists and overwrite is off — check if expList matches + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved PSTHs from:\n %s\n', [saveDir nameOfFile]); + forloop = false; % skip computation, go straight to plot + else + fprintf('Experiment list mismatch — recomputing.\n'); + forloop = true; % expList changed, recompute + end +else + forloop = true; % file doesn't exist or overwrite requested +end + +% ========================================================================= +% EXPERIMENT LOOP — only runs if forloop is true +% ========================================================================= +if forloop + + nStim = numel(params.stimTypes); + nExp = numel(exList); + + % One cell per stim type, grows one row per experiment + psthAll = cell(1, nStim); + for s = 1:nStim + psthAll{s} = []; + end + + % Locked time window — set from first valid experiment + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; + + % ------------------------------------------------------------------ + % LOOP OVER EXPERIMENTS + % ------------------------------------------------------------------ + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + % Load NP data for this experiment + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + % Add NaN placeholder row if window is already locked + for s = 1:nStim + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + end + continue + end + + % -------------------------------------------------------------- + % LOOP OVER STIMULUS TYPES + % -------------------------------------------------------------- + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Build analysis object for this stim type + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + case 'StaticGrating' + obj = StaticDriftingGratingAnalysis(NP); + case 'MovingGrating' + obj = StaticDriftingGratingAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue + end + + % ---------------------------------------------------------- + % Extract data structures + % ---------------------------------------------------------- + + % ResponseWindow holds trial timing and spike data + NeuronResp = obj.ResponseWindow; + + % Stats struct for p-values + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + % Resolve speed field name + if params.speed ~= "max" && isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed2'; + startStim = 0; + elseif isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed1'; + startStim = 0; + elseif isequal(params.stimTypes,'StaticGrating') + fieldName = 'Static'; + startStim = 0; + + elseif isequal(params.stimTypes,'MovingGrating') + startStim = obj.VST.static_time*1000; + fieldName = 'Moving'; + else + startStim = 0; + end + + % Spike trains of somatic (good) units + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + % P-values for each unit + try + pvals = Stats.(fieldName).pvalsResponse; + catch + pvals = Stats.pvalsResponse; + end + + % Trial onset times in ms + try + C = NeuronResp.(fieldName).C; + catch + C = NeuronResp.C; + end + directimesSorted = C(:, 1)' + startStim; + + % Use params.preBase directly — no formula needed + preBase = params.preBase; + + % Total trial window = baseline + post-stim period + windowTotal = preBase + params.postStim; + + % Lock in time window from first valid experiment + if isempty(lockedPreBase) + lockedPreBase = preBase; + lockedEdges = 0 : params.binWidth : windowTotal; + lockedNBins = numel(lockedEdges) - 1; + tAxis = lockedEdges(1:end-1); + fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... + lockedPreBase, params.postStim, lockedNBins); + end + + % ---------------------------------------------------------- + % Find responsive neurons + % ---------------------------------------------------------- + eNeurons = find(pvals < params.alpha); + + if isempty(eNeurons) + fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', ... + stimType, ex, numel(eNeurons)); + + % ---------------------------------------------------------- + % Build PSTH for each responsive neuron + % BuildBurstMatrix returns nTrials x 1 x nTimeBins + % Window: from (trialOnset - preBase) for windowTotal ms + % ---------------------------------------------------------- + psthRateNeurons = zeros(numel(eNeurons), lockedNBins); + + for ni = 1:numel(eNeurons) + u = eNeurons(ni); + + % Spike matrix: rows = trials, cols = time bins (1ms each) + MRhist = BuildBurstMatrix( ... + goodU(:, u), ... + round(p_sort.t), ... + round(directimesSorted - lockedPreBase), ... + round(windowTotal)); + + + + % Remove singleton dimensions → nTrials x nTimeBins + MRhist = squeeze(MRhist); + + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist,2); + [~, ind] = sort(MeanTrial,'descend'); + + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + + MRhist = MRhist(takeTrials,:); + + end + nTrials = size(MRhist, 1); + + % Convert to spike times in ms + spikeTimes = repmat((1:size(MRhist, 2)), nTrials, 1); + spikeTimes = spikeTimes(logical(MRhist)); + + % Bin into locked edges and convert to spk/s + counts = histcounts(spikeTimes, lockedEdges); + psthRateNeurons(ni, :) = (counts / (params.binWidth * nTrials)) * 1000; + end + + % Average across responsive neurons → 1 x lockedNBins + psthExp = mean(psthRateNeurons, 1, 'omitnan'); + + if params.zScore + baselineBins = tAxis < lockedPreBase; + baselineMean = mean(psthExp(baselineBins)); + baselineStd = std(psthExp(baselineBins)); + if baselineStd > 0 + psthExp = (psthExp - baselineMean) / baselineStd; + else + warning(' [%s] Baseline std is zero for exp %d — skipping experiment.', stimType, ex); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue % skip to next experiment, do not append raw rates + end + end + + % Append as new row — guaranteed lockedNBins wide + psthAll{s} = [psthAll{s}; psthExp(:)']; + + end % end stim loop + end % end experiment loop + + % ------------------------------------------------------------------ + % Save results to struct + % ------------------------------------------------------------------ + S.expList = exList; % experiment list for future matching + S.lockedEdges = lockedEdges; % bin edges used (ms from trial start) + S.lockedPreBase = lockedPreBase; % baseline duration in ms + S.params = params; % all parameters used + + % Save one field per stim type, named by stim e.g. S.rectGrid + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); % safe field name + S.(stimField) = psthAll{s}; % nExp x nBins PSTH matrix + end + + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved PSTHs to:\n %s\n', [saveDir nameOfFile]); + +else + % ------------------------------------------------------------------ + % Load psthAll from saved struct + % ------------------------------------------------------------------ + lockedEdges = S.lockedEdges; + lockedPreBase = S.lockedPreBase; + + psthAll = cell(1, numel(params.stimTypes)); + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + if isfield(S, stimField) + psthAll{s} = S.(stimField); % load the nExp x nBins matrix + else + % Stim type not found in saved file — warn and leave empty + warning('Stim type "%s" not found in saved file.', params.stimTypes(s)); + psthAll{s} = []; + end + end + +end % end forloop + +% ========================================================================= +% PLOT +% ========================================================================= + +tAxis = lockedEdges(1:end-1); +tAxisPlot = tAxis - lockedPreBase; + +colors = lines(numel(params.stimTypes)); + +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 9 10]); % single axis now + +% ------------------------------------------------------------------ +% Map stimulus type names to short legend labels +% ------------------------------------------------------------------ +stimLegendMap = containers.Map(... + {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... + {'MB', 'SB', 'MG', 'SG'}); + +% ------------------------------------------------------------------ +% First pass: compute mean/sem for all stim types and find global ylim +% ------------------------------------------------------------------ +meanAll = cell(1, numel(params.stimTypes)); +semAll = cell(1, numel(params.stimTypes)); +yMax = 0; +yMin = inf; + +for s = 1:numel(params.stimTypes) + data = psthAll{s}; + if isempty(data) + continue + end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data) + continue + end + meanAll{s} = mean(data, 1, 'omitnan'); + semAll{s} = std(data, 0, 1, 'omitnan') / sqrt(sum(~isnan(data(:,1)))); + yMax = max(yMax, max(meanAll{s} + semAll{s})); + yMin = min(yMin, min(meanAll{s} - semAll{s})); +end + +% Y limits with 10% padding +yPad = (yMax - yMin) * 0.1; +if params.zScore + yLims = [yMin - yPad, yMax + yPad]; +else + yLims = [max(0, yMin - yPad), yMax + yPad]; +end + +% ------------------------------------------------------------------ +% Single axis plot — all stim types overlaid +% ------------------------------------------------------------------ +ax = axes(fig); +hold(ax, 'on'); + +legendHandles = gobjects(numel(params.stimTypes), 1); % store line handles for legend + +for s = 1:numel(params.stimTypes) + + data = psthAll{s}; + if isempty(data) + continue + end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data) + continue + end + + meanPSTH = meanAll{s}; + semPSTH = semAll{s}; + + % Get short legend label for this stim type + stimKey = char(params.stimTypes(s)); + if isKey(stimLegendMap, stimKey) + legendLabel = stimLegendMap(stimKey); + else + legendLabel = stimKey; % fallback to full name if not in map + end + + % Shade ±SEM band + if params.shadeSTD && size(data, 1) > 1 + upper = meanPSTH + semPSTH; + lower = meanPSTH - semPSTH; + xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; + yFill = [upper(:)', fliplr(lower(:)') ]; + fill(ax, xFill, yFill, colors(s,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + end + + % Mean PSTH line — store handle for legend + legendHandles(s) = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... + 'Color', colors(s,:), 'LineWidth', 1.5, 'DisplayName', legendLabel); + + % Number of contributing experiments as text + nValid = sum(validRows); + fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, nValid); + +end + +% Stim onset and end of post-stim window +xline(ax, 0, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); +xline(ax, params.postStim, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); + +% Y label +if params.zScore + yLabel = 'Z-score'; +else + yLabel = '[spk/s]'; +end + +xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); +ylabel(ax, yLabel, 'FontName', 'helvetica', 'FontSize', 8); +xlim(ax, [tAxisPlot(1) tAxisPlot(end)]); +ylim(ax, yLims); + +% Legend — only show valid handles (skip stim types with no data) +validHandles = legendHandles(isgraphics(legendHandles)); +legend(validHandles, 'Location', 'northeast', 'FontName', 'helvetica', 'FontSize', 8); + +ax.FontName = 'helvetica'; +ax.FontSize = 8; +hold(ax, 'off'); + +sgtitle(sprintf('N = %d', numel(exList)), 'FontName', 'helvetica', 'FontSize', 11); + +ax = gca; +ax.YAxis.FontSize = 8; +ax.YAxis.FontName = 'helvetica'; + +ax = gca; +ax.XAxis.FontSize = 8; +ax.XAxis.FontName = 'helvetica'; + +set(fig, 'Units', 'centimeters'); +set(fig, 'Position', [20 20 5 6]); + +if params.PaperFig + vs_first.printFig(fig, sprintf('PSTH-comparison-%s-%s', ... + params.stimTypes(1), params.stimTypes(2)), PaperFig = params.PaperFig) +end + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/plotSpatialTuningIndex.m b/visualStimulationAnalysis/plotSpatialTuningIndex.m new file mode 100644 index 0000000..b163552 --- /dev/null +++ b/visualStimulationAnalysis/plotSpatialTuningIndex.m @@ -0,0 +1,189 @@ +function [fig, tbl] = plotSpatialTuningIndex(exList, pairs, params) + +arguments + exList double + pairs cell = {} + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.indexType string = "L_combined" % L_amplitude, L_geometric, L_combined + params.onOff double = 1 % 1=on, 2=off (rectGrid only) + params.sizeIdx double = 1 + params.lumIdx double = 1 + params.nBoot double = 10000 + params.overwrite logical = false + params.yLegend char = 'Spatial Tuning Index' + params.yMaxVis double = 1 + params.Alpha double = 0.4 + params.PaperFig logical = false +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); + +switch params.stimTypes(1) + case "rectGrid" + vs_first = rectGridAnalysis(NP_first); + case "linearlyMovingBall" + vs_first = linearlyMovingBallAnalysis(NP_first); +end + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_SpatialTuningIndex_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Load SpatialTuningIndex results +% ------------------------------------------------------------------------- +if ~exist([saveDir nameOfFile], 'file') + error('SpatialTuningIndex results not found. Run SpatialTuningIndex first.'); +end + +S = load([saveDir nameOfFile]); + +% ------------------------------------------------------------------------- +% Build long table +% ------------------------------------------------------------------------- +tbl = table(); + +nExp = numel(exList); +nStim = numel(params.stimTypes); + +for ei = 1:nExp + ex = exList(ei); + + % Get animal/insertion info + try + NP = loadNPclassFromTable(ex); + catch + warning('Could not load experiment %d — skipping.', ex); + continue + end + + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Get the right index matrix for this stim/exp + switch params.indexType + case "L_amplitude" + idxMat = S.L_amplitude_all{s, ei}; + case "L_geometric" + idxMat = S.L_geometric_all{s, ei}; + case "L_combined" + idxMat = S.L_combined_all{s, ei}; + end + + if isempty(idxMat) + continue + end + + % idxMat is [nN x nOnOff x nSize x nLum] + % for linearlyMovingBall there is no onOff dimension — handle both + if ndims(idxMat) == 3 + % [nN x nSize x nLum] — no onOff + vals = idxMat(:, params.sizeIdx, params.lumIdx); + oi = 1; + else + vals = idxMat(:, params.onOff, params.sizeIdx, params.lumIdx); + oi = params.onOff; + end + + nN = numel(vals); + + % Build rows for this experiment/stim + rows = table(); + rows.value = vals; + rows.stimulus = categorical(repmat({char(stimType)}, nN, 1)); + rows.insertion = categorical(repmat(ex, nN, 1)); + rows.animal = categorical(repmat({NP.animalName}, nN, 1)); + rows.NeurID = (1:nN)'; + rows.onOff = repmat(oi, nN, 1); + rows.sizeIdx = repmat(params.sizeIdx, nN, 1); + rows.lumIdx = repmat(params.lumIdx, nN, 1); + rows.indexType = categorical(repmat({params.indexType}, nN, 1)); + + tbl = [tbl; rows]; + end +end + +if isempty(tbl) + warning('No data found — table is empty.'); + fig = []; + return +end + +% Clean up categories +tbl.stimulus = removecats(tbl.stimulus); +tbl.animal = removecats(tbl.animal); +tbl.insertion = removecats(tbl.insertion); + +% ------------------------------------------------------------------------- +% Compute p-values using hierBoot +% ------------------------------------------------------------------------- +ps = []; + +if ~isempty(pairs) + ps = zeros(size(pairs, 1), 1); + j = 1; + + for i = 1:size(pairs, 1) + diffs = []; + insers = []; + animals = []; + + for ins = unique(tbl.insertion)' + idx1 = tbl.insertion == categorical(ins) & tbl.stimulus == pairs{i,1}; + idx2 = tbl.insertion == categorical(ins) & tbl.stimulus == pairs{i,2}; + + V1 = tbl.value(idx1); + V2 = tbl.value(idx2); + + if isempty(V1) || isempty(V2) + continue + end + + animal = unique(tbl.animal(idx1)); + diffs = [diffs; V1 - V2]; + insers = [insers; double(repmat(ins, size(V1,1), 1))]; + animals = [animals; double(repmat(animal, size(V1,1), 1))]; + end + + if isempty(diffs) + ps(j) = NaN; + else + bootDiff = hierBoot(diffs, params.nBoot, insers, animals); + ps(j) = mean(bootDiff <= 0); + end + j = j + 1; + end +end + +% ------------------------------------------------------------------------- +% Plot +% ------------------------------------------------------------------------- +V1max = max(tbl.value, [], 'omitnan'); + +[fig, ~] = plotSwarmBootstrapWithComparisons(tbl, pairs, ps, {'value'}, ... + yLegend = params.yLegend, ... + yMaxVis = max(params.yMaxVis, V1max), ... + diff = false, ... + Alpha = params.Alpha, ... + plotMeanSem = true); + +title(sprintf('%s — %s (size=%d, lum=%d)', ... + params.indexType, strjoin(params.stimTypes,'/'), ... + params.sizeIdx, params.lumIdx), ... + 'FontSize', 9); + +if params.PaperFig + vs_first.printFig(fig, sprintf('SpatialTuningIndex-%s-%s', ... + params.indexType, strjoin(params.stimTypes, '-')), ... + PaperFig = params.PaperFig); +end + +end \ No newline at end of file From da2415431b3a4a1c95428d1a932f73004ee347ed Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Fri, 20 Mar 2026 00:53:19 +0200 Subject: [PATCH 5/8] details spatial ytuning --- .../RunAnalysisClass.asv | 210 --------- visualStimulationAnalysis/RunAnalysisClass.m | 2 +- .../SpatialTuningIndex.asv | 408 ------------------ .../SpatialTuningIndex.m | 79 ++-- 4 files changed, 51 insertions(+), 648 deletions(-) delete mode 100644 visualStimulationAnalysis/RunAnalysisClass.asv delete mode 100644 visualStimulationAnalysis/SpatialTuningIndex.asv diff --git a/visualStimulationAnalysis/RunAnalysisClass.asv b/visualStimulationAnalysis/RunAnalysisClass.asv deleted file mode 100644 index 9f7d3fd..0000000 --- a/visualStimulationAnalysis/RunAnalysisClass.asv +++ /dev/null @@ -1,210 +0,0 @@ -cd('\\sil3\data\Large_scale_mapping_NP') -excelFile = 'Experiment_Excel.xlsx'; - -data = readtable(excelFile); - -%% -%% Rect Grid -for ex = [49:54,64:97] %84:91 - NP = loadNPclassFromTable(ex); %73 81 - vsRe = rectGridAnalysis(NP); - % vsRe.getSessionTime("overwrite",true); - % %vsRe.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - % vsRe.getDiodeTriggers('overwrite',true); - % vsRe.getSyncedDiodeTriggers("overwrite",true); - % % vsRe.plotSpatialTuningSpikes; - % % vsRe.plotSpatialTuningLFP; - % vsRe.ResponseWindow('overwrite',true) - % results = vsRe.ShufflingAnalysis('overwrite',true); - % vsRe.plotRaster(MergeNtrials=1,overwrite=true,AllResponsiveNeurons = true, selectedLum=[],oneTrial = true,PaperFig = true) %43 - % close all;vsRe.plotRaster(MergeNtrials=1,overwrite=true,exNeurons=18, selectedLum=255,oneTrial = true,PaperFig = true) %43 - vsRe.CalculateReceptiveFields('overwrite',true) - %[colorbarLims] = vsRe.PlotReceptiveFields(exNeurons=18,allStimParamsCombined=true,PaperFig=true,overwrite=true); - %result = vsRe.BootstrapPerNeuron('overwrite',true); - -end -% vsRe.CalculateReceptiveFields -% vsRe.PlotReceptiveFields("meanAllNeurons",true) - -%% Moving ball - -for ex = [84:97]%97 74:84 (Neurons, 96_74, ) - NP = loadNPclassFromTable(ex); %73 81 - vs = linearlyMovingBallAnalysis(NP,Session=1); - % vs.getSessionTime("overwrite",true); - % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - % % %vs.plotDiodeTriggers - % vs.getSyncedDiodeTriggers("overwrite",true); - % % %vs.plotSpatialTuningSpikes; - % r = vs.ResponseWindow('overwrite',true); - % results = vs.ShufflingAnalysis('overwrite',true); - % % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'MergeNtrials',3) - % %vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'MergeNtrials',2,'bin',5,'GaussianLength',30,'MaxVal_1', false) - % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'speed',2,'MergeNtrials',3) - %vs.plotRaster('exNeurons',82,'overwrite',true,'MergeNtrials',1,'OneDirection','up','OneLuminosity','white','PaperFig',true) - % % %vs.plotCorrSpikePattern - % vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'OneDirection','up','OneLuminosity','white','MergeNtrials',1,'PaperFig',true) - - %vs.plotRaster('exNeurons',9,'AllResponsiveNeurons',false,'overwrite',true,'MergeNtrials',3,MaxVal_1=false) - vs.CalculateReceptiveFields('overwrite',true,testConvolution=false); - % colorbarLims=vs.PlotReceptiveFields('exNeurons',82,'overwrite',true,'OneDirection','up','OneLuminosity','white','PaperFig',true); - %result = vs.BootstrapPerNeuron('overwrite',true);%('overwrite',true); - % pvals0_6Filter =result.Speed2.pvalsResponse'; - % compare = [pvals,pvalsNoFilt,pvals0_6Filter]; -end - -%% PlotZScoreComparison -%[49:54 57:81] MBR all experiments 'NV','NI' -%[44:56,64:88] All experiments -%[28:32,44,45,47,48,56,98] All SA experiments -%Check triggers 45, SA82 44,45,47:54,56,64:88 -% All stim: 'FFF','SDG','MBR','MB','RG','NI','NV' -%[49:54,64:97] %All PV good experiments -% %%[89,90,92,93,95,96,97] %Al NV and NI experiments -%[49:54,84:90,92:96] %All SDG experiments -%solve MBR -%bootsrapRespBase -VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... - overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR -%% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); - -%% Calculate spatial tuning -SpatialTuningIndex([52:54,64:97]) - -%% Gratings - -for ex = [54 84:90] - NP = loadNPclassFromTable(ex); %73 81 - vs = StaticDriftingGratingAnalysis(NP); - vs.getSessionTime("overwrite",true); - vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - dT = vs.getDiodeTriggers; - % vs.plotDiodeTriggers - vs.getSyncedDiodeTriggers("overwrite",true); - r = vs.ResponseWindow('overwrite',true); - results = vs.ShufflingAnalysis('overwrite',true); - result = vs.BootstrapPerNeuron('overwrite',true); -end - -%% movie - -for ex = [89,90,92,93,95:97] - NP = loadNPclassFromTable(ex); %73 81 - vs = movieAnalysis(NP); - % vs.getSessionTime("overwrite",true); - % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - % dT = vs.getDiodeTriggers; - % vs.plotDiodeTriggers - %vs.getSyncedDiodeTriggers("overwrite",true); - %r = vs.ResponseWindow('overwrite',true); - %results = vs.ShufflingAnalysis('overwrite',true); - vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) -end - - -%% image - -for ex = [89,90,92,93,95:97] - NP = loadNPclassFromTable(ex); %73 81 - vs = imageAnalysis(NP); - %vs.getSessionTime("overwrite",true); - %vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - %dT = vs.getDiodeTriggers; - % vs.plotDiodeTriggers - %vs.getSyncedDiodeTriggers("overwrite",true); - r = vs.ResponseWindow('overwrite',true); - %results = vs.ShufflingAnalysis('overwrite',true); - vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) - -end - - -%% Moving bar -for ex = 81 - NP = loadNPclassFromTable(ex); %73 81 - vs = linearlyMovingBarAnalysis(NP); - vs.getSessionTime("overwrite",true); - vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - %vs.plotDiodeTriggers - vs.getSyncedDiodeTriggers("overwrite",true); - r = vs.ResponseWindow('overwrite',true); - results = vs.ShufflingAnalysis('overwrite',true); - if ~any(results.Speed1.pvalsResponse<0.05) - fprintf('%d-No responsive neurons.\n',ex) - continue - end - vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); - vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) -end - -%% FFF -for ex = 56 - NP = loadNPclassFromTable(ex); %73 81 - vs = fullFieldFlashAnalysis(NP); - vs.getSessionTime("overwrite",true); - vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - %vs.plotDiodeTriggers - vs.getSyncedDiodeTriggers("overwrite",true); - r = vs.ResponseWindow('overwrite',true); - results = vs.ShufflingAnalysis('overwrite',true); - if ~any(results.Speed1.pvalsResponse<0.05) - fprintf('%d-No responsive neurons.\n',ex) - continue - end - vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); - vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) -end - - -%% Run for all -for ex = 85:88 - NP = loadNPclassFromTable(ex); %73 81 - vs = linearlyMovingBallAnalysis(NP); - vs.getSessionTime("overwrite",true); - vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); - %vs.plotDiodeTriggers - vs.getSyncedDiodeTriggers("overwrite",true); - r = vs.ResponseWindow('overwrite',true); - results = vs.ShufflingAnalysis('overwrite',true); - if ~any(results.Speed1.pvalsResponse<0.05) - fprintf('%d-No responsive neurons.\n',ex) - continue - end - vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); - vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) -end - -%% Check experiments in timseseries viewer -timeSeriesViewer(NP) -t=NP.getTrigger; -data.VS_ordered(ex) - -stimOn = t{3}; -stimOff = t{4}; - -MBRtOn = stimOn(stimOn > t{1}(1) & stimOn < t{2}(1)); -MBRtOff = stimOff(stimOff > t{1}(1) & stimOff < t{2}(1)); - -MBtOn = stimOn(stimOn > t{1}(2) & stimOn < t{2}(2)); -MBtOff = stimOff(stimOff > t{1}(2) & stimOff < t{2}(2)); - -RGtOn = stimOn(stimOn > t{1}(3) & stimOn < t{2}(3)); -RGtOff = stimOff(stimOff > t{1}(3) & stimOff < t{2}(3)); - -NGtOn = stimOn(stimOn > t{1}(4) & stimOn < t{2}(4)); -NGtOff = stimOff(stimOff > t{1}(4) & stimOff < t{2}(4)); - -DtOn = stimOn(stimOn > t{1}(5) & stimOn < t{2}(5)); -DtOff = stimOff(stimOff > t{1}(5) & stimOff < t{2}(5)); - -MovingBallTriggersDiode = d3.stimOnFlipTimes; - - - -%% %% check neural data sync and analog data sync - -allTimes = [stimOn(:); stimOff(:); onSync(:); offSync(:)]; % concatenate as column - -% Sort from earliest to latest -sortedTimesDiodeOldMethod = sort(allTimes); diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index 11c74e7..9332153 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -70,7 +70,7 @@ plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); %% Calculate spatial tuning -SpatialTuningIndex([49:54,64:97], overwrite=true) +SpatialTuningIndex([49:54,64:97], indexType = "L_geometric",overwrite=true) %% Gratings diff --git a/visualStimulationAnalysis/SpatialTuningIndex.asv b/visualStimulationAnalysis/SpatialTuningIndex.asv deleted file mode 100644 index 0f6d98a..0000000 --- a/visualStimulationAnalysis/SpatialTuningIndex.asv +++ /dev/null @@ -1,408 +0,0 @@ -function results = SpatialTuningIndex(exList, params) - -arguments - exList double - params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] - params.topPercent double = 10 - params.overwrite logical = false - params.statType string = "BootstrapPerNeuron" - params.speed double = 1 - params.plot logical = true - params.indexType string = "L_combined" % L_amplitude, L_geometric, L_combined - params.onOff double = 1 % 1=on, 2=off (rectGrid only) - params.sizeIdx double = 1 - params.lumIdx double = 1 - params.nBoot double = 10000 - params.yLegend char = 'Spatial Tuning Index' - params.yMaxVis double = 1 - params.Alpha double = 0.4 - params.PaperFig logical = false -end - -% ------------------------------------------------------------------------- -% Build save path -% ------------------------------------------------------------------------- -NP_first = loadNPclassFromTable(exList(1)); - -switch params.stimTypes(1) - case "rectGrid" - vs_first = rectGridAnalysis(NP_first); - case "linearlyMovingBall" - vs_first = linearlyMovingBallAnalysis(NP_first); -end - -p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); -p = [p 'lizards']; -if ~exist([p '\Combined_lizard_analysis'], 'dir') - cd(p) - mkdir Combined_lizard_analysis -end -saveDir = [p '\Combined_lizard_analysis']; - -stimLabel = strjoin(params.stimTypes, '-'); -nameOfFile = sprintf('\\Ex_%d-%d_SpatialTuningIndex_%s.mat', ... - exList(1), exList(end), stimLabel); - -% ------------------------------------------------------------------------- -% Decide whether to compute or load -% ------------------------------------------------------------------------- -if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite - S = load([saveDir nameOfFile]); - if isequal(S.expList, exList) - fprintf('Loading saved SpatialTuningIndex from:\n %s\n', [saveDir nameOfFile]); - % Jump straight to table building - tbl = S.tbl; - goto_plot = true; - else - fprintf('Experiment list mismatch — recomputing.\n'); - goto_plot = false; - end -else - goto_plot = false; -end - -% ========================================================================= -% COMPUTE -% ========================================================================= -if ~goto_plot - - nExp = numel(exList); - nStim = numel(params.stimTypes); - - tbl = table(); - - for ei = 1:nExp - - ex = exList(ei); - fprintf('\n=== Experiment %d ===\n', ex); - - try - NP = loadNPclassFromTable(ex); - catch ME - warning('Could not load experiment %d: %s', ex, ME.message); - continue - end - - nameParts = split(NP.recordingName, '_'); - animalName = nameParts{1}; - - % ---------------------------------------------------------- - % Find union of responsive neurons across ALL stim types - % ---------------------------------------------------------- - % Get phy IDs and responsive units for each stim type - respPhyIDs_all = cell(1, nStim); - phyIDs_all = cell(1, nStim); - - p_s = obj_s.dataObj.convertPhySorting2tIc(obj_s.spikeSortingFolder); - phy_IDg = p_s.phy_ID(string(p_s.label') == 'good'); - - - for s = 1:nStim - stimType = params.stimTypes(s); - try - switch stimType - case "rectGrid" - obj_s = rectGridAnalysis(NP); - case "linearlyMovingBall" - obj_s = linearlyMovingBallAnalysis(NP); - end - - if params.statType == "BootstrapPerNeuron" - Stats = obj_s.BootstrapPerNeuron; - else - Stats = obj_s.ShufflingAnalysis; - end - - - try - switch stimType - case "linearlyMovingBall" - fieldName = sprintf('Speed%d', params.speed); - pvals = Stats.(fieldName).pvalsResponse; - otherwise - pvals = Stats.pvalsResponse; - end - catch - pvals = Stats.pvalsResponse; - end - - respU = find(pvals < 0.05); - phyIDs_all{s} = phy_IDg; % all good unit phy IDs for this stim - respPhyIDs_all{s} = phy_IDg(respU); % only responsive ones - fprintf(' [%s] %d responsive neuron(s).\n', stimType, numel(respU)); - - catch ME - warning('Could not get pvals for %s exp %d: %s', stimType, ex, ME.message); - phyIDs_all{s} = []; - respPhyIDs_all{s} = []; - end - end - - % Union of responsive phy IDs across stim types - sharedPhyIDs = respPhyIDs_all{1}; - for s = 2:nStim - sharedPhyIDs = union(sharedPhyIDs, respPhyIDs_all{s}); - end - - if isempty(sharedPhyIDs) - fprintf(' No responsive neurons in exp %d — skipping.\n', ex); - continue - end - - fprintf(' %d neuron(s) responsive to at least one stim type in exp %d.\n', numel(sharedPhyIDs), ex); - - - for s = 1:nStim - - stimType = params.stimTypes(s); - - % Build analysis object - try - switch stimType - case "rectGrid" - obj = rectGridAnalysis(NP); - case "linearlyMovingBall" - obj = linearlyMovingBallAnalysis(NP); - otherwise - error('Unknown stimType: %s', stimType); - end - catch ME - warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); - continue - end - - - % ---------------------------------------------------------- - % Load grid results - % ---------------------------------------------------------- - S_rf = obj.CalculateReceptiveFields; - - gridSpikeRate = S_rf.gridSpikeRate; - gridSpikeRateShuff = S_rf.gridSpikeRateShuff; - - switch stimType - case "rectGrid" - % Select onOff from both - gridSpikeRateSelected = gridSpikeRate(:,:,:,params.onOff,:,:); % [nGrid nGrid nN nSize nLum] -- but with singleton onOff removed - gridShuffSelected = gridSpikeRateShuff(:,:,:,:,params.onOff,:,:); % [nGrid nGrid nN nShuffle nSize nLum] - case "linearlyMovingBall" - gridSpikeRateSelected = gridSpikeRate; % [nGrid nGrid nN nSize nLum] - gridShuffSelected = gridSpikeRateShuff; % [nGrid nGrid nN nShuffle nSize nLum] - end - - % Find indices in this stim's good units that match sharedPhyIDs - [~, neuronIdx] = ismember(sharedPhyIDs, phyIDs_all{s}); - neuronIdx = neuronIdx(neuronIdx > 0); % remove any not found in this stim - - gridSpikeRateSelected = gridSpikeRateSelected(:,:,neuronIdx,:,:); - gridShuffSelected = gridShuffSelected(:,:,neuronIdx,:,:,:); - - % Average over shuffles and reshape explicitly — no squeeze - gridShuffMean = mean(gridShuffSelected, 4); % [nGrid nGrid nN 1 nSize nLum] - - % Get dimensions explicitly - nN = size(gridSpikeRateSelected, 3); - nSize = size(gridSpikeRateSelected, 5); - nLum = size(gridSpikeRateSelected, 6); - - % Reshape both to clean [nGrid nGrid nN nSize nLum] - gridSpikeRateSelected = reshape(gridSpikeRateSelected, [nGrid nGrid nN nSize nLum]); - gridShuffMean = reshape(gridShuffMean, [nGrid nGrid nN nSize nLum]); - - nCells = nGrid * nGrid; - maxDist = sqrt(2) * (nGrid - 1); - - % Average over shuffles - - - % ---------------------------------------------------------- - % Compute indices - % ---------------------------------------------------------- - - fprintf('gridSpikeRate size: %s\n', num2str(size(gridSpikeRate))); - fprintf('gridSpikeRateShuff size: %s\n', num2str(size(gridSpikeRateShuff))); - fprintf('gridShuffMean size: %s\n', num2str(size(gridShuffMean))); - - for si = 1:nSize - for li = 1:nLum - - rateFlat = reshape(gridSpikeRateSelected(:,:,:,si,li), [nCells, nN]); - rateFlatShuff = reshape(gridShuffMean(:,:,:,si,li), [nCells, nN]); - - L_amplitude = zeros(nN, 1); - L_geometric = zeros(nN, 1); - L_combined = zeros(nN, 1); - - for u = 1:nN - - rateVec = rateFlat(:, u); - rateVecShuff = rateFlatShuff(:, u); - - % Top cells - threshold = prctile(rateVec, 100 - params.topPercent); - thresholdShuff = prctile(rateVecShuff, 100 - params.topPercent); - - topIdx = find(rateVec >= threshold); - topIdxShuff = find(rateVecShuff >= thresholdShuff); - restIdx = setdiff(1:nCells, topIdx); - restIdxShuff = setdiff(1:nCells, topIdxShuff); - - % Amplitude - meanTop = mean(rateVec(topIdx)); - meanRest = mean(rateVec(restIdx)); - meanAll = mean(rateVec); - meanTopShuff = mean(rateVecShuff(topIdxShuff)); - meanRestShuff = mean(rateVecShuff(restIdxShuff)); - meanAllShuff = mean(rateVecShuff); - - if meanAll == 0, meanAll = eps; end - if meanAllShuff == 0, meanAllShuff = eps; end - - L_amplitude(u) = ... - (meanTop - meanRest) / meanAll - ... - (meanTopShuff - meanRestShuff) / meanAllShuff; - - % Geometric - [rowIdx, colIdx] = ind2sub([nGrid nGrid], topIdx); - [rowIdxShuff, colIdxShuff] = ind2sub([nGrid nGrid], topIdxShuff); - - if size(rowIdx, 1) > 1 - D = mean(pdist([rowIdx, colIdx], 'euclidean')) / maxDist; - else - D = 0; - end - if size(rowIdxShuff, 1) > 1 - DShuff = mean(pdist([rowIdxShuff, colIdxShuff], 'euclidean')) / maxDist; - else - DShuff = 0; - end - - L_geometric(u) = (1 - D) - (1 - DShuff); - L_combined(u) = L_amplitude(u) * L_geometric(u); - - end - - % Build rows for this condition - rows = table(); - rows.L_amplitude = L_amplitude; - rows.L_geometric = L_geometric; - rows.L_combined = L_combined; - rows.stimulus = categorical(repmat({char(stimType)}, nN, 1)); - rows.insertion = categorical(repmat(ex, nN, 1)); - rows.animal = categorical(repmat({animalName}, nN, 1)); - rows.NeurID = (1:nN)'; - rows.onOff = repmat(params.onOff, nN, 1); % params.onOff for rectGrid, meaningless but consistent for movingBall - rows.sizeIdx = repmat(si, nN, 1); - rows.lumIdx = repmat(li, nN, 1); - - tbl = [tbl; rows]; - - end - end - - fprintf(' [%s] Indices computed. %d neurons.\n', stimType, nN); - - end % stim loop - end % exp loop - - % Clean categories - tbl.stimulus = removecats(tbl.stimulus); - tbl.animal = removecats(tbl.animal); - tbl.insertion = removecats(tbl.insertion); - - % Save - S.expList = exList; - S.tbl = tbl; - S.params = params; - save([saveDir nameOfFile], '-struct', 'S'); - fprintf('\nSaved to:\n %s\n', [saveDir nameOfFile]); - -end % compute block - -results.tbl = tbl; - -% ========================================================================= -% PLOT -% ========================================================================= -if params.plot - - % Filter table to requested condition - idx = tbl.onOff == params.onOff & ... - tbl.sizeIdx == params.sizeIdx & ... - tbl.lumIdx == params.lumIdx; - - tblPlot = tbl(idx, :); - tblPlot.value = tblPlot.(params.indexType); % select which index to plot - - % ---------------------------------------------------------- - % Compute p-values using hierBoot - % ---------------------------------------------------------- - ps = []; - - pairs = {char(params.stimTypes(1)), char(params.stimTypes(2))}; - - - ps = zeros(size(pairs, 1), 1); - j = 1; - - for i = 1:size(pairs, 1) - diffs = []; - insers = []; - animals = []; - - for ins = unique(tblPlot.insertion)' - idx1 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,1}; - idx2 = tblPlot.insertion == categorical(ins) & tblPlot.stimulus == pairs{i,2}; - - V1 = tblPlot.value(idx1); - V2 = tblPlot.value(idx2); - - if isempty(V1) || isempty(V2) - continue - end - - animal = unique(tblPlot.animal(idx1)); - diffs = [diffs; V1 - V2]; - insers = [insers; double(repmat(ins, size(V1,1), 1))]; - animals = [animals; double(repmat(animal, size(V1,1), 1))]; - end - - if isempty(diffs) - ps(j) = NaN; - else - bootDiff = hierBoot(diffs, params.nBoot, insers, animals); - ps(j) = mean(bootDiff <= 0); - end - j = j + 1; - end - - - % ---------------------------------------------------------- - % Plot - % ---------------------------------------------------------- - V1max = max(tblPlot.value, [], 'omitnan'); - - [fig, ~] = plotSwarmBootstrapWithComparisons(tblPlot, pairs, ps, {'value'}, ... - yLegend = params.yLegend, ... - yMaxVis = max(params.yMaxVis, V1max), ... - diff = false, ... - Alpha = params.Alpha, ... - plotMeanSem = true); - - title(sprintf('%s — %s (onOff=%d, size=%d, lum=%d)', ... - params.indexType, strjoin(params.stimTypes, '/'), ... - params.onOff, params.sizeIdx, params.lumIdx), ... - 'FontSize', 9); - - if params.PaperFig - vs_first.printFig(fig, sprintf('SpatialTuningIndex-%s-%s', ... - params.indexType, strjoin(params.stimTypes, '-')), ... - PaperFig = params.PaperFig); - end - - results.fig = fig; - results.ps = ps; - -end - -end \ No newline at end of file diff --git a/visualStimulationAnalysis/SpatialTuningIndex.m b/visualStimulationAnalysis/SpatialTuningIndex.m index 0f6d98a..98bf534 100644 --- a/visualStimulationAnalysis/SpatialTuningIndex.m +++ b/visualStimulationAnalysis/SpatialTuningIndex.m @@ -8,7 +8,7 @@ params.statType string = "BootstrapPerNeuron" params.speed double = 1 params.plot logical = true - params.indexType string = "L_combined" % L_amplitude, L_geometric, L_combined + params.indexType string = "L_amplitude" % L_amplitude, L_geometric, L_combined params.onOff double = 1 % 1=on, 2=off (rectGrid only) params.sizeIdx double = 1 params.lumIdx double = 1 @@ -83,19 +83,21 @@ continue end + obj_s = linearlyMovingBallAnalysis(NP); + nameParts = split(NP.recordingName, '_'); animalName = nameParts{1}; % ---------------------------------------------------------- % Find union of responsive neurons across ALL stim types % ---------------------------------------------------------- - % Get phy IDs and responsive units for each stim type - respPhyIDs_all = cell(1, nStim); - phyIDs_all = cell(1, nStim); - p_s = obj_s.dataObj.convertPhySorting2tIc(obj_s.spikeSortingFolder); - phy_IDg = p_s.phy_ID(string(p_s.label') == 'good'); + % Get phy IDs once — same for all stim types + p_s = NP.convertPhySorting2tIc(obj_s.spikeSortingFolder); + phy_IDg = p_s.phy_ID(string(p_s.label') == 'good'); + respPhyIDs_all = cell(1, nStim); + respU_all = cell(1, nStim); % ADD — stores respU indices per stim for s = 1:nStim stimType = params.stimTypes(s); @@ -113,7 +115,6 @@ Stats = obj_s.ShufflingAnalysis; end - try switch stimType case "linearlyMovingBall" @@ -126,30 +127,30 @@ pvals = Stats.pvalsResponse; end - respU = find(pvals < 0.05); - phyIDs_all{s} = phy_IDg; % all good unit phy IDs for this stim - respPhyIDs_all{s} = phy_IDg(respU); % only responsive ones + respU = find(pvals < 0.05); + respU_all{s} = respU; % ADD — index into gridSpikeRate dim 3 + respPhyIDs_all{s} = phy_IDg(respU); % phy IDs of responsive neurons fprintf(' [%s] %d responsive neuron(s).\n', stimType, numel(respU)); catch ME warning('Could not get pvals for %s exp %d: %s', stimType, ex, ME.message); - phyIDs_all{s} = []; + respU_all{s} = []; respPhyIDs_all{s} = []; end end - % Union of responsive phy IDs across stim types + % Intersection of responsive phy IDs across stim types sharedPhyIDs = respPhyIDs_all{1}; for s = 2:nStim - sharedPhyIDs = union(sharedPhyIDs, respPhyIDs_all{s}); + sharedPhyIDs = intersect(sharedPhyIDs, respPhyIDs_all{s}); end if isempty(sharedPhyIDs) - fprintf(' No responsive neurons in exp %d — skipping.\n', ex); + fprintf(' No neurons responsive to all stim types in exp %d — skipping.\n', ex); continue end - fprintf(' %d neuron(s) responsive to at least one stim type in exp %d.\n', numel(sharedPhyIDs), ex); + fprintf(' %d neuron(s) responsive to all stim types in exp %d.\n', numel(sharedPhyIDs), ex); for s = 1:nStim @@ -182,17 +183,27 @@ switch stimType case "rectGrid" - % Select onOff from both - gridSpikeRateSelected = gridSpikeRate(:,:,:,params.onOff,:,:); % [nGrid nGrid nN nSize nLum] -- but with singleton onOff removed - gridShuffSelected = gridSpikeRateShuff(:,:,:,:,params.onOff,:,:); % [nGrid nGrid nN nShuffle nSize nLum] + gridSpikeRateSelected = gridSpikeRate(:,:,:,params.onOff,:,:); + gridShuffSelected = gridSpikeRateShuff(:,:,:,:,params.onOff,:,:); + + % Remove onOff singleton at dim 4 for rate: [9 9 nN 1 nSize nLum] -> [9 9 nN nSize nLum] + gridSpikeRateSelected = reshape(gridSpikeRateSelected, ... + [size(gridSpikeRateSelected,1), size(gridSpikeRateSelected,2), ... + size(gridSpikeRateSelected,3), size(gridSpikeRateSelected,5), ... + size(gridSpikeRateSelected,6)]); + + % Remove onOff singleton at dim 5 for shuff: [9 9 nN nShuffle 1 nSize nLum] -> [9 9 nN nShuffle nSize nLum] + gridShuffSelected = reshape(gridShuffSelected, ... + [size(gridShuffSelected,1), size(gridShuffSelected,2), ... + size(gridShuffSelected,3), size(gridShuffSelected,4), ... + size(gridShuffSelected,6), size(gridShuffSelected,7)]); case "linearlyMovingBall" gridSpikeRateSelected = gridSpikeRate; % [nGrid nGrid nN nSize nLum] gridShuffSelected = gridSpikeRateShuff; % [nGrid nGrid nN nShuffle nSize nLum] end - % Find indices in this stim's good units that match sharedPhyIDs - [~, neuronIdx] = ismember(sharedPhyIDs, phyIDs_all{s}); - neuronIdx = neuronIdx(neuronIdx > 0); % remove any not found in this stim + % Find which indices of THIS stim's gridSpikeRate correspond to sharedPhyIDs + [~, neuronIdx] = ismember(sharedPhyIDs, phy_IDg(respU_all{s})); gridSpikeRateSelected = gridSpikeRateSelected(:,:,neuronIdx,:,:); gridShuffSelected = gridShuffSelected(:,:,neuronIdx,:,:,:); @@ -202,8 +213,12 @@ % Get dimensions explicitly nN = size(gridSpikeRateSelected, 3); - nSize = size(gridSpikeRateSelected, 5); - nLum = size(gridSpikeRateSelected, 6); + nSize = size(gridSpikeRateSelected, 4); + nLum = size(gridSpikeRateSelected, 5); + nGrid = size(gridSpikeRateSelected, 1); + + fprintf('gridSpikeRateSelected size before reshape: %s\n', num2str(size(gridSpikeRateSelected))); + fprintf('Expected: [%d %d %d %d %d]\n', nGrid, nGrid, nN, nSize, nLum); % Reshape both to clean [nGrid nGrid nN nSize nLum] gridSpikeRateSelected = reshape(gridSpikeRateSelected, [nGrid nGrid nN nSize nLum]); @@ -214,7 +229,6 @@ % Average over shuffles - % ---------------------------------------------------------- % Compute indices % ---------------------------------------------------------- @@ -229,7 +243,8 @@ rateFlat = reshape(gridSpikeRateSelected(:,:,:,si,li), [nCells, nN]); rateFlatShuff = reshape(gridShuffMean(:,:,:,si,li), [nCells, nN]); - L_amplitude = zeros(nN, 1); + L_amplitude_diff = zeros(nN, 1); + L_amplitude_ratio = zeros(nN, 1); L_geometric = zeros(nN, 1); L_combined = zeros(nN, 1); @@ -258,10 +273,15 @@ if meanAll == 0, meanAll = eps; end if meanAllShuff == 0, meanAllShuff = eps; end - L_amplitude(u) = ... + L_amplitude_diff(u) = ... (meanTop - meanRest) / meanAll - ... (meanTopShuff - meanRestShuff) / meanAllShuff; + shuffleNorm = (meanTopShuff - meanRestShuff) / meanAllShuff; + if shuffleNorm == 0, shuffleNorm = eps; end + + L_amplitude_ratio(u) = ((meanTop - meanRest) / meanAll) / shuffleNorm; + % Geometric [rowIdx, colIdx] = ind2sub([nGrid nGrid], topIdx); [rowIdxShuff, colIdxShuff] = ind2sub([nGrid nGrid], topIdxShuff); @@ -278,13 +298,14 @@ end L_geometric(u) = (1 - D) - (1 - DShuff); - L_combined(u) = L_amplitude(u) * L_geometric(u); + L_combined(u) = L_amplitude_diff(u) * L_geometric(u); end % Build rows for this condition rows = table(); - rows.L_amplitude = L_amplitude; + rows.L_amplitude_diff = L_amplitude_diff; + rows.L_amplitude_ratio = L_amplitude_ratio; rows.L_geometric = L_geometric; rows.L_combined = L_combined; rows.stimulus = categorical(repmat({char(stimType)}, nN, 1)); @@ -385,7 +406,7 @@ [fig, ~] = plotSwarmBootstrapWithComparisons(tblPlot, pairs, ps, {'value'}, ... yLegend = params.yLegend, ... yMaxVis = max(params.yMaxVis, V1max), ... - diff = false, ... + diff = true, ... Alpha = params.Alpha, ... plotMeanSem = true); From 14e9f4524ff03b153f2c22b610c915cb9ed25db3 Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Wed, 25 Mar 2026 00:42:59 +0200 Subject: [PATCH 6/8] Added option to plot by depth, and function to get depth of all exps --- .../CalculateReceptiveFields.m | 1 + .../RunAnalysisClass.asv | 213 ++++++++ visualStimulationAnalysis/RunAnalysisClass.m | 7 +- .../SpatialTuningIndex.m | 2 +- visualStimulationAnalysis/getNeuronDepths.m | 107 ++++ visualStimulationAnalysis/plotPSTH_MultiExp.m | 439 ++++++++--------- .../plotPSTH_MultiExpV1.m | 463 ++++++++++++++++++ 7 files changed, 1007 insertions(+), 225 deletions(-) create mode 100644 visualStimulationAnalysis/RunAnalysisClass.asv create mode 100644 visualStimulationAnalysis/getNeuronDepths.m create mode 100644 visualStimulationAnalysis/plotPSTH_MultiExpV1.m diff --git a/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m b/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m index 7465c3e..7080102 100644 --- a/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m +++ b/visualStimulationAnalysis/@rectGridAnalysis/CalculateReceptiveFields.m @@ -284,6 +284,7 @@ trialCount = zeros(nGrid, nGrid, nSize, nLums); jj = 1; + for i = 1:trialDiv:nT xBin = discretize(XcStore(jj), xEdges); diff --git a/visualStimulationAnalysis/RunAnalysisClass.asv b/visualStimulationAnalysis/RunAnalysisClass.asv new file mode 100644 index 0000000..e674375 --- /dev/null +++ b/visualStimulationAnalysis/RunAnalysisClass.asv @@ -0,0 +1,213 @@ +cd('\\sil3\data\Large_scale_mapping_NP') +excelFile = 'Experiment_Excel.xlsx'; + +data = readtable(excelFile); + +%% +%% Rect Grid +for ex = 52 %84:91 + NP = loadNPclassFromTable(ex); %73 81 + vsRe = rectGridAnalysis(NP); + % vsRe.getSessionTime("overwrite",true); + % %vsRe.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % vsRe.getDiodeTriggers('overwrite',true); + % vsRe.getSyncedDiodeTriggers("overwrite",true); + % % vsRe.plotSpatialTuningSpikes; + % % vsRe.plotSpatialTuningLFP; + % vsRe.ResponseWindow('overwrite',true) + % results = vsRe.ShufflingAnalysis('overwrite',true); + % vsRe.plotRaster(MergeNtrials=1,overwrite=true,AllResponsiveNeurons = true, selectedLum=[],oneTrial = true,PaperFig = true) %43 + % close all;vsRe.plotRaster(MergeNtrials=1,overwrite=true,exNeurons=18, selectedLum=255,oneTrial = true,PaperFig = true) %43 + vsRe.CalculateReceptiveFields('overwrite',true) + %[colorbarLims] = vsRe.PlotReceptiveFields(exNeurons=18,allStimParamsCombined=true,PaperFig=true,overwrite=true); + %result = vsRe.BootstrapPerNeuron('overwrite',true); + +end +% vsRe.CalculateReceptiveFields +% vsRe.PlotReceptiveFields("meanAllNeurons",true) + +%% Moving ball + +for ex = [84:97]%97 74:84 (Neurons, 96_74, ) + ex = 84 + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBallAnalysis(NP,Session=1); + % vs.getSessionTime("overwrite",true); + % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % % %vs.plotDiodeTriggers + % vs.getSyncedDiodeTriggers("overwrite",true); + % % %vs.plotSpatialTuningSpikes; + % r = vs.ResponseWindow('overwrite',true); + % results = vs.ShufflingAnalysis('overwrite',true); + % % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'MergeNtrials',3) + % %vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'MergeNtrials',2,'bin',5,'GaussianLength',30,'MaxVal_1', false) + % vs.plotRaster('AllSomaticNeurons',true,'overwrite',true,'speed',2,'MergeNtrials',3) + %vs.plotRaster('exNeurons',82,'overwrite',true,'MergeNtrials',1,'OneDirection','up','OneLuminosity','white','PaperFig',true) + % % %vs.plotCorrSpikePattern + % vs.plotRaster('AllResponsiveNeurons',true,'overwrite',true,'OneDirection','up','OneLuminosity','white','MergeNtrials',1,'PaperFig',true) + + %vs.plotRaster('exNeurons',9,'AllResponsiveNeurons',false,'overwrite',true,'MergeNtrials',3,MaxVal_1=false) + vs.CalculateReceptiveFields('overwrite',true,testConvolution=false); + % colorbarLims=vs.PlotReceptiveFields('exNeurons',82,'overwrite',true,'OneDirection','up','OneLuminosity','white','PaperFig',true); + %result = vs.BootstrapPerNeuron('overwrite',true);%('overwrite',true); + % pvals0_6Filter =result.Speed2.pvalsResponse'; + % compare = [pvals,pvalsNoFilt,pvals0_6Filter]; +end + +%% PlotZScoreComparison +%[49:54 57:81] MBR all experiments 'NV','NI' +%[44:56,64:88] All experiments +%[28:32,44,45,47,48,56,98] All SA experiments +%Check triggers 45, SA82 44,45,47:54,56,64:88 +% All stim: 'FFF','SDG','MBR','MB','RG','NI','NV' +%[49:54,64:97] %All PV good experiments +% %%[89,90,92,93,95,96,97] %Al NV and NI experiments +%[49:54,84:90,92:96] %All SDG experiments +%solve MBR +%bootsrapRespBase +VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... + overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR +%% PSTH for all experiments +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); + +%% Calculate spatial tuning +SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) + +%% Get neuron depths +getNeuronDepths([49:54,64:72,84:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates +%% Gratings + +for ex = [54 84:90] + NP = loadNPclassFromTable(ex); %73 81 + vs = StaticDriftingGratingAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + result = vs.BootstrapPerNeuron('overwrite',true); +end + +%% movie + +for ex = [89,90,92,93,95:97] + NP = loadNPclassFromTable(ex); %73 81 + vs = movieAnalysis(NP); + % vs.getSessionTime("overwrite",true); + % vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + % dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + %vs.getSyncedDiodeTriggers("overwrite",true); + %r = vs.ResponseWindow('overwrite',true); + %results = vs.ShufflingAnalysis('overwrite',true); + vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) +end + + +%% image + +for ex = [89,90,92,93,95:97] + NP = loadNPclassFromTable(ex); %73 81 + vs = imageAnalysis(NP); + %vs.getSessionTime("overwrite",true); + %vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %dT = vs.getDiodeTriggers; + % vs.plotDiodeTriggers + %vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + %results = vs.ShufflingAnalysis('overwrite',true); + vs.plotRaster('AllResponsiveNeurons',true,MergeNtrials=1,overwrite=true) + +end + + +%% Moving bar +for ex = 81 + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBarAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + +%% FFF +for ex = 56 + NP = loadNPclassFromTable(ex); %73 81 + vs = fullFieldFlashAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + + +%% Run for all +for ex = 85:88 + NP = loadNPclassFromTable(ex); %73 81 + vs = linearlyMovingBallAnalysis(NP); + vs.getSessionTime("overwrite",true); + vs.getDiodeTriggers('extractionMethod','digitalTriggerDiode','overwrite',true); + %vs.plotDiodeTriggers + vs.getSyncedDiodeTriggers("overwrite",true); + r = vs.ResponseWindow('overwrite',true); + results = vs.ShufflingAnalysis('overwrite',true); + if ~any(results.Speed1.pvalsResponse<0.05) + fprintf('%d-No responsive neurons.\n',ex) + continue + end + vs.CalculateReceptiveFields('overwrite',true,'nShuffle',20); + vs.PlotReceptiveFields('overwrite',true,'RFsDivision',{'Directions','','Luminosities'},meanAllNeurons=true) +end + +%% Check experiments in timseseries viewer +timeSeriesViewer(NP) +t=NP.getTrigger; +data.VS_ordered(ex) + +stimOn = t{3}; +stimOff = t{4}; + +MBRtOn = stimOn(stimOn > t{1}(1) & stimOn < t{2}(1)); +MBRtOff = stimOff(stimOff > t{1}(1) & stimOff < t{2}(1)); + +MBtOn = stimOn(stimOn > t{1}(2) & stimOn < t{2}(2)); +MBtOff = stimOff(stimOff > t{1}(2) & stimOff < t{2}(2)); + +RGtOn = stimOn(stimOn > t{1}(3) & stimOn < t{2}(3)); +RGtOff = stimOff(stimOff > t{1}(3) & stimOff < t{2}(3)); + +NGtOn = stimOn(stimOn > t{1}(4) & stimOn < t{2}(4)); +NGtOff = stimOff(stimOff > t{1}(4) & stimOff < t{2}(4)); + +DtOn = stimOn(stimOn > t{1}(5) & stimOn < t{2}(5)); +DtOff = stimOff(stimOff > t{1}(5) & stimOff < t{2}(5)); + +MovingBallTriggersDiode = d3.stimOnFlipTimes; + + + +%% %% check neural data sync and analog data sync + +allTimes = [stimOn(:); stimOff(:); onSync(:); offSync(:)]; % concatenate as column + +% Sort from earliest to latest +sortedTimesDiodeOldMethod = sort(allTimes); diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index 9332153..020bac3 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -29,6 +29,7 @@ %% Moving ball for ex = [84:97]%97 74:84 (Neurons, 96_74, ) + ex = 84 NP = loadNPclassFromTable(ex); %73 81 vs = linearlyMovingBallAnalysis(NP,Session=1); % vs.getSessionTime("overwrite",true); @@ -67,11 +68,13 @@ VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); +plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true); %% Calculate spatial tuning -SpatialTuningIndex([49:54,64:97], indexType = "L_geometric",overwrite=true) +SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) +%% Get neuron depths +getNeuronDepths([49:54,64:72,84:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates %% Gratings for ex = [54 84:90] diff --git a/visualStimulationAnalysis/SpatialTuningIndex.m b/visualStimulationAnalysis/SpatialTuningIndex.m index 98bf534..5a80b33 100644 --- a/visualStimulationAnalysis/SpatialTuningIndex.m +++ b/visualStimulationAnalysis/SpatialTuningIndex.m @@ -8,7 +8,7 @@ params.statType string = "BootstrapPerNeuron" params.speed double = 1 params.plot logical = true - params.indexType string = "L_amplitude" % L_amplitude, L_geometric, L_combined + params.indexType string = "L_amplitude" % L_amplitude_diff,L_amplitude_ratio, L_geometric, L_combined params.onOff double = 1 % 1=on, 2=off (rectGrid only) params.sizeIdx double = 1 params.lumIdx double = 1 diff --git a/visualStimulationAnalysis/getNeuronDepths.m b/visualStimulationAnalysis/getNeuronDepths.m new file mode 100644 index 0000000..35f98bd --- /dev/null +++ b/visualStimulationAnalysis/getNeuronDepths.m @@ -0,0 +1,107 @@ +function [result] = getNeuronDepths(exList) +% getNeuronDepths Returns cortical depths of good units across all experiments, +% and computes 3 globally-defined equal depth bins. +% +% Inputs: +% exList - vector of experiment numbers (same as used in plotPSTH_MultiExp) +% +% Outputs: +% result - struct with fields: +% .depthTable - table with columns: Experiment, Unit, Depth_um +% .depthBinEdges - 1x4 vector [min, t1, t2, max] in um +% .perExp - struct array with per-experiment data: +% .exNum, .goodU, .p_sort + +% ------------------------------------------------------------------ +% Load Excel once +% ------------------------------------------------------------------ +excelPath = '\\sil3\data\Large_scale_mapping_NP\Experiment_Excel.xlsx'; +T = readtable(excelPath); + +% ------------------------------------------------------------------ +% Preallocate collections +% ------------------------------------------------------------------ +expCol = []; % experiment number per unit +unitCol = []; % unit index (1-based) per unit +depthCol = []; % depth in um per unit + +result.perExp(numel(exList)) = struct('exNum', [], 'goodU', [], 'p_sort', []); + +% ------------------------------------------------------------------ +% Loop over experiments +% ------------------------------------------------------------------ +for ei = 1:numel(exList) + + ex = exList(ei); + fprintf('Loading experiment %d ...\n', ex); + + try + NP = loadNPclassFromTable(ex); + obj = linearlyMovingBallAnalysis(NP); + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + result.perExp(ei).exNum = ex; + result.perExp(ei).goodU = []; + result.perExp(ei).p_sort = []; + continue + end + + % coor_Z for this experiment + coor_Z = T.coor_Z(ex); + + % Good units + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); % nTimePoints x nGoodUnits + nGood = size(goodU, 2); + + % Channel IDs (0-based) → Y positions → real depths + channelIDs = goodU(1, :); % 1 x nGoodUnits, 0-based + yPos = NP.chLayoutPositions(2, channelIDs + 1); % 1 x nGoodUnits + neuronDepths = coor_Z - yPos; % 1 x nGoodUnits, in um + + % Accumulate table columns + expCol = [expCol, repmat(ex, 1, nGood)]; + unitCol = [unitCol, 1:nGood ]; + depthCol = [depthCol, neuronDepths ]; + + % Store per-experiment data + result.perExp(ei).exNum = ex; + result.perExp(ei).goodU = goodU; + result.perExp(ei).p_sort = p_sort; + + fprintf(' coor_Z = %.0f um | Good units: %d | Depth range: %.0f - %.0f um\n', ... + coor_Z, nGood, min(neuronDepths), max(neuronDepths)); + +end + +% ------------------------------------------------------------------ +% Build table +% ------------------------------------------------------------------ +result.depthTable = table(expCol(:), unitCol(:), depthCol(:), ... + 'VariableNames', {'Experiment', 'Unit', 'Depth_um'}); + +% ------------------------------------------------------------------ +% Global depth bins +% ------------------------------------------------------------------ +dMin = min(depthCol); +dMax = max(depthCol); +step = (dMax - dMin) / 3; + +result.depthBinEdges = [dMin, dMin+step, dMin+2*step, dMax]; + +fprintf('\nGlobal depth range: %.0f - %.0f um\n', dMin, dMax); +fprintf('Depth bins:\n'); +fprintf(' Bin 1 (shallow) : %.0f - %.0f um\n', result.depthBinEdges(1), result.depthBinEdges(2)); +fprintf(' Bin 2 (middle) : %.0f - %.0f um\n', result.depthBinEdges(2), result.depthBinEdges(3)); +fprintf(' Bin 3 (deep) : %.0f - %.0f um\n', result.depthBinEdges(3), result.depthBinEdges(4)); + +% ------------------------------------------------------------------ +% Save to disk +% ------------------------------------------------------------------ + +n = extractBefore(obj.getAnalysisFileName,'lizards'); +saveName = [n 'lizards' filesep 'Combined_lizard_analysis' filesep 'NeuronDepths.mat']; +save(saveName, '-struct', 'result'); +fprintf('\nSaved to: %s\n', saveName); +end \ No newline at end of file diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.m b/visualStimulationAnalysis/plotPSTH_MultiExp.m index 24e42dc..ec92d0d 100644 --- a/visualStimulationAnalysis/plotPSTH_MultiExp.m +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.m @@ -9,24 +9,41 @@ function plotPSTH_MultiExp(exList, params) params.speed string = "max" params.alpha double = 0.05 params.shadeSTD logical = true - params.postStim double = 500 % ms after stim onset to include - params.preBase double = 200 % ms of baseline before stim onset - params.overwrite logical = false % force recompute even if file exists - params.TakeTopPercentTrials double = 0.3 %Percentage of highest spiking rate trials to take to calculate PSTHs - params.zScore logical = false % normalize firing rate to z-score using baseline - params.PaperFig logical = false %Is this going to be used in the paper? + params.postStim double = 500 + params.preBase double = 200 + params.overwrite logical = false + params.TakeTopPercentTrials double = 0.3 + params.zScore logical = false + params.PaperFig logical = false + params.byDepth logical = false % plot 3 depth bins per stim type end % ------------------------------------------------------------------------- -% Build save path using first experiment to get the analysis folder -% This mirrors the convention used in PlotZScoreComparison +% Load depth info from saved file (only if byDepth is requested) % ------------------------------------------------------------------------- +if params.byDepth + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + if ~exist(depthFile, 'file') + error('NeuronDepths.mat not found. Run getNeuronDepths() first.'); + end + D = load(depthFile); + depthTable = D.depthTable; + depthBinEdges = D.depthBinEdges; + nDepthBins = 3; + fprintf('Depth bins loaded:\n'); + fprintf(' Bin 1 (shallow): %.0f - %.0f um\n', depthBinEdges(1), depthBinEdges(2)); + fprintf(' Bin 2 (middle) : %.0f - %.0f um\n', depthBinEdges(2), depthBinEdges(3)); + fprintf(' Bin 3 (deep) : %.0f - %.0f um\n', depthBinEdges(3), depthBinEdges(4)); +else + nDepthBins = 1; +end -% Load first experiment just to get the folder path +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- NP_first = loadNPclassFromTable(exList(1)); -vs_first = linearlyMovingBallAnalysis(NP_first); % used only for path +vs_first = linearlyMovingBallAnalysis(NP_first); -% Build the save directory path p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); p = [p 'lizards']; if ~exist([p '\Combined_lizard_analysis'], 'dir') @@ -35,79 +52,66 @@ function plotPSTH_MultiExp(exList, params) end saveDir = [p '\Combined_lizard_analysis']; -% Build filename — includes stim types so different comparisons don't clash -stimLabel = strjoin(params.stimTypes, '-'); % e.g. "rectGrid-linearlyMovingBall" -nameOfFile = sprintf('\\Ex_%d-%d_Combined_PSTHs_%s.mat', ... - exList(1), exList(end), stimLabel); +stimLabel = strjoin(params.stimTypes, '-'); +depthSuffix = ''; +if params.byDepth; depthSuffix = '_byDepth'; end +nameOfFile = sprintf('\\Ex_%d-%d_Combined_PSTHs_%s%s.mat', ... + exList(1), exList(end), stimLabel, depthSuffix); % ------------------------------------------------------------------------- -% Decide whether to run the experiment loop or load from disk -% forloop = true → compute PSTHs from scratch -% forloop = false → load saved struct and skip to plotting +% Decide whether to recompute or load % ------------------------------------------------------------------------- if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite - % File exists and overwrite is off — check if expList matches S = load([saveDir nameOfFile]); if isequal(S.expList, exList) fprintf('Loading saved PSTHs from:\n %s\n', [saveDir nameOfFile]); - forloop = false; % skip computation, go straight to plot + forloop = false; else fprintf('Experiment list mismatch — recomputing.\n'); - forloop = true; % expList changed, recompute + forloop = true; end else - forloop = true; % file doesn't exist or overwrite requested + forloop = true; end % ========================================================================= -% EXPERIMENT LOOP — only runs if forloop is true +% EXPERIMENT LOOP % ========================================================================= if forloop nStim = numel(params.stimTypes); nExp = numel(exList); - % One cell per stim type, grows one row per experiment - psthAll = cell(1, nStim); - for s = 1:nStim - psthAll{s} = []; - end + % psthAll{s,b} — s = stim type, b = depth bin (1 if byDepth is off) + psthAll = cell(nStim, nDepthBins); - % Locked time window — set from first valid experiment - lockedPreBase = []; - lockedNBins = []; - lockedEdges = []; + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; - % ------------------------------------------------------------------ - % LOOP OVER EXPERIMENTS - % ------------------------------------------------------------------ for ei = 1:nExp ex = exList(ei); fprintf('\n=== Experiment %d ===\n', ex); - % Load NP data for this experiment try NP = loadNPclassFromTable(ex); catch ME warning('Could not load experiment %d: %s', ex, ME.message); - % Add NaN placeholder row if window is already locked for s = 1:nStim - if ~isempty(psthAll{s}) - psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end end end continue end - % -------------------------------------------------------------- - % LOOP OVER STIMULUS TYPES - % -------------------------------------------------------------- for s = 1:nStim stimType = params.stimTypes(s); - % Build analysis object for this stim type try switch stimType case "rectGrid" @@ -123,71 +127,54 @@ function plotPSTH_MultiExp(exList, params) end catch ME warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); - if ~isempty(psthAll{s}) - psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end end continue end - % ---------------------------------------------------------- - % Extract data structures - % ---------------------------------------------------------- - - % ResponseWindow holds trial timing and spike data NeuronResp = obj.ResponseWindow; - % Stats struct for p-values if params.statType == "BootstrapPerNeuron" Stats = obj.BootstrapPerNeuron; else Stats = obj.ShufflingAnalysis; end - % Resolve speed field name if params.speed ~= "max" && isequal(obj.stimName,'linearlyMovingBall') - fieldName = 'Speed2'; - startStim = 0; + fieldName = 'Speed2'; startStim = 0; elseif isequal(obj.stimName,'linearlyMovingBall') - fieldName = 'Speed1'; - startStim = 0; + fieldName = 'Speed1'; startStim = 0; elseif isequal(params.stimTypes,'StaticGrating') - fieldName = 'Static'; - startStim = 0; - + fieldName = 'Static'; startStim = 0; elseif isequal(params.stimTypes,'MovingGrating') - startStim = obj.VST.static_time*1000; - fieldName = 'Moving'; + startStim = obj.VST.static_time*1000; fieldName = 'Moving'; else startStim = 0; end - % Spike trains of somatic (good) units p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); label = string(p_sort.label'); goodU = p_sort.ic(:, label == 'good'); - % P-values for each unit try - pvals = Stats.(fieldName).pvalsResponse; + pvals = Stats.(fieldName).pvalsResponse; catch - pvals = Stats.pvalsResponse; + pvals = Stats.pvalsResponse; end - % Trial onset times in ms try - C = NeuronResp.(fieldName).C; + C = NeuronResp.(fieldName).C; catch C = NeuronResp.C; end directimesSorted = C(:, 1)' + startStim; - % Use params.preBase directly — no formula needed - preBase = params.preBase; - - % Total trial window = baseline + post-stim period + preBase = params.preBase; windowTotal = preBase + params.postStim; - % Lock in time window from first valid experiment if isempty(lockedPreBase) lockedPreBase = preBase; lockedEdges = 0 : params.binWidth : windowTotal; @@ -197,125 +184,143 @@ function plotPSTH_MultiExp(exList, params) lockedPreBase, params.postStim, lockedNBins); end - % ---------------------------------------------------------- - % Find responsive neurons - % ---------------------------------------------------------- eNeurons = find(pvals < params.alpha); if isempty(eNeurons) fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); - if ~isempty(psthAll{s}) - psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end end continue end - fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', ... - stimType, ex, numel(eNeurons)); + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, ex, numel(eNeurons)); % ---------------------------------------------------------- - % Build PSTH for each responsive neuron - % BuildBurstMatrix returns nTrials x 1 x nTimeBins - % Window: from (trialOnset - preBase) for windowTotal ms + % Build PSTH per neuron % ---------------------------------------------------------- psthRateNeurons = zeros(numel(eNeurons), lockedNBins); + neuronBinIdx = zeros(numel(eNeurons), 1); for ni = 1:numel(eNeurons) u = eNeurons(ni); - % Spike matrix: rows = trials, cols = time bins (1ms each) + % Assign depth bin + if params.byDepth + depthRow = depthTable.Experiment == ex & depthTable.Unit == u; + if ~any(depthRow) + neuronBinIdx(ni) = 0; % unknown depth — skip + continue + end + unitDepth = depthTable.Depth_um(depthRow); + if unitDepth <= depthBinEdges(2) + neuronBinIdx(ni) = 1; + elseif unitDepth <= depthBinEdges(3) + neuronBinIdx(ni) = 2; + else + neuronBinIdx(ni) = 3; + end + else + neuronBinIdx(ni) = 1; % all neurons in single bin + end + MRhist = BuildBurstMatrix( ... goodU(:, u), ... round(p_sort.t), ... round(directimesSorted - lockedPreBase), ... round(windowTotal)); + MRhist = squeeze(MRhist); - - - % Remove singleton dimensions → nTrials x nTimeBins - MRhist = squeeze(MRhist); - - if ~isempty(params.TakeTopPercentTrials) - MeanTrial = mean(MRhist,2); - [~, ind] = sort(MeanTrial,'descend'); - - takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); - - MRhist = MRhist(takeTrials,:); - + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist, 2); + [~, ind] = sort(MeanTrial, 'descend'); + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + MRhist = MRhist(takeTrials, :); end - nTrials = size(MRhist, 1); - % Convert to spike times in ms - spikeTimes = repmat((1:size(MRhist, 2)), nTrials, 1); + nTrials = size(MRhist, 1); + spikeTimes = repmat((1:size(MRhist,2)), nTrials, 1); spikeTimes = spikeTimes(logical(MRhist)); - - % Bin into locked edges and convert to spk/s - counts = histcounts(spikeTimes, lockedEdges); + counts = histcounts(spikeTimes, lockedEdges); psthRateNeurons(ni, :) = (counts / (params.binWidth * nTrials)) * 1000; end - % Average across responsive neurons → 1 x lockedNBins - psthExp = mean(psthRateNeurons, 1, 'omitnan'); + % ---------------------------------------------------------- + % Average per depth bin and append + % ---------------------------------------------------------- + for b = 1:nDepthBins + binNeurons = neuronBinIdx == b; + if ~any(binNeurons) + fprintf(' [%s] No neurons in depth bin %d for exp %d.\n', stimType, b, ex); + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + continue + end - if params.zScore - baselineBins = tAxis < lockedPreBase; - baselineMean = mean(psthExp(baselineBins)); - baselineStd = std(psthExp(baselineBins)); - if baselineStd > 0 - psthExp = (psthExp - baselineMean) / baselineStd; - else - warning(' [%s] Baseline std is zero for exp %d — skipping experiment.', stimType, ex); - if ~isempty(psthAll{s}) - psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + psthExp = mean(psthRateNeurons(binNeurons, :), 1, 'omitnan'); + + if params.zScore + baselineBins = tAxis < lockedPreBase; + baselineMean = mean(psthExp(baselineBins)); + baselineStd = std(psthExp(baselineBins)); + if baselineStd > 0 + psthExp = (psthExp - baselineMean) / baselineStd; + else + warning(' [%s] Bin %d: baseline std is zero for exp %d.', stimType, b, ex); + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + continue end - continue % skip to next experiment, do not append raw rates end - end - % Append as new row — guaranteed lockedNBins wide - psthAll{s} = [psthAll{s}; psthExp(:)']; + psthAll{s,b} = [psthAll{s,b}; psthExp(:)']; + fprintf(' [%s] Bin %d: %d neuron(s) in exp %d.\n', stimType, b, sum(binNeurons), ex); + end - end % end stim loop - end % end experiment loop + end % stim loop + end % experiment loop % ------------------------------------------------------------------ - % Save results to struct + % Save % ------------------------------------------------------------------ - S.expList = exList; % experiment list for future matching - S.lockedEdges = lockedEdges; % bin edges used (ms from trial start) - S.lockedPreBase = lockedPreBase; % baseline duration in ms - S.params = params; % all parameters used + S.expList = exList; + S.lockedEdges = lockedEdges; + S.lockedPreBase = lockedPreBase; + S.params = params; - % Save one field per stim type, named by stim e.g. S.rectGrid for s = 1:numel(params.stimTypes) - stimField = matlab.lang.makeValidName(params.stimTypes(s)); % safe field name - S.(stimField) = psthAll{s}; % nExp x nBins PSTH matrix + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + for b = 1:nDepthBins + S.(sprintf('%s_bin%d', stimField, b)) = psthAll{s,b}; + end end save([saveDir nameOfFile], '-struct', 'S'); fprintf('\nSaved PSTHs to:\n %s\n', [saveDir nameOfFile]); else - % ------------------------------------------------------------------ - % Load psthAll from saved struct - % ------------------------------------------------------------------ + % Load psthAll from disk lockedEdges = S.lockedEdges; lockedPreBase = S.lockedPreBase; - psthAll = cell(1, numel(params.stimTypes)); + psthAll = cell(numel(params.stimTypes), nDepthBins); for s = 1:numel(params.stimTypes) - stimField = matlab.lang.makeValidName(params.stimTypes(s)); - if isfield(S, stimField) - psthAll{s} = S.(stimField); % load the nExp x nBins matrix - else - % Stim type not found in saved file — warn and leave empty - warning('Stim type "%s" not found in saved file.', params.stimTypes(s)); - psthAll{s} = []; + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + for b = 1:nDepthBins + fieldKey = sprintf('%s_bin%d', stimField, b); + if isfield(S, fieldKey) + psthAll{s,b} = S.(fieldKey); + else + warning('Field "%s" not found in saved file.', fieldKey); + psthAll{s,b} = []; + end end end - -end % end forloop +end % ========================================================================= % PLOT @@ -324,43 +329,37 @@ function plotPSTH_MultiExp(exList, params) tAxis = lockedEdges(1:end-1); tAxisPlot = tAxis - lockedPreBase; -colors = lines(numel(params.stimTypes)); - -fig = figure; -set(fig, 'Units', 'centimeters', 'Position', [5 5 9 10]); % single axis now +baseColors = lines(numel(params.stimTypes)); +depthShades = [0.6, 0.35, 0.1]; % light → dark for shallow → deep +binLabels = {'shallow', 'middle', 'deep'}; -% ------------------------------------------------------------------ -% Map stimulus type names to short legend labels -% ------------------------------------------------------------------ stimLegendMap = containers.Map(... {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... {'MB', 'SB', 'MG', 'SG'}); % ------------------------------------------------------------------ -% First pass: compute mean/sem for all stim types and find global ylim +% First pass: global ylim % ------------------------------------------------------------------ -meanAll = cell(1, numel(params.stimTypes)); -semAll = cell(1, numel(params.stimTypes)); -yMax = 0; -yMin = inf; +yMax = 0; +yMin = inf; + +meanAll = cell(numel(params.stimTypes), nDepthBins); +semAll = cell(numel(params.stimTypes), nDepthBins); for s = 1:numel(params.stimTypes) - data = psthAll{s}; - if isempty(data) - continue + for b = 1:nDepthBins + data = psthAll{s,b}; + if isempty(data); continue; end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data); continue; end + meanAll{s,b} = mean(data, 1, 'omitnan'); + semAll{s,b} = std(data, 0, 1, 'omitnan') / sqrt(sum(~isnan(data(:,1)))); + yMax = max(yMax, max(meanAll{s,b} + semAll{s,b})); + yMin = min(yMin, min(meanAll{s,b} - semAll{s,b})); end - validRows = ~all(isnan(data), 2); - data = data(validRows, :); - if isempty(data) - continue - end - meanAll{s} = mean(data, 1, 'omitnan'); - semAll{s} = std(data, 0, 1, 'omitnan') / sqrt(sum(~isnan(data(:,1)))); - yMax = max(yMax, max(meanAll{s} + semAll{s})); - yMin = min(yMin, min(meanAll{s} - semAll{s})); end -% Y limits with 10% padding yPad = (yMax - yMin) * 0.1; if params.zScore yLims = [yMin - yPad, yMax + yPad]; @@ -369,94 +368,90 @@ function plotPSTH_MultiExp(exList, params) end % ------------------------------------------------------------------ -% Single axis plot — all stim types overlaid +% Plot % ------------------------------------------------------------------ +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 9 10]); ax = axes(fig); hold(ax, 'on'); -legendHandles = gobjects(numel(params.stimTypes), 1); % store line handles for legend +legendHandles = []; +legendLabels = {}; for s = 1:numel(params.stimTypes) - data = psthAll{s}; - if isempty(data) - continue - end - validRows = ~all(isnan(data), 2); - data = data(validRows, :); - if isempty(data) - continue - end - - meanPSTH = meanAll{s}; - semPSTH = semAll{s}; - - % Get short legend label for this stim type stimKey = char(params.stimTypes(s)); if isKey(stimLegendMap, stimKey) - legendLabel = stimLegendMap(stimKey); + shortName = stimLegendMap(stimKey); else - legendLabel = stimKey; % fallback to full name if not in map + shortName = stimKey; end - % Shade ±SEM band - if params.shadeSTD && size(data, 1) > 1 - upper = meanPSTH + semPSTH; - lower = meanPSTH - semPSTH; - xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; - yFill = [upper(:)', fliplr(lower(:)') ]; - fill(ax, xFill, yFill, colors(s,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); - end + for b = 1:nDepthBins + + data = psthAll{s,b}; + if isempty(data); continue; end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data); continue; end + + meanPSTH = meanAll{s,b}; + semPSTH = semAll{s,b}; + + % Color and label depend on mode + if params.byDepth + lineColor = baseColors(s,:) * (1 - depthShades(b)); + legendLabel = sprintf('%s %s (%.0f-%.0f um)', ... + shortName, binLabels{b}, depthBinEdges(b), depthBinEdges(b+1)); + else + lineColor = baseColors(s,:); + legendLabel = shortName; + end + + % SEM shading + if params.shadeSTD && size(data,1) > 1 + upper = meanPSTH + semPSTH; + lower = meanPSTH - semPSTH; + xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; + yFill = [upper(:)', fliplr(lower(:)') ]; + fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.15, 'EdgeColor', 'none'); + end - % Mean PSTH line — store handle for legend - legendHandles(s) = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... - 'Color', colors(s,:), 'LineWidth', 1.5, 'DisplayName', legendLabel); + % Mean line + h = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... + 'Color', lineColor, 'LineWidth', 1.5); - % Number of contributing experiments as text - nValid = sum(validRows); - fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, nValid); + legendHandles(end+1) = h; %#ok + legendLabels{end+1} = legendLabel; %#ok + fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, sum(validRows)); + end end -% Stim onset and end of post-stim window xline(ax, 0, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); xline(ax, params.postStim, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); -% Y label -if params.zScore - yLabel = 'Z-score'; -else - yLabel = '[spk/s]'; -end +if params.zScore; yLabel = 'Z-score'; else; yLabel = '[spk/s]'; end xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); ylabel(ax, yLabel, 'FontName', 'helvetica', 'FontSize', 8); xlim(ax, [tAxisPlot(1) tAxisPlot(end)]); ylim(ax, yLims); -% Legend — only show valid handles (skip stim types with no data) -validHandles = legendHandles(isgraphics(legendHandles)); -legend(validHandles, 'Location', 'northeast', 'FontName', 'helvetica', 'FontSize', 8); +legend(legendHandles, legendLabels, 'Location', 'northeast', ... + 'FontName', 'helvetica', 'FontSize', 7); -ax.FontName = 'helvetica'; -ax.FontSize = 8; -hold(ax, 'off'); - -sgtitle(sprintf('N = %d', numel(exList)), 'FontName', 'helvetica', 'FontSize', 11); - -ax = gca; +ax.FontName = 'helvetica'; +ax.FontSize = 8; ax.YAxis.FontSize = 8; -ax.YAxis.FontName = 'helvetica'; - -ax = gca; ax.XAxis.FontSize = 8; -ax.XAxis.FontName = 'helvetica'; +hold(ax, 'off'); -set(fig, 'Units', 'centimeters'); -set(fig, 'Position', [20 20 5 6]); +sgtitle(sprintf('N = %d', numel(exList)), 'FontName', 'helvetica', 'FontSize', 11); +set(fig, 'Units', 'centimeters', 'Position', [20 20 8 6]); if params.PaperFig - vs_first.printFig(fig, sprintf('PSTH-comparison-%s-%s', ... + vs_first.printFig(fig, sprintf('PSTH-depth-%s-%s', ... params.stimTypes(1), params.stimTypes(2)), PaperFig = params.PaperFig) end diff --git a/visualStimulationAnalysis/plotPSTH_MultiExpV1.m b/visualStimulationAnalysis/plotPSTH_MultiExpV1.m new file mode 100644 index 0000000..f9a404d --- /dev/null +++ b/visualStimulationAnalysis/plotPSTH_MultiExpV1.m @@ -0,0 +1,463 @@ +function plotPSTH_MultiExpV1(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.bin double = 30 + params.binWidth double = 10 + params.statType string = "BootstrapPerNeuron" + params.speed string = "max" + params.alpha double = 0.05 + params.shadeSTD logical = true + params.postStim double = 500 % ms after stim onset to include + params.preBase double = 200 % ms of baseline before stim onset + params.overwrite logical = false % force recompute even if file exists + params.TakeTopPercentTrials double = 0.3 %Percentage of highest spiking rate trials to take to calculate PSTHs + params.zScore logical = false % normalize firing rate to z-score using baseline + params.PaperFig logical = false %Is this going to be used in the paper? +end + +% ------------------------------------------------------------------------- +% Build save path using first experiment to get the analysis folder +% This mirrors the convention used in PlotZScoreComparison +% ------------------------------------------------------------------------- + +% Load first experiment just to get the folder path +NP_first = loadNPclassFromTable(exList(1)); +vs_first = linearlyMovingBallAnalysis(NP_first); % used only for path + +% Build the save directory path +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +% Build filename — includes stim types so different comparisons don't clash +stimLabel = strjoin(params.stimTypes, '-'); % e.g. "rectGrid-linearlyMovingBall" +nameOfFile = sprintf('\\Ex_%d-%d_Combined_PSTHs_%s.mat', ... + exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to run the experiment loop or load from disk +% forloop = true → compute PSTHs from scratch +% forloop = false → load saved struct and skip to plotting +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + % File exists and overwrite is off — check if expList matches + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved PSTHs from:\n %s\n', [saveDir nameOfFile]); + forloop = false; % skip computation, go straight to plot + else + fprintf('Experiment list mismatch — recomputing.\n'); + forloop = true; % expList changed, recompute + end +else + forloop = true; % file doesn't exist or overwrite requested +end + +% ========================================================================= +% EXPERIMENT LOOP — only runs if forloop is true +% ========================================================================= +if forloop + + nStim = numel(params.stimTypes); + nExp = numel(exList); + + % One cell per stim type, grows one row per experiment + psthAll = cell(1, nStim); + for s = 1:nStim + psthAll{s} = []; + end + + % Locked time window — set from first valid experiment + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; + + % ------------------------------------------------------------------ + % LOOP OVER EXPERIMENTS + % ------------------------------------------------------------------ + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + % Load NP data for this experiment + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + % Add NaN placeholder row if window is already locked + for s = 1:nStim + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + end + continue + end + + % -------------------------------------------------------------- + % LOOP OVER STIMULUS TYPES + % -------------------------------------------------------------- + for s = 1:nStim + + stimType = params.stimTypes(s); + + % Build analysis object for this stim type + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + case 'StaticGrating' + obj = StaticDriftingGratingAnalysis(NP); + case 'MovingGrating' + obj = StaticDriftingGratingAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue + end + + % ---------------------------------------------------------- + % Extract data structures + % ---------------------------------------------------------- + + % ResponseWindow holds trial timing and spike data + NeuronResp = obj.ResponseWindow; + + % Stats struct for p-values + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + % Resolve speed field name + if params.speed ~= "max" && isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed2'; + startStim = 0; + elseif isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed1'; + startStim = 0; + elseif isequal(params.stimTypes,'StaticGrating') + fieldName = 'Static'; + startStim = 0; + + elseif isequal(params.stimTypes,'MovingGrating') + startStim = obj.VST.static_time*1000; + fieldName = 'Moving'; + else + startStim = 0; + end + + % Spike trains of somatic (good) units + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + % P-values for each unit + try + pvals = Stats.(fieldName).pvalsResponse; + catch + pvals = Stats.pvalsResponse; + end + + % Trial onset times in ms + try + C = NeuronResp.(fieldName).C; + catch + C = NeuronResp.C; + end + directimesSorted = C(:, 1)' + startStim; + + % Use params.preBase directly — no formula needed + preBase = params.preBase; + + % Total trial window = baseline + post-stim period + windowTotal = preBase + params.postStim; + + % Lock in time window from first valid experiment + if isempty(lockedPreBase) + lockedPreBase = preBase; + lockedEdges = 0 : params.binWidth : windowTotal; + lockedNBins = numel(lockedEdges) - 1; + tAxis = lockedEdges(1:end-1); + fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... + lockedPreBase, params.postStim, lockedNBins); + end + + % ---------------------------------------------------------- + % Find responsive neurons + % ---------------------------------------------------------- + eNeurons = find(pvals < params.alpha); + + if isempty(eNeurons) + fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', ... + stimType, ex, numel(eNeurons)); + + % ---------------------------------------------------------- + % Build PSTH for each responsive neuron + % BuildBurstMatrix returns nTrials x 1 x nTimeBins + % Window: from (trialOnset - preBase) for windowTotal ms + % ---------------------------------------------------------- + psthRateNeurons = zeros(numel(eNeurons), lockedNBins); + + for ni = 1:numel(eNeurons) + u = eNeurons(ni); + + % Spike matrix: rows = trials, cols = time bins (1ms each) + MRhist = BuildBurstMatrix( ... + goodU(:, u), ... + round(p_sort.t), ... + round(directimesSorted - lockedPreBase), ... + round(windowTotal)); + + + + % Remove singleton dimensions → nTrials x nTimeBins + MRhist = squeeze(MRhist); + + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist,2); + [~, ind] = sort(MeanTrial,'descend'); + + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + + MRhist = MRhist(takeTrials,:); + + end + nTrials = size(MRhist, 1); + + % Convert to spike times in ms + spikeTimes = repmat((1:size(MRhist, 2)), nTrials, 1); + spikeTimes = spikeTimes(logical(MRhist)); + + % Bin into locked edges and convert to spk/s + counts = histcounts(spikeTimes, lockedEdges); + psthRateNeurons(ni, :) = (counts / (params.binWidth * nTrials)) * 1000; + end + + % Average across responsive neurons → 1 x lockedNBins + psthExp = mean(psthRateNeurons, 1, 'omitnan'); + + if params.zScore + baselineBins = tAxis < lockedPreBase; + baselineMean = mean(psthExp(baselineBins)); + baselineStd = std(psthExp(baselineBins)); + if baselineStd > 0 + psthExp = (psthExp - baselineMean) / baselineStd; + else + warning(' [%s] Baseline std is zero for exp %d — skipping experiment.', stimType, ex); + if ~isempty(psthAll{s}) + psthAll{s} = [psthAll{s}; NaN(1, lockedNBins)]; + end + continue % skip to next experiment, do not append raw rates + end + end + + % Append as new row — guaranteed lockedNBins wide + psthAll{s} = [psthAll{s}; psthExp(:)']; + + end % end stim loop + end % end experiment loop + + % ------------------------------------------------------------------ + % Save results to struct + % ------------------------------------------------------------------ + S.expList = exList; % experiment list for future matching + S.lockedEdges = lockedEdges; % bin edges used (ms from trial start) + S.lockedPreBase = lockedPreBase; % baseline duration in ms + S.params = params; % all parameters used + + % Save one field per stim type, named by stim e.g. S.rectGrid + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); % safe field name + S.(stimField) = psthAll{s}; % nExp x nBins PSTH matrix + end + + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved PSTHs to:\n %s\n', [saveDir nameOfFile]); + +else + % ------------------------------------------------------------------ + % Load psthAll from saved struct + % ------------------------------------------------------------------ + lockedEdges = S.lockedEdges; + lockedPreBase = S.lockedPreBase; + + psthAll = cell(1, numel(params.stimTypes)); + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + if isfield(S, stimField) + psthAll{s} = S.(stimField); % load the nExp x nBins matrix + else + % Stim type not found in saved file — warn and leave empty + warning('Stim type "%s" not found in saved file.', params.stimTypes(s)); + psthAll{s} = []; + end + end + +end % end forloop + +% ========================================================================= +% PLOT +% ========================================================================= + +tAxis = lockedEdges(1:end-1); +tAxisPlot = tAxis - lockedPreBase; + +colors = lines(numel(params.stimTypes)); + +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 9 10]); % single axis now + +% ------------------------------------------------------------------ +% Map stimulus type names to short legend labels +% ------------------------------------------------------------------ +stimLegendMap = containers.Map(... + {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... + {'MB', 'SB', 'MG', 'SG'}); + +% ------------------------------------------------------------------ +% First pass: compute mean/sem for all stim types and find global ylim +% ------------------------------------------------------------------ +meanAll = cell(1, numel(params.stimTypes)); +semAll = cell(1, numel(params.stimTypes)); +yMax = 0; +yMin = inf; + +for s = 1:numel(params.stimTypes) + data = psthAll{s}; + if isempty(data) + continue + end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data) + continue + end + meanAll{s} = mean(data, 1, 'omitnan'); + semAll{s} = std(data, 0, 1, 'omitnan') / sqrt(sum(~isnan(data(:,1)))); + yMax = max(yMax, max(meanAll{s} + semAll{s})); + yMin = min(yMin, min(meanAll{s} - semAll{s})); +end + +% Y limits with 10% padding +yPad = (yMax - yMin) * 0.1; +if params.zScore + yLims = [yMin - yPad, yMax + yPad]; +else + yLims = [max(0, yMin - yPad), yMax + yPad]; +end + +% ------------------------------------------------------------------ +% Single axis plot — all stim types overlaid +% ------------------------------------------------------------------ +ax = axes(fig); +hold(ax, 'on'); + +legendHandles = gobjects(numel(params.stimTypes), 1); % store line handles for legend + +for s = 1:numel(params.stimTypes) + + data = psthAll{s}; + if isempty(data) + continue + end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data) + continue + end + + meanPSTH = meanAll{s}; + semPSTH = semAll{s}; + + % Get short legend label for this stim type + stimKey = char(params.stimTypes(s)); + if isKey(stimLegendMap, stimKey) + legendLabel = stimLegendMap(stimKey); + else + legendLabel = stimKey; % fallback to full name if not in map + end + + % Shade ±SEM band + if params.shadeSTD && size(data, 1) > 1 + upper = meanPSTH + semPSTH; + lower = meanPSTH - semPSTH; + xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; + yFill = [upper(:)', fliplr(lower(:)') ]; + fill(ax, xFill, yFill, colors(s,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none'); + end + + % Mean PSTH line — store handle for legend + legendHandles(s) = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... + 'Color', colors(s,:), 'LineWidth', 1.5, 'DisplayName', legendLabel); + + % Number of contributing experiments as text + nValid = sum(validRows); + fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, nValid); + +end + +% Stim onset and end of post-stim window +xline(ax, 0, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); +xline(ax, params.postStim, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); + +% Y label +if params.zScore + yLabel = 'Z-score'; +else + yLabel = '[spk/s]'; +end + +xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); +ylabel(ax, yLabel, 'FontName', 'helvetica', 'FontSize', 8); +xlim(ax, [tAxisPlot(1) tAxisPlot(end)]); +ylim(ax, yLims); + +% Legend — only show valid handles (skip stim types with no data) +validHandles = legendHandles(isgraphics(legendHandles)); +legend(validHandles, 'Location', 'northeast', 'FontName', 'helvetica', 'FontSize', 8); + +ax.FontName = 'helvetica'; +ax.FontSize = 8; +hold(ax, 'off'); + +sgtitle(sprintf('N = %d', numel(exList)), 'FontName', 'helvetica', 'FontSize', 11); + +ax = gca; +ax.YAxis.FontSize = 8; +ax.YAxis.FontName = 'helvetica'; + +ax = gca; +ax.XAxis.FontSize = 8; +ax.XAxis.FontName = 'helvetica'; + +set(fig, 'Units', 'centimeters'); +set(fig, 'Position', [20 20 5 6]); + +if params.PaperFig + vs_first.printFig(fig, sprintf('PSTH-comparison-%s-%s', ... + params.stimTypes(1), params.stimTypes(2)), PaperFig = params.PaperFig) +end + +end \ No newline at end of file From 807f862bec81bdeb352570606c0867499d052338 Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Wed, 25 Mar 2026 01:40:55 +0200 Subject: [PATCH 7/8] updates o PSTH --- .../RunAnalysisClass.asv | 5 +- visualStimulationAnalysis/RunAnalysisClass.m | 5 +- .../plotPSTH_MultiExp.asv | 463 ++++++++++++++++++ visualStimulationAnalysis/plotPSTH_MultiExp.m | 15 +- .../plotPSTH_MultiExpV1.m | 2 + 5 files changed, 484 insertions(+), 6 deletions(-) create mode 100644 visualStimulationAnalysis/plotPSTH_MultiExp.asv diff --git a/visualStimulationAnalysis/RunAnalysisClass.asv b/visualStimulationAnalysis/RunAnalysisClass.asv index e674375..f372c34 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.asv +++ b/visualStimulationAnalysis/RunAnalysisClass.asv @@ -68,7 +68,10 @@ end VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false); +plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true, smooth=50, stimTypes=["linearlyMovingBall"]); + +%% +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=false); %% Calculate spatial tuning SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index 020bac3..f372c34 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -68,7 +68,10 @@ VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true); +plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true, smooth=50, stimTypes=["linearlyMovingBall"]); + +%% +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=false); %% Calculate spatial tuning SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.asv b/visualStimulationAnalysis/plotPSTH_MultiExp.asv new file mode 100644 index 0000000..58b8e5d --- /dev/null +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.asv @@ -0,0 +1,463 @@ +function plotPSTH_MultiExp(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.bin double = 30 + params.binWidth double = 10 + params.statType string = "BootstrapPerNeuron" + params.speed string = "max" + params.alpha double = 0.05 + params.shadeSTD logical = true + params.postStim double = 500 + params.preBase double = 200 + params.overwrite logical = false + params.TakeTopPercentTrials double = 0.3 + params.zScore logical = false + params.PaperFig logical = false + params.byDepth logical = false +end + +% ------------------------------------------------------------------------- +% Load depth info (only if byDepth requested) +% ------------------------------------------------------------------------- +if params.byDepth + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + if ~exist(depthFile, 'file') + error('NeuronDepths.mat not found. Run getNeuronDepths() first.'); + end + D = load(depthFile); + depthTable = D.depthTable; + depthBinEdges = D.depthBinEdges; + nDepthBins = 3; + fprintf('Depth bins loaded:\n'); + fprintf(' Bin 1 (shallow): %.0f - %.0f um\n', depthBinEdges(1), depthBinEdges(2)); + fprintf(' Bin 2 (middle) : %.0f - %.0f um\n', depthBinEdges(2), depthBinEdges(3)); + fprintf(' Bin 3 (deep) : %.0f - %.0f um\n', depthBinEdges(3), depthBinEdges(4)); +else + nDepthBins = 1; +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); +vs_first = linearlyMovingBallAnalysis(NP_first); + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +depthSuffix = ''; +if params.byDepth; depthSuffix = '_byDepth'; end +nameOfFile = sprintf('\\Ex_%d-%d_Combined_PSTHs_%s%s.mat', ... + exList(1), exList(end), stimLabel, depthSuffix); + +% ------------------------------------------------------------------------- +% Decide whether to recompute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved PSTHs from:\n %s\n', [saveDir nameOfFile]); + forloop = false; + else + fprintf('Experiment list mismatch — recomputing.\n'); + forloop = true; + end +else + forloop = true; +end + +% ========================================================================= +% EXPERIMENT LOOP +% ========================================================================= +if forloop + + nStim = numel(params.stimTypes); + nExp = numel(exList); + + psthAll = cell(nStim, nDepthBins); + + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; + tAxis = []; % FIX 3: initialise here so it is always defined + + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + % FIX 4: only append NaN rows if window is already locked + if ~isempty(lockedNBins) + for s = 1:nStim + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + end + end + end + continue + end + + for s = 1:nStim + + stimType = params.stimTypes(s); + + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + case "StaticGrating" + obj = StaticDriftingGratingAnalysis(NP); + case "MovingGrating" + obj = StaticDriftingGratingAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + if ~isempty(lockedNBins) + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + end + end + continue + end + + NeuronResp = obj.ResponseWindow; + + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + % FIX 1+2: initialise fieldName and use stimType (loop var) not params.stimTypes + fieldName = ''; + startStim = 0; + if params.speed ~= "max" && isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed2'; + elseif isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed1'; + elseif isequal(stimType, 'StaticGrating') % FIX 2 + fieldName = 'Static'; + elseif isequal(stimType, 'MovingGrating') % FIX 2 + fieldName = 'Moving'; + startStim = obj.VST.static_time * 1000; + end + + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + % Use fieldName if set, otherwise fall back to top-level fields + try + pvals = Stats.(fieldName).pvalsResponse; + catch + pvals = Stats.pvalsResponse; + end + + try + C = NeuronResp.(fieldName).C; + catch + C = NeuronResp.C; + end + directimesSorted = C(:, 1)' + startStim; + + preBase = params.preBase; + windowTotal = preBase + params.postStim; + + if isempty(lockedPreBase) + lockedPreBase = preBase; + lockedEdges = 0 : params.binWidth : windowTotal; + lockedNBins = numel(lockedEdges) - 1; + tAxis = lockedEdges(1:end-1); % FIX 3: set alongside lockedEdges + fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... + lockedPreBase, params.postStim, lockedNBins); + end + + eNeurons = find(pvals < params.alpha); + + if isempty(eNeurons) + fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + end + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, numel(eNeurons), ex); + + % ---------------------------------------------------------- + % Build PSTH per neuron + % ---------------------------------------------------------- + psthRateNeurons = zeros(numel(eNeurons), lockedNBins); + neuronBinIdx = zeros(numel(eNeurons), 1); + + for ni = 1:numel(eNeurons) + u = eNeurons(ni); + + % Assign depth bin + if params.byDepth + depthRow = depthTable.Experiment == ex & depthTable.Unit == u; + if ~any(depthRow) + neuronBinIdx(ni) = 0; % unknown — will be skipped + continue + end + unitDepth = depthTable.Depth_um(depthRow); + if unitDepth <= depthBinEdges(2) + neuronBinIdx(ni) = 1; + elseif unitDepth <= depthBinEdges(3) + neuronBinIdx(ni) = 2; + else + neuronBinIdx(ni) = 3; + end + else + neuronBinIdx(ni) = 1; + end + + MRhist = BuildBurstMatrix( ... + goodU(:, u), ... + round(p_sort.t), ... + round(directimesSorted - lockedPreBase), ... + round(windowTotal)); + MRhist = squeeze(MRhist); + + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist, 2); + [~, ind] = sort(MeanTrial, 'descend'); + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + MRhist = MRhist(takeTrials, :); + end + + nTrials = size(MRhist, 1); + spikeTimes = repmat((1:size(MRhist,2)), nTrials, 1); + spikeTimes = spikeTimes(logical(MRhist)); + counts = histcounts(spikeTimes, lockedEdges); + psthRateNeurons(ni, :) = (counts / (params.binWidth * nTrials)) * 1000; + end + + % ---------------------------------------------------------- + % Average per depth bin and append + % ---------------------------------------------------------- + for b = 1:nDepthBins + binNeurons = neuronBinIdx == b; + if ~any(binNeurons) + fprintf(' [%s] No neurons in depth bin %d for exp %d.\n', stimType, b, ex); + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + continue + end + + psthExp = mean(psthRateNeurons(binNeurons, :), 1, 'omitnan'); + + if params.zScore + baselineBins = tAxis < lockedPreBase; + baselineMean = mean(psthExp(baselineBins)); + baselineStd = std(psthExp(baselineBins)); + if baselineStd > 0 + psthExp = (psthExp - baselineMean) / baselineStd; + else + warning(' [%s] Bin %d: baseline std is zero for exp %d.', stimType, b, ex); + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; + end + continue + end + end + + psthAll{s,b} = [psthAll{s,b}; psthExp(:)']; + fprintf(' [%s] Bin %d: %d neuron(s) in exp %d.\n', stimType, b, sum(binNeurons), ex); + end + + end % stim loop + end % experiment loop + + % ------------------------------------------------------------------ + % Save + % ------------------------------------------------------------------ + S.expList = exList; + S.lockedEdges = lockedEdges; + S.lockedPreBase = lockedPreBase; + S.params = params; + + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + for b = 1:nDepthBins + S.(sprintf('%s_bin%d', stimField, b)) = psthAll{s,b}; + end + end + + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved PSTHs to:\n %s\n', [saveDir nameOfFile]); + +else + % Load psthAll from disk + lockedEdges = S.lockedEdges; + lockedPreBase = S.lockedPreBase; + + psthAll = cell(numel(params.stimTypes), nDepthBins); + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + for b = 1:nDepthBins + fieldKey = sprintf('%s_bin%d', stimField, b); + if isfield(S, fieldKey) + psthAll{s,b} = S.(fieldKey); + else + warning('Field "%s" not found in saved file.', fieldKey); + psthAll{s,b} = []; + end + end + end +end + +% ========================================================================= +% PLOT +% ========================================================================= + +tAxis = lockedEdges(1:end-1); +tAxisPlot = tAxis - lockedPreBase; + +baseColors = lines(numel(params.stimTypes)); +depthShades = [0.6, 0.35, 0.1]; % light → dark for shallow → deep +binLabels = {'shallow', 'middle', 'deep'}; + +stimLegendMap = containers.Map(... + {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... + {'MB', 'SB', 'MG', 'SG'}); + +% ------------------------------------------------------------------ +% First pass: global ylim +% ------------------------------------------------------------------ +yMax = 0; +yMin = inf; + +meanAll = cell(numel(params.stimTypes), nDepthBins); +semAll = cell(numel(params.stimTypes), nDepthBins); + +for s = 1:numel(params.stimTypes) + for b = 1:nDepthBins + data = psthAll{s,b}; + if isempty(data); continue; end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data); continue; end + meanAll{s,b} = mean(data, 1, 'omitnan'); + semAll{s,b} = std(data, 0, 1, 'omitnan') / sqrt(sum(~isnan(data(:,1)))); + yMax = max(yMax, max(meanAll{s,b} + semAll{s,b})); + yMin = min(yMin, min(meanAll{s,b} - semAll{s,b})); + end +end + +yPad = (yMax - yMin) * 0.1; +if params.zScore + yLims = [yMin - yPad, yMax + yPad]; +else + yLims = [max(0, yMin - yPad), yMax + yPad]; +end + +% ------------------------------------------------------------------ +% Plot +% ------------------------------------------------------------------ +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 9 10]); +ax = axes(fig); +hold(ax, 'on'); + +legendHandles = []; +legendLabels = {}; + +for s = 1:numel(params.stimTypes) + + stimKey = char(params.stimTypes(s)); + if isKey(stimLegendMap, stimKey) + shortName = stimLegendMap(stimKey); + else + shortName = stimKey; + end + + for b = 1:nDepthBins + + data = psthAll{s,b}; + if isempty(data); continue; end + validRows = ~all(isnan(data), 2); + data = data(validRows, :); + if isempty(data); continue; end + + meanPSTH = meanAll{s,b}; + semPSTH = semAll{s,b}; + + if params.byDepth + lineColor = baseColors(s,:) * (1 - depthShades(b)); + legendLabel = sprintf('%s %s (%.0f-%.0f um)', ... + shortName, binLabels{b}, depthBinEdges(b), depthBinEdges(b+1)); + else + lineColor = baseColors(s,:); + legendLabel = shortName; + end + + if params.shadeSTD && size(data,1) > 1 + upper = meanPSTH + semPSTH; + lower = meanPSTH - semPSTH; + xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; + yFill = [upper(:)', fliplr(lower(:)') ]; + fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.15, 'EdgeColor', 'none'); + end + + h = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... + 'Color', lineColor, 'LineWidth', 1.5); + + legendHandles(end+1) = h; %#ok + legendLabels{end+1} = legendLabel; %#ok + + fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, sum(validRows)); + end +end + +xline(ax, 0, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); +xline(ax, params.postStim, 'k--', 'LineWidth', 1.2, 'HandleVisibility', 'off'); + +if params.zScore; yLabel = 'Z-score'; else; yLabel = '[spk/s]'; end + +xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); +ylabel(ax, yLabel, 'FontName', 'helvetica', 'FontSize', 8); +xlim(ax, [tAxisPlot(1) tAxisPlot(end)]); +ylim(ax, yLims); + +legend(legendHandles, legendLabels, 'Location', 'northeast', ... + 'FontName', 'helvetica', 'FontSize', 7); + +ax.FontName = 'helvetica'; +ax.FontSize = 8; +ax.YAxis.FontSize = 8; +ax.XAxis.FontSize = 8; +hold(ax, 'off'); + +sgtitle(sprintf('N = %d', numel(exList)), 'FontName', 'helvetica', 'FontSize', 11); +set(fig, 'Units', 'centimeters', 'Position', [20 20 8 6]); + +if params.PaperFig + vs_first.printFig(fig, sprintf('PSTH-depth-%s-%s', ... + params.stimTypes(1), params.stimTypes(2)), PaperFig = params.PaperFig) +end + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.m b/visualStimulationAnalysis/plotPSTH_MultiExp.m index ec92d0d..dc0c682 100644 --- a/visualStimulationAnalysis/plotPSTH_MultiExp.m +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.m @@ -3,13 +3,13 @@ function plotPSTH_MultiExp(exList, params) arguments exList double params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] - params.bin double = 30 params.binWidth double = 10 + params.smooth double = 0 % smoothing window in ms (0 = no smoothing) params.statType string = "BootstrapPerNeuron" params.speed string = "max" params.alpha double = 0.05 params.shadeSTD logical = true - params.postStim double = 500 + params.postStim double = 2000 params.preBase double = 200 params.overwrite logical = false params.TakeTopPercentTrials double = 0.3 @@ -330,7 +330,7 @@ function plotPSTH_MultiExp(exList, params) tAxisPlot = tAxis - lockedPreBase; baseColors = lines(numel(params.stimTypes)); -depthShades = [0.6, 0.35, 0.1]; % light → dark for shallow → deep +depthShades = [0.05, 0.45, 0.78]; % light → dark for shallow → deep binLabels = {'shallow', 'middle', 'deep'}; stimLegendMap = containers.Map(... @@ -398,6 +398,13 @@ function plotPSTH_MultiExp(exList, params) meanPSTH = meanAll{s,b}; semPSTH = semAll{s,b}; + % Smooth if requested + if params.smooth > 0 + smoothBins = round(params.smooth / params.binWidth); % convert ms to bins + meanPSTH = smoothdata(meanPSTH, 'gaussian', smoothBins); + semPSTH = smoothdata(semPSTH, 'gaussian', smoothBins); + end + % Color and label depend on mode if params.byDepth lineColor = baseColors(s,:) * (1 - depthShades(b)); @@ -414,7 +421,7 @@ function plotPSTH_MultiExp(exList, params) lower = meanPSTH - semPSTH; xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; yFill = [upper(:)', fliplr(lower(:)') ]; - fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.15, 'EdgeColor', 'none'); + fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.08, 'EdgeColor', 'none'); end % Mean line diff --git a/visualStimulationAnalysis/plotPSTH_MultiExpV1.m b/visualStimulationAnalysis/plotPSTH_MultiExpV1.m index f9a404d..e9270cd 100644 --- a/visualStimulationAnalysis/plotPSTH_MultiExpV1.m +++ b/visualStimulationAnalysis/plotPSTH_MultiExpV1.m @@ -5,6 +5,7 @@ function plotPSTH_MultiExpV1(exList, params) params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] params.bin double = 30 params.binWidth double = 10 + params.smooth double = 0 % smoothing window in ms (0 = no smoothing) params.statType string = "BootstrapPerNeuron" params.speed string = "max" params.alpha double = 0.05 @@ -391,6 +392,7 @@ function plotPSTH_MultiExpV1(exList, params) meanPSTH = meanAll{s}; semPSTH = semAll{s}; + % Get short legend label for this stim type stimKey = char(params.stimTypes(s)); if isKey(stimLegendMap, stimKey) From 68bbe596c2159486b4160c5b0e8f14edcad08d52 Mon Sep 17 00:00:00 2001 From: simon37robledo Date: Thu, 26 Mar 2026 01:25:38 +0200 Subject: [PATCH 8/8] Adding general raster --- .../RunAnalysisClass.asv | 6 +- visualStimulationAnalysis/RunAnalysisClass.m | 6 +- visualStimulationAnalysis/getNeuronDepths.m | 2 +- .../plotPSTH_MultiExp.asv | 80 +-- visualStimulationAnalysis/plotPSTH_MultiExp.m | 2 +- .../plotRaster_MultiExp.asv | 448 +++++++++++++++++ .../plotRaster_MultiExp.m | 471 ++++++++++++++++++ 7 files changed, 968 insertions(+), 47 deletions(-) create mode 100644 visualStimulationAnalysis/plotRaster_MultiExp.asv create mode 100644 visualStimulationAnalysis/plotRaster_MultiExp.m diff --git a/visualStimulationAnalysis/RunAnalysisClass.asv b/visualStimulationAnalysis/RunAnalysisClass.asv index f372c34..10bb840 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.asv +++ b/visualStimulationAnalysis/RunAnalysisClass.asv @@ -68,16 +68,16 @@ end VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true, smooth=50, stimTypes=["linearlyMovingBall"]); +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=true, byDepth=true, smooth=50); %stimTypes=["linearlyMovingBall"] %% -plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=false); +plotRaster_MultiExp([49:54,64:97], sortBy = "depth",overwrite=false,TakeTopPercentTrials=[]) %% Calculate spatial tuning SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) %% Get neuron depths -getNeuronDepths([49:54,64:72,84:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates +getNeuronDepths([49:54,64:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates %% Gratings for ex = [54 84:90] diff --git a/visualStimulationAnalysis/RunAnalysisClass.m b/visualStimulationAnalysis/RunAnalysisClass.m index f372c34..8be8cca 100644 --- a/visualStimulationAnalysis/RunAnalysisClass.m +++ b/visualStimulationAnalysis/RunAnalysisClass.m @@ -68,16 +68,16 @@ VStimAnalysis.PlotZScoreComparison([49:54,64:97] ,{'MB','RG'},StatMethod='bootsrapRespBase', overwrite=false,ComparePairs={'MB','RG'},PaperFig=true,... overwriteResponse=false,overwriteStats=false)%[49:54,57:91] %%Check why I have different array dimensions in MBR %% PSTH for all experiments -plotPSTH_MultiExp([49:54,64:72,84:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=true, smooth=50, stimTypes=["linearlyMovingBall"]); +plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=true, byDepth=true, smooth=50); %stimTypes=["linearlyMovingBall"] %% -plotPSTH_MultiExp([49:54,64:97], overwrite=true, zScore=true,TakeTopPercentTrials=[], PaperFig=false, byDepth=false); +plotRaster_MultiExp([49:54,64:97], sortBy = "peak",overwrite=false,TakeTopPercentTrials=[]) %% Calculate spatial tuning SpatialTuningIndex([49:54,64:97], indexType = "L_amplitude_ratio" ,overwrite=true, topPercent = 20) %% Get neuron depths -getNeuronDepths([49:54,64:72,84:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates +getNeuronDepths([49:54,64:97]) %[49:54,64:72,84:97] %% PV140 missing depth coordinates %% Gratings for ex = [54 84:90] diff --git a/visualStimulationAnalysis/getNeuronDepths.m b/visualStimulationAnalysis/getNeuronDepths.m index 35f98bd..b653851 100644 --- a/visualStimulationAnalysis/getNeuronDepths.m +++ b/visualStimulationAnalysis/getNeuronDepths.m @@ -57,7 +57,7 @@ % Channel IDs (0-based) → Y positions → real depths channelIDs = goodU(1, :); % 1 x nGoodUnits, 0-based - yPos = NP.chLayoutPositions(2, channelIDs + 1); % 1 x nGoodUnits + yPos = NP.chLayoutPositions(2, channelIDs); % 1 x nGoodUnits neuronDepths = coor_Z - yPos; % 1 x nGoodUnits, in um % Accumulate table columns diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.asv b/visualStimulationAnalysis/plotPSTH_MultiExp.asv index 58b8e5d..6b6bc02 100644 --- a/visualStimulationAnalysis/plotPSTH_MultiExp.asv +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.asv @@ -3,8 +3,8 @@ function plotPSTH_MultiExp(exList, params) arguments exList double params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] - params.bin double = 30 params.binWidth double = 10 + params.smooth double = 0 % smoothing window in ms (0 = no smoothing) params.statType string = "BootstrapPerNeuron" params.speed string = "max" params.alpha double = 0.05 @@ -15,11 +15,11 @@ arguments params.TakeTopPercentTrials double = 0.3 params.zScore logical = false params.PaperFig logical = false - params.byDepth logical = false + params.byDepth logical = false % plot 3 depth bins per stim type end % ------------------------------------------------------------------------- -% Load depth info (only if byDepth requested) +% Load depth info from saved file (only if byDepth is requested) % ------------------------------------------------------------------------- if params.byDepth depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; @@ -82,12 +82,12 @@ if forloop nStim = numel(params.stimTypes); nExp = numel(exList); + % psthAll{s,b} — s = stim type, b = depth bin (1 if byDepth is off) psthAll = cell(nStim, nDepthBins); lockedPreBase = []; lockedNBins = []; lockedEdges = []; - tAxis = []; % FIX 3: initialise here so it is always defined for ei = 1:nExp @@ -98,13 +98,10 @@ if forloop NP = loadNPclassFromTable(ex); catch ME warning('Could not load experiment %d: %s', ex, ME.message); - % FIX 4: only append NaN rows if window is already locked - if ~isempty(lockedNBins) - for s = 1:nStim - for b = 1:nDepthBins - if ~isempty(psthAll{s,b}) - psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; - end + for s = 1:nStim + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; end end end @@ -121,20 +118,18 @@ if forloop obj = rectGridAnalysis(NP); case "linearlyMovingBall" obj = linearlyMovingBallAnalysis(NP); - case "StaticGrating" + case 'StaticGrating' obj = StaticDriftingGratingAnalysis(NP); - case "MovingGrating" + case 'MovingGrating' obj = StaticDriftingGratingAnalysis(NP); otherwise error('Unknown stimType: %s', stimType); end catch ME warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); - if ~isempty(lockedNBins) - for b = 1:nDepthBins - if ~isempty(psthAll{s,b}) - psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; - end + for b = 1:nDepthBins + if ~isempty(psthAll{s,b}) + psthAll{s,b} = [psthAll{s,b}; NaN(1, lockedNBins)]; end end continue @@ -148,25 +143,22 @@ if forloop Stats = obj.ShufflingAnalysis; end - % FIX 1+2: initialise fieldName and use stimType (loop var) not params.stimTypes - fieldName = ''; - startStim = 0; - if params.speed ~= "max" && isequal(obj.stimName, 'linearlyMovingBall') - fieldName = 'Speed2'; - elseif isequal(obj.stimName, 'linearlyMovingBall') - fieldName = 'Speed1'; - elseif isequal(stimType, 'StaticGrating') % FIX 2 - fieldName = 'Static'; - elseif isequal(stimType, 'MovingGrating') % FIX 2 - fieldName = 'Moving'; - startStim = obj.VST.static_time * 1000; + if params.speed ~= "max" && isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed2'; startStim = 0; + elseif isequal(obj.stimName,'linearlyMovingBall') + fieldName = 'Speed1'; startStim = 0; + elseif isequal(params.stimTypes,'StaticGrating') + fieldName = 'Static'; startStim = 0; + elseif isequal(params.stimTypes,'MovingGrating') + startStim = obj.VST.static_time*1000; fieldName = 'Moving'; + else + startStim = 0; end - p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder,1,1); label = string(p_sort.label'); goodU = p_sort.ic(:, label == 'good'); - % Use fieldName if set, otherwise fall back to top-level fields try pvals = Stats.(fieldName).pvalsResponse; catch @@ -187,7 +179,7 @@ if forloop lockedPreBase = preBase; lockedEdges = 0 : params.binWidth : windowTotal; lockedNBins = numel(lockedEdges) - 1; - tAxis = lockedEdges(1:end-1); % FIX 3: set alongside lockedEdges + tAxis = lockedEdges(1:end-1); fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... lockedPreBase, params.postStim, lockedNBins); end @@ -204,7 +196,7 @@ if forloop continue end - fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, numel(eNeurons), ex); + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, ex, numel(eNeurons)); % ---------------------------------------------------------- % Build PSTH per neuron @@ -219,7 +211,7 @@ if forloop if params.byDepth depthRow = depthTable.Experiment == ex & depthTable.Unit == u; if ~any(depthRow) - neuronBinIdx(ni) = 0; % unknown — will be skipped + neuronBinIdx(ni) = 0; % unknown depth — skip continue end unitDepth = depthTable.Depth_um(depthRow); @@ -231,7 +223,7 @@ if forloop neuronBinIdx(ni) = 3; end else - neuronBinIdx(ni) = 1; + neuronBinIdx(ni) = 1; % all neurons in single bin end MRhist = BuildBurstMatrix( ... @@ -338,7 +330,7 @@ tAxis = lockedEdges(1:end-1); tAxisPlot = tAxis - lockedPreBase; baseColors = lines(numel(params.stimTypes)); -depthShades = [0.6, 0.35, 0.1]; % light → dark for shallow → deep +depthShades = [0.05, 0.45, 0.78]; % light → dark for shallow → deep binLabels = {'shallow', 'middle', 'deep'}; stimLegendMap = containers.Map(... @@ -406,6 +398,14 @@ for s = 1:numel(params.stimTypes) meanPSTH = meanAll{s,b}; semPSTH = semAll{s,b}; + % Smooth if requested + if params.smooth > 0 + smoothBins = round(params.smooth / params.binWidth); % convert ms to bins + meanPSTH = smoothdata(meanPSTH, 'gaussian', smoothBins); + semPSTH = smoothdata(semPSTH, 'gaussian', smoothBins); + end + + % Color and label depend on mode if params.byDepth lineColor = baseColors(s,:) * (1 - depthShades(b)); legendLabel = sprintf('%s %s (%.0f-%.0f um)', ... @@ -415,18 +415,20 @@ for s = 1:numel(params.stimTypes) legendLabel = shortName; end + % SEM shading if params.shadeSTD && size(data,1) > 1 upper = meanPSTH + semPSTH; lower = meanPSTH - semPSTH; xFill = [tAxisPlot(:)', fliplr(tAxisPlot(:)')]; yFill = [upper(:)', fliplr(lower(:)') ]; - fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.15, 'EdgeColor', 'none'); + fill(ax, xFill, yFill, lineColor, 'FaceAlpha', 0.08, 'EdgeColor', 'none'); end + % Mean line h = plot(ax, tAxisPlot(:)', meanPSTH(:)', ... 'Color', lineColor, 'LineWidth', 1.5); - legendHandles(end+1) = h; %#ok + legendHandles(end+1) = h; %#ok legendLabels{end+1} = legendLabel; %#ok fprintf(' [%s] n=%d experiments in plot.\n', legendLabel, sum(validRows)); diff --git a/visualStimulationAnalysis/plotPSTH_MultiExp.m b/visualStimulationAnalysis/plotPSTH_MultiExp.m index dc0c682..e8ecb3a 100644 --- a/visualStimulationAnalysis/plotPSTH_MultiExp.m +++ b/visualStimulationAnalysis/plotPSTH_MultiExp.m @@ -9,7 +9,7 @@ function plotPSTH_MultiExp(exList, params) params.speed string = "max" params.alpha double = 0.05 params.shadeSTD logical = true - params.postStim double = 2000 + params.postStim double = 500 params.preBase double = 200 params.overwrite logical = false params.TakeTopPercentTrials double = 0.3 diff --git a/visualStimulationAnalysis/plotRaster_MultiExp.asv b/visualStimulationAnalysis/plotRaster_MultiExp.asv new file mode 100644 index 0000000..3623ccb --- /dev/null +++ b/visualStimulationAnalysis/plotRaster_MultiExp.asv @@ -0,0 +1,448 @@ +function plotRaster_MultiExp(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.binWidth double = 10 + params.smooth double = 0 + params.statType string = "BootstrapPerNeuron" + params.speed string = "max" + params.alpha double = 0.05 + params.postStim double = 500 + params.preBase double = 200 + params.overwrite logical = false + params.TakeTopPercentTrials double = 0.3 + params.zScore logical = true % default true — more meaningful for raster + params.sortBy string = "peak" % "peak" = sort by peak response time, "depth" = sort by depth + params.PaperFig logical = false +end + +% ------------------------------------------------------------------------- +% Load depth info if sorting by depth +% ------------------------------------------------------------------------- +if params.sortBy == "depth" + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + if ~exist(depthFile, 'file') + error('NeuronDepths.mat not found. Run getNeuronDepths() first.'); + end + D = load(depthFile); + depthTable = D.depthTable; +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); +vs_first = linearlyMovingBallAnalysis(NP_first); + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_Raster_%s.mat', exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to recompute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved raster data from:\n %s\n', [saveDir nameOfFile]); + forloop = false; + else + fprintf('Experiment list mismatch — recomputing.\n'); + forloop = true; + end +else + forloop = true; +end + +% ========================================================================= +% EXPERIMENT LOOP +% ========================================================================= +if forloop + + nStim = numel(params.stimTypes); + nExp = numel(exList); + + % rasterAll{s} grows one row per responsive neuron across all experiments + % each row = mean PSTH of one neuron in spk/s (or z-score) + rasterAll = cell(1, nStim); % nNeurons x nBins + depthAll = cell(1, nStim); % nNeurons x 1 — depth of each neuron + expAll = cell(1, nStim); % nNeurons x 1 — which experiment each neuron came from + + for s = 1:nStim + rasterAll{s} = []; + depthAll{s} = []; + expAll{s} = []; + end + + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; + tAxis = []; + + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + continue + end + + for s = 1:nStim + + stimType = params.stimTypes(s); + + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + case "StaticGrating" + obj = StaticDriftingGratingAnalysis(NP); + case "MovingGrating" + obj = StaticDriftingGratingAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + continue + end + + NeuronResp = obj.ResponseWindow; + + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + % Resolve field name and stim start + fieldName = ''; + startStim = 0; + if params.speed ~= "max" && isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed2'; + elseif isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed1'; + elseif isequal(stimType, 'StaticGrating') + fieldName = 'Static'; + elseif isequal(stimType, 'MovingGrating') + fieldName = 'Moving'; + startStim = obj.VST.static_time * 1000; + end + + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + try + pvals = Stats.(fieldName).pvalsResponse; + catch + pvals = Stats.pvalsResponse; + end + + try + C = NeuronResp.(fieldName).C; + catch + C = NeuronResp.C; + end + directimesSorted = C(:, 1)' + startStim; + + preBase = params.preBase; + windowTotal = preBase + params.postStim; + + if isempty(lockedPreBase) + lockedPreBase = preBase; + lockedEdges = 0 : params.binWidth : windowTotal; + lockedNBins = numel(lockedEdges) - 1; + tAxis = lockedEdges(1:end-1); + fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... + lockedPreBase, params.postStim, lockedNBins); + end + + eNeurons = find(pvals < params.alpha); + + if isempty(eNeurons) + fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, numel(eNeurons), ex); + + % ---------------------------------------------------------- + % Build per-neuron PSTH + % ---------------------------------------------------------- + for ni = 1:numel(eNeurons) + u = eNeurons(ni); + + MRhist = BuildBurstMatrix( ... + goodU(:, u), ... + round(p_sort.t), ... + round(directimesSorted - lockedPreBase), ... + round(windowTotal)); + MRhist = squeeze(MRhist); + + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist, 2); + [~, ind] = sort(MeanTrial, 'descend'); + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + MRhist = MRhist(takeTrials, :); + end + + nTrials = size(MRhist, 1); + spikeTimes = repmat((1:size(MRhist,2)), nTrials, 1); + spikeTimes = spikeTimes(logical(MRhist)); + counts = histcounts(spikeTimes, lockedEdges); + neuronPSTH = (counts / (params.binWidth * nTrials)) * 1000; % spk/s + + % Z-score using baseline + if params.zScore + baselineBins = tAxis < lockedPreBase; + bMean = mean(neuronPSTH(baselineBins)); + bStd = std(neuronPSTH(baselineBins)); + if bStd > 0 + neuronPSTH = (neuronPSTH - bMean) / bStd; + else + continue % skip neuron if baseline std is zero + end + end + + % Smooth if requested + if params.smooth > 0 + smoothBins = round(params.smooth / params.binWidth); + neuronPSTH = smoothdata(neuronPSTH, 'gaussian', smoothBins); + end + + % Append neuron row + rasterAll{s} = [rasterAll{s}; neuronPSTH]; + + % Get depth for this neuron + if params.sortBy == "depth" + depthRow = depthTable.Experiment == ex & depthTable.Unit == u; + if any(depthRow) + depthAll{s}(end+1) = depthTable.Depth_um(depthRow); + else + depthAll{s}(end+1) = NaN; + end + else + depthAll{s}(end+1) = NaN; + end + + expAll{s}(end+1) = ex; + end + + end % stim loop + end % experiment loop + + % ------------------------------------------------------------------ + % Save + % ------------------------------------------------------------------ + S.expList = exList; + S.lockedEdges = lockedEdges; + S.lockedPreBase = lockedPreBase; + S.params = params; + + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + S.(sprintf('%s_raster', stimField)) = rasterAll{s}; + S.(sprintf('%s_depth', stimField)) = depthAll{s}; + S.(sprintf('%s_exp', stimField)) = expAll{s}; + end + + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved raster data to:\n %s\n', [saveDir nameOfFile]); + +else + % Load from disk + lockedEdges = S.lockedEdges; + lockedPreBase = S.lockedPreBase; + + rasterAll = cell(1, numel(params.stimTypes)); + depthAll = cell(1, numel(params.stimTypes)); + expAll = cell(1, numel(params.stimTypes)); + + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + rasterAll{s} = S.(sprintf('%s_raster', stimField)); + depthAll{s} = S.(sprintf('%s_depth', stimField)); + expAll{s} = S.(sprintf('%s_exp', stimField)); + end +end + +% ========================================================================= +% SORT NEURONS +% ========================================================================= +for s = 1:numel(params.stimTypes) + data = rasterAll{s}; + if isempty(data); continue; end + + if params.sortBy == "peak" + % Sort by time of peak response in the post-stimulus window + postStimBins = tAxis >= lockedPreBase; + [~, peakBin] = max(data(:, postStimBins), [], 2); + [~, sortIdx] = sort(peakBin); + elseif params.sortBy == "depth" + % Sort by depth (shallow to deep) + [~, sortIdx] = sort(depthAll{s}, 'ascend'); + else + sortIdx = 1:size(data, 1); % no sorting + end + + rasterAll{s} = data(sortIdx, :); + depthAll{s} = depthAll{s}(sortIdx); + expAll{s} = expAll{s}(sortIdx); +end + +% ========================================================================= +% PLOT +% ========================================================================= + +tAxis = lockedEdges(1:end-1); +tAxisPlot = tAxis - lockedPreBase; + +stimLegendMap = containers.Map(... + {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... + {'MB', 'SB', 'MG', 'SG'}); + +nStim = numel(params.stimTypes); + +% ------------------------------------------------------------------ +% Global color limits across all stim types +% ------------------------------------------------------------------ +allValues = []; +for s = 1:nStim + if ~isempty(rasterAll{s}) + allValues = [allValues, rasterAll{s}(:)']; %#ok + end +end +cLimMax = prctile(abs(allValues), 98); % robust limit — ignore extreme outliers +if params.zScore + cLims = [-cLimMax, cLimMax]; % symmetric around zero for z-score +else + cLims = [0, cLimMax]; +end + +% ------------------------------------------------------------------ +% Figure and tiled layout +% ------------------------------------------------------------------ +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 5*nStim + 2, 10]); + +tl = tiledlayout(fig, 1, nStim, 'TileSpacing', 'compact', 'Padding', 'compact'); + +axAll = gobjects(1, nStim); + +for s = 1:nStim + + data = rasterAll{s}; + stimKey = char(params.stimTypes(s)); + if isKey(stimLegendMap, stimKey) + shortName = stimLegendMap(stimKey); + else + shortName = stimKey; + end + + axAll(s) = nexttile(tl); + ax = axAll(s); + + if isempty(data) + title(ax, shortName, 'FontName', 'helvetica', 'FontSize', 8); + axis(ax, 'off'); + continue + end + + % imagesc: x = time, y = neuron index + imagesc(ax, tAxisPlot, 1:size(data,1), data); + clim(ax, cLims); + colormap(ax, flipud(gray)); % white = low, black = high + + % ------------------------------------------------------------------ + % Depth bin boundary lines (only when sorted by depth) + % ------------------------------------------------------------------ + if params.sortBy == "depth" && ~isempty(depthAll{s}) + + % Load bin edges + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + D = load(depthFile); + depthBinEdges = D.depthBinEdges; + + binLabelsDepth = {sprintf('%.0f-%.0f um', depthBinEdges(1), depthBinEdges(2)), ... + sprintf('%.0f-%.0f um', depthBinEdges(2), depthBinEdges(3)), ... + sprintf('%.0f-%.0f um', depthBinEdges(3), depthBinEdges(4))}; + + % Find the last neuron index belonging to each bin boundary + for edge = 2:3 % edges 2 and 3 are the internal boundaries + %lastInBin = find(depthAll{s} <= depthBinEdges(edge), 1, 'last'); + %lastInBin = find(~isnan(depthAll{s}) & depthAll{s} <= depthBinEdges(edge), 1, 'last'); + depthCombined = depthAll{s}; + depthCombined = depthCombined(); + if ~isempty(lastInBin) && lastInBin < size(data,1) + yline(ax, lastInBin + 0.5, 'r-', 'LineWidth', 1.2); + % Label on the right side showing the bin range + text(ax, tAxisPlot(end), lastInBin - size(data,1)*0.02, ... + binLabelsDepth{edge-1}, ... + 'Color', 'r', 'FontSize', 6, 'FontName', 'helvetica', ... + 'HorizontalAlignment', 'right', 'VerticalAlignment', 'top'); + end + end + % Label for the deepest bin + text(ax, tAxisPlot(end), size(data,1), ... + binLabelsDepth{3}, ... + 'Color', 'r', 'FontSize', 6, 'FontName', 'helvetica', ... + 'HorizontalAlignment', 'right', 'VerticalAlignment', 'top'); + end + + % Stim onset and offset lines + xline(ax, 0, 'w--', 'LineWidth', 1.0); + xline(ax, params.postStim, 'w--', 'LineWidth', 1.0); + + xlim(ax, [tAxisPlot(1), tAxisPlot(end)]); + ylim(ax, [0.5, size(data,1)+0.5]); + + xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); + if s == 1 + ylabel(ax, 'Neuron #', 'FontName', 'helvetica', 'FontSize', 8); + end + title(ax, sprintf('%s (n=%d)', shortName, size(data,1)), ... + 'FontName', 'helvetica', 'FontSize', 8); + + ax.FontName = 'helvetica'; + ax.FontSize = 8; + ax.YDir = 'normal'; % neuron 1 at bottom + +end + +% ------------------------------------------------------------------ +% Single colorbar for the whole layout +% ------------------------------------------------------------------ +cb = colorbar(axAll(end)); +if params.zScore + cb.Label.String = 'Z-score'; +else + cb.Label.String = 'Firing rate [spk/s]'; +end +cb.Label.FontName = 'helvetica'; +cb.Label.FontSize = 8; +cb.FontName = 'helvetica'; +cb.FontSize = 8; + +sgtitle(sprintf('N = %d experiments', numel(exList)), ... + 'FontName', 'helvetica', 'FontSize', 10); + +if params.PaperFig + vs_first.printFig(fig, sprintf('Raster-%s', stimLabel), PaperFig=params.PaperFig); +end + +end \ No newline at end of file diff --git a/visualStimulationAnalysis/plotRaster_MultiExp.m b/visualStimulationAnalysis/plotRaster_MultiExp.m new file mode 100644 index 0000000..b31a0f9 --- /dev/null +++ b/visualStimulationAnalysis/plotRaster_MultiExp.m @@ -0,0 +1,471 @@ +function plotRaster_MultiExp(exList, params) + +arguments + exList double + params.stimTypes (1,:) string = ["rectGrid", "linearlyMovingBall"] + params.binWidth double = 10 + params.smooth double = 0 + params.statType string = "BootstrapPerNeuron" + params.speed string = "max" + params.alpha double = 0.05 + params.postStim double = 500 + params.preBase double = 200 + params.overwrite logical = false + params.TakeTopPercentTrials double = 0.3 + params.zScore logical = true % default true — more meaningful for raster + params.sortBy string = "peak" % "peak" = sort by peak response time, "depth" = sort by depth + params.PaperFig logical = false + params.climPrctile double = 90 % percentile for color limit — lower = more contrast + params.climNeg double = 0 % fixed negative z-score limit (absolute value) +end + +% ------------------------------------------------------------------------- +% Load depth info if sorting by depth +% ------------------------------------------------------------------------- +if params.sortBy == "depth" + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + if ~exist(depthFile, 'file') + error('NeuronDepths.mat not found. Run getNeuronDepths() first.'); + end + D = load(depthFile); + depthTable = D.depthTable; +end + +% ------------------------------------------------------------------------- +% Build save path +% ------------------------------------------------------------------------- +NP_first = loadNPclassFromTable(exList(1)); +vs_first = linearlyMovingBallAnalysis(NP_first); + +p = extractBefore(vs_first.getAnalysisFileName, 'lizards'); +p = [p 'lizards']; +if ~exist([p '\Combined_lizard_analysis'], 'dir') + cd(p) + mkdir Combined_lizard_analysis +end +saveDir = [p '\Combined_lizard_analysis']; + +stimLabel = strjoin(params.stimTypes, '-'); +nameOfFile = sprintf('\\Ex_%d-%d_Raster_%s.mat', exList(1), exList(end), stimLabel); + +% ------------------------------------------------------------------------- +% Decide whether to recompute or load +% ------------------------------------------------------------------------- +if exist([saveDir nameOfFile], 'file') == 2 && ~params.overwrite + S = load([saveDir nameOfFile]); + if isequal(S.expList, exList) + fprintf('Loading saved raster data from:\n %s\n', [saveDir nameOfFile]); + forloop = false; + else + fprintf('Experiment list mismatch — recomputing.\n'); + forloop = true; + end +else + forloop = true; +end + +% ========================================================================= +% EXPERIMENT LOOP +% ========================================================================= +if forloop + + nStim = numel(params.stimTypes); + nExp = numel(exList); + + % rasterAll{s} grows one row per responsive neuron across all experiments + % each row = mean PSTH of one neuron in spk/s (or z-score) + rasterAll = cell(1, nStim); % nNeurons x nBins + depthAll = cell(1, nStim); % nNeurons x 1 — depth of each neuron + expAll = cell(1, nStim); % nNeurons x 1 — which experiment each neuron came from + + for s = 1:nStim + rasterAll{s} = []; + depthAll{s} = []; + expAll{s} = []; + end + + lockedPreBase = []; + lockedNBins = []; + lockedEdges = []; + tAxis = []; + + for ei = 1:nExp + + ex = exList(ei); + fprintf('\n=== Experiment %d ===\n', ex); + + try + NP = loadNPclassFromTable(ex); + catch ME + warning('Could not load experiment %d: %s', ex, ME.message); + continue + end + + for s = 1:nStim + + stimType = params.stimTypes(s); + + try + switch stimType + case "rectGrid" + obj = rectGridAnalysis(NP); + case "linearlyMovingBall" + obj = linearlyMovingBallAnalysis(NP); + case "StaticGrating" + obj = StaticDriftingGratingAnalysis(NP); + case "MovingGrating" + obj = StaticDriftingGratingAnalysis(NP); + otherwise + error('Unknown stimType: %s', stimType); + end + catch ME + warning('Could not build %s for exp %d: %s', stimType, ex, ME.message); + continue + end + + NeuronResp = obj.ResponseWindow; + + if params.statType == "BootstrapPerNeuron" + Stats = obj.BootstrapPerNeuron; + else + Stats = obj.ShufflingAnalysis; + end + + % Resolve field name and stim start + fieldName = ''; + startStim = 0; + if params.speed ~= "max" && isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed2'; + elseif isequal(obj.stimName, 'linearlyMovingBall') + fieldName = 'Speed1'; + elseif isequal(stimType, 'StaticGrating') + fieldName = 'Static'; + elseif isequal(stimType, 'MovingGrating') + fieldName = 'Moving'; + startStim = obj.VST.static_time * 1000; + end + + p_sort = obj.dataObj.convertPhySorting2tIc(obj.spikeSortingFolder); + label = string(p_sort.label'); + goodU = p_sort.ic(:, label == 'good'); + + try + pvals = Stats.(fieldName).pvalsResponse; + catch + pvals = Stats.pvalsResponse; + end + + try + C = NeuronResp.(fieldName).C; + catch + C = NeuronResp.C; + end + directimesSorted = C(:, 1)' + startStim; + + preBase = params.preBase; + windowTotal = preBase + params.postStim; + + if isempty(lockedPreBase) + lockedPreBase = preBase; + lockedEdges = 0 : params.binWidth : windowTotal; + lockedNBins = numel(lockedEdges) - 1; + tAxis = lockedEdges(1:end-1); + fprintf(' Locked window: preBase=%d ms, postStim=%d ms, nBins=%d\n', ... + lockedPreBase, params.postStim, lockedNBins); + end + + eNeurons = find(pvals < params.alpha); + + if isempty(eNeurons) + fprintf(' [%s] No responsive neurons in exp %d.\n', stimType, ex); + continue + end + + fprintf(' [%s] %d responsive neuron(s) in exp %d.\n', stimType, numel(eNeurons), ex); + + % ---------------------------------------------------------- + % Build per-neuron PSTH + % ---------------------------------------------------------- + for ni = 1:numel(eNeurons) + u = eNeurons(ni); + + MRhist = BuildBurstMatrix( ... + goodU(:, u), ... + round(p_sort.t), ... + round(directimesSorted - lockedPreBase), ... + round(windowTotal)); + MRhist = squeeze(MRhist); + + if ~isempty(params.TakeTopPercentTrials) + MeanTrial = mean(MRhist, 2); + [~, ind] = sort(MeanTrial, 'descend'); + takeTrials = ind(1:round(numel(MeanTrial)*params.TakeTopPercentTrials)); + MRhist = MRhist(takeTrials, :); + end + + nTrials = size(MRhist, 1); + spikeTimes = repmat((1:size(MRhist,2)), nTrials, 1); + spikeTimes = spikeTimes(logical(MRhist)); + counts = histcounts(spikeTimes, lockedEdges); + neuronPSTH = (counts / (params.binWidth * nTrials)) * 1000; % spk/s + + % Z-score using baseline + if params.zScore + baselineBins = tAxis < lockedPreBase; + bMean = mean(neuronPSTH(baselineBins)); + bStd = std(neuronPSTH(baselineBins)); + if bStd > 0 + neuronPSTH = (neuronPSTH - bMean) / bStd; + else + continue % skip neuron if baseline std is zero + end + end + + % Smooth if requested + if params.smooth > 0 + smoothBins = round(params.smooth / params.binWidth); + neuronPSTH = smoothdata(neuronPSTH, 'gaussian', smoothBins); + end + + % Append neuron row + rasterAll{s} = [rasterAll{s}; neuronPSTH]; + + % Get depth for this neuron + if params.sortBy == "depth" + depthRow = depthTable.Experiment == ex & depthTable.Unit == u; + if any(depthRow) + depthAll{s}(end+1) = depthTable.Depth_um(depthRow); + else + depthAll{s}(end+1) = NaN; + end + else + depthAll{s}(end+1) = NaN; + end + + expAll{s}(end+1) = ex; + end + + end % stim loop + end % experiment loop + + % ------------------------------------------------------------------ + % Save + % ------------------------------------------------------------------ + S.expList = exList; + S.lockedEdges = lockedEdges; + S.lockedPreBase = lockedPreBase; + S.params = params; + + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + S.(sprintf('%s_raster', stimField)) = rasterAll{s}; + S.(sprintf('%s_depth', stimField)) = depthAll{s}; + S.(sprintf('%s_exp', stimField)) = expAll{s}; + end + + save([saveDir nameOfFile], '-struct', 'S'); + fprintf('\nSaved raster data to:\n %s\n', [saveDir nameOfFile]); + +else + % Load from disk + lockedEdges = S.lockedEdges; + lockedPreBase = S.lockedPreBase; + + rasterAll = cell(1, numel(params.stimTypes)); + depthAll = cell(1, numel(params.stimTypes)); + expAll = cell(1, numel(params.stimTypes)); + + for s = 1:numel(params.stimTypes) + stimField = matlab.lang.makeValidName(params.stimTypes(s)); + rasterAll{s} = S.(sprintf('%s_raster', stimField)); + depthAll{s} = S.(sprintf('%s_depth', stimField)); + expAll{s} = S.(sprintf('%s_exp', stimField)); + end +end + +tAxis = lockedEdges(1:end-1); +tAxisPlot = tAxis - lockedPreBase; + +% ========================================================================= +% SORT NEURONS +% ========================================================================= +for s = 1:numel(params.stimTypes) + data = rasterAll{s}; + if isempty(data); continue; end + + if params.sortBy == "peak" + % Sort by time of peak response in the post-stimulus window + postStimBins = tAxis >= lockedPreBase; + [~, peakBin] = max(data(:, postStimBins), [], 2); + [~, sortIdx] = sort(peakBin); + elseif params.sortBy == "depth" + % Sort by depth (shallow to deep) + [~, sortIdx] = sort(depthAll{s}, 'ascend'); + else + sortIdx = 1:size(data, 1); % no sorting + end + + rasterAll{s} = data(sortIdx, :); + depthAll{s} = depthAll{s}(sortIdx); + expAll{s} = expAll{s}(sortIdx); +end + +% ========================================================================= +% PLOT +% ========================================================================= + + +stimLegendMap = containers.Map(... + {'linearlyMovingBall', 'rectGrid', 'MovingGrating', 'StaticGrating'}, ... + {'MB', 'SB', 'MG', 'SG'}); + +nStim = numel(params.stimTypes); + +% ------------------------------------------------------------------ +% ------------------------------------------------------------------ +% Global color limits — use lower percentile for better contrast +allValues = []; +for s = 1:nStim + if ~isempty(rasterAll{s}) + allValues = [allValues, rasterAll{s}(:)']; %#ok + end +end + +if params.zScore + cLimPos = prctile(allValues, params.climPrctile); % data-driven positive limit + cLims = [-params.climNeg, cLimPos]; % asymmetric: fixed neg, data-driven pos +else + cLims = [prctile(allValues, 2), prctile(allValues, params.climPrctile)]; +end + +% ------------------------------------------------------------------ +% Figure and tiled layout +% ------------------------------------------------------------------ +fig = figure; +set(fig, 'Units', 'centimeters', 'Position', [5 5 5*nStim + 2, 10]); + +tl = tiledlayout(fig, 1, nStim, 'TileSpacing', 'compact', 'Padding', 'compact'); + +axAll = gobjects(1, nStim); + +for s = 1:nStim + + data = rasterAll{s}; + stimKey = char(params.stimTypes(s)); + if isKey(stimLegendMap, stimKey) + shortName = stimLegendMap(stimKey); + else + shortName = stimKey; + end + + axAll(s) = nexttile(tl); + ax = axAll(s); + + if isempty(data) + title(ax, shortName, 'FontName', 'helvetica', 'FontSize', 8); + axis(ax, 'off'); + continue + end + + % imagesc: x = time, y = neuron index + imagesc(ax, tAxisPlot, 1:size(data,1), data); + clim(ax, cLims); + %colormap(ax, flipud(gray)); % white = low, black = high + if params.zScore + cLimPos = prctile(allValues, params.climPrctile); + cLims = [-params.climNeg, cLimPos]; + + % Proportion of colors for each side — white stays at zero + nColors = 256; + nNeg = round(nColors * params.climNeg / (params.climNeg + cLimPos)); + nPos = nColors - nNeg; + + blueHalf = [linspace(0.1, 1, nNeg)', linspace(0.2, 1, nNeg)', linspace(0.8, 1, nNeg)']; + redHalf = [linspace(1, 0.9, nPos)', linspace(1, 0.2, nPos)', linspace(1, 0.05, nPos)']; + colormap(ax, [blueHalf; redHalf]); + else + cLims = [prctile(allValues, 2), prctile(allValues, params.climPrctile)]; + colormap(ax, flipud(gray)); + end + + % ------------------------------------------------------------------ + % Depth bin boundary lines (only when sorted by depth) + % ------------------------------------------------------------------ + if params.sortBy == "depth" && ~isempty(depthAll{s}) + + % Load bin edges + depthFile = 'W:\Large_scale_mapping_NP\lizards\Combined_lizard_analysis\NeuronDepths.mat'; + D = load(depthFile); + depthBinEdges = D.depthBinEdges; + + binLabelsDepth = {sprintf('%.0f-%.0f um', depthBinEdges(1), depthBinEdges(2)), ... + sprintf('%.0f-%.0f um', depthBinEdges(2), depthBinEdges(3)), ... + sprintf('%.0f-%.0f um', depthBinEdges(3), depthBinEdges(4))}; + + % Find the last neuron index belonging to each bin boundary + for edge = 2:3 % edges 2 and 3 are the internal boundaries + %lastInBin = find(depthAll{s} <= depthBinEdges(edge), 1, 'last'); + %lastInBin = find(~isnan(depthAll{s}) & depthAll{s} <= depthBinEdges(edge), 1, 'last'); + depthCombined = depthAll{s}; + depthCombined = depthCombined(~isnan(depthCombined)); + lastInBin = find(depthCombined <= depthBinEdges(edge), 1, 'last'); + if ~isempty(lastInBin) && lastInBin < size(data,1) + yline(ax, lastInBin + 0.5, 'k-', 'LineWidth', 1.2); + % Label on the right side showing the bin range + text(ax, tAxisPlot(5), lastInBin - size(data,1)*0.02, ... + binLabelsDepth{edge-1}, ... + 'Color', 'w', 'FontSize', 6, 'FontName', 'helvetica', ... + 'HorizontalAlignment', 'left', 'VerticalAlignment', 'top'); + end + end + % Label for the deepest bin + text(ax, tAxisPlot(5), size(data,1), ... + binLabelsDepth{3}, ... + 'Color', 'w', 'FontSize', 6, 'FontName', 'helvetica', ... + 'HorizontalAlignment', 'left', 'VerticalAlignment', 'top'); + end + + % Stim onset and offset lines + xline(ax, 0, 'k--', 'LineWidth', 1.0); + xline(ax, params.postStim, 'k--', 'LineWidth', 1.0); + + xlim(ax, [tAxisPlot(1), tAxisPlot(end)]); + ylim(ax, [0.5, size(data,1)+0.5]); + xticks(ax, -params.preBase : 100 : params.postStim); + + xlabel(ax, 'Time re. stim onset [ms]', 'FontName', 'helvetica', 'FontSize', 8); + if s == 1 + ylabel(ax, 'Neuron #', 'FontName', 'helvetica', 'FontSize', 8); + end + title(ax, sprintf('%s (n=%d)', shortName, size(data,1)), ... + 'FontName', 'helvetica', 'FontSize', 8); + + ax.FontName = 'helvetica'; + ax.FontSize = 8; + ax.YDir = 'normal'; % neuron 1 at bottom + + +end + +% ------------------------------------------------------------------ +% Single colorbar for the whole layout +% ------------------------------------------------------------------ +cb = colorbar(axAll(end)); +if params.zScore + cb.Label.String = 'Z-score'; +else + cb.Label.String = 'Firing rate [spk/s]'; +end +cb.Label.FontName = 'helvetica'; +cb.Label.FontSize = 8; +cb.FontName = 'helvetica'; +cb.FontSize = 8; + +sgtitle(sprintf('N = %d experiments', numel(exList)), ... + 'FontName', 'helvetica', 'FontSize', 10); + +if params.PaperFig + vs_first.printFig(fig, sprintf('Raster-%s', stimLabel), PaperFig=params.PaperFig); +end + +end \ No newline at end of file