-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathdataset.lua
More file actions
160 lines (132 loc) · 4.59 KB
/
dataset.lua
File metadata and controls
160 lines (132 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')
local ffi = require 'ffi'
local class = require('pl.class')
local dir = require 'pl.dir'
local tablex = require 'pl.tablex'
local argcheck = require 'argcheck'
local flowX = require 'flowExtensions'
require 'sys'
require 'xlua'
require 'image'
local dataset = torch.class('dataLoader')
local initcheck = argcheck{
pack=true,
help=[[
A dataset class for loading images and dense outputs such as optical flow
or segmentations in large datasets. Tested only on Linux (as it uses
command-line linux utilities to scale up)
]],
{name="inputSize",
type="table",
help="the size of the input images"},
{name="flowSize",
type="table",
help="the size of the network output"},
{name="samplingMode",
type="string",
help="Sampling mode: random | balanced ",
default = "balanced"},
{name="verbose",
type="boolean",
help="Verbose mode during initialization",
default = false},
{name="samples",
type="table",
help="samples of training or testing images",
opt = true},
{name="sampleHookTrain",
type="function",
help="applied to sample during training(ex: for lighting jitter). "
.. "It takes the image path as input",
opt = true},
{name="sampleHookTest",
type="function",
help="applied to sample during testing",
opt = true},
}
function dataset:__init(...)
-- argcheck
local args = initcheck(...)
print(args)
for k,v in pairs(args) do self[k] = v end
if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
self.numSamples = #self.samples
assert(self.numSamples > 0, "Could not find any sample in the given input paths")
if self.verbose then print(self.numSamples .. ' samples found.') end
end
-- size(), size(class)
function dataset:size(class, list)
return self.numSamples
end
-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, imgTable, flowTable)
local images, flows
-- local quantity = #scalarTable
local quantity = #imgTable
--print('inputSize' .. self.inputSize[1])
--print('flowSize' .. self.flowSize[1], flowTable[1]:size()[1])
assert(imgTable[1]:size()[1] == self.inputSize[1])
assert(flowTable[1]:size()[1] == self.flowSize[1])
images = torch.Tensor(quantity,
self.inputSize[1], self.inputSize[2], self.inputSize[3])
flows = torch.Tensor(quantity,
self.flowSize[1], self.flowSize[2], self.flowSize[3])
--print('images size', images:size())
--print('imgtable size', imgTable[1]:size())
for i=1,quantity do
images[i]:copy(imgTable[i])
flows[i]:copy(flowTable[i])
end
return images, flows
end
-- sampler, samples from the training set.
local function getFlowTable(flows )
local flowTab = { flowX.scaleBatch(flows, 0.125),
flowX.scaleBatch(flows, 0.25),
flowX.scaleBatch(flows, 0.5),
flows}
return flowTab
end
function dataset:sample(quantity)
assert(quantity)
local imgTable = {}
local flowTable = {}
--print('Quantity ' ..quantity)
for i=1,quantity do
local id = torch.random(1, self.numSamples)
local img, flow = self:sampleHookTrain(id) -- single element[not tensor] from a row
--print("Printing Image and Output Sizes in dataset sample")
--print(img:size())
--print(output:size())
--local out = self:getById(id)
table.insert(imgTable, img)
table.insert(flowTable, flow)
end
-- print('Image table dim' .. imgTable[1]:dim() .. 'Output Table dim' .. outputTable[1]:dim())
local images, flows = tableToOutput(self, imgTable, flowTable)
local flowTab = getFlowTable(flows)
return images, flowTab
end
function dataset:get(i1, i2)
--local indices = self.samplingIds[{{i1, i2}}];
local quantity = i2 - i1 + 1;
assert(quantity > 0)
-- now that indices has been initialized, get the samples
local imgTable = {}
local flowTable = {}
for i=1,quantity do
-- load the sample
--print(indices[i], indices[i][1])
local img, flow = self:sampleHookTest(i1+i-1)
-- local out = self:sampleHookTest(imgpath)
table.insert(imgTable, img)
table.insert(flowTable, flow)
end
local images, flows = tableToOutput(self, imgTable, flowTable)
local flowTab = getFlowTable(flows)
return images, flowTab
end
return dataset