-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathTrainTestSplit.m
More file actions
47 lines (38 loc) · 1.6 KB
/
TrainTestSplit.m
File metadata and controls
47 lines (38 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
function [ strTrain, strTest, yTrain, yTest ] = TrainTestSplit( strInstances, labels, numTrainPerClass )
%TRAINTESTSPLIT Stratified train/test split where instances are filenames
% Inputs:
% `strInstances` M x 1 cell array of instances
% `labels` M x 1 array of integer labels
% `numTrainPerClass` Number of training examples per class. Takes all
% instances in the class as training examples if number of instances in
% that class is less than this.
% Outputs:
% `fTrain` Filenames of training instances
% `fTest` Filenames of test instances
% `yTrain` Labels of training instances
% `yTest` Labels of test instances
strTrain = {};
strTest = {};
yTrain = [];
yTest = [];
labeledInstances = [strInstances' num2cell(labels)'];
for i = unique(labels)
% Get indices of a random sample of labeled instances that match this label
indAll = find(cell2mat(labeledInstances(:, 2)) == i);
% Split indices into training and test
indAll = indAll(randperm(numel(indAll))); % shuffle indices
if (numel(indAll) >= numTrainPerClass)
indTrain = indAll(1 : numTrainPerClass);
indTest = indAll(numTrainPerClass + 1: end);
else
indTrain = indAll;
indTest = [];
end
%labeledInstancesTrain = labeledInstances(indTrain, 2);
%labeledInstancesTest = labeledInstances(indTest, :);
strTrain = [strTrain labeledInstances(indTrain, 1)'];
strTest = [strTest labeledInstances(indTest, 1)'];
yTrain = [yTrain; cell2mat(labeledInstances(indTrain, 2))];
yTest = [yTest; cell2mat(labeledInstances(indTest, 2))];
end
end