-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptParams.R
More file actions
173 lines (154 loc) · 6.76 KB
/
optParams.R
File metadata and controls
173 lines (154 loc) · 6.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
optParams <-
function(func, form=NULL, data=NULL, x=NULL, y=NULL
,nTrain=c(100,1000,10000), nValid=nTrain, replications=rep(30, length(nTrain))
,optFunc=function(pred,actual){mean((pred-actual)^2)}
,optArgs=list()
,optVals=rep(5,length(optArgs))
,optRed=rep(.7,length(optArgs))
,predFunc=predict
,constArgs=list()
,coldStart=10
,seed=321)
{
set.seed(seed)
#data quality checks
if(!is(func,"function"))
stop("func must be a function!")
if(is.null(form) & (!is.null(data) | is.null(x) | is.null(y)))
stop("Must specify form, data OR x, y")
if(is.null(x) & (!is.null(y) | is.null(form) | is.null(data)))
stop("Must specify form, data OR x, y")
if(!is.null(x) & !is.null(form))
stop("Must specify form, data OR x, y")
if(!is.null(x))
if(nrow(x)!=length(y))
stop("Number of rows of x must be the same as length of y!")
if(length(nTrain)!=length(nValid))
stop("nTrain and nValid must have the same length!")
n = ifelse(is.null(data), length(y), nrow(data))
if(max(nTrain + nValid)>n)
stop("nTrain + nValid exceeds n at some point!")
if(length(optArgs)==0)
stop("No arguments to optimize, length(optArgs)=0!")
if(!is(optArgs,"list") | !all( lapply(optArgs, length)==3 ))
stop("optArgs's arguments don't have the correct form! Each should be a list of length 3.")
if(!any( lapply(optArgs, function(x){
(x[[2]] %in% c("numeric", "ordered") & length(x[[3]]==2)) |
(x[[2]]=="categorical" & length(x[[3]])>1)}) ) )
stop("optArgs is not of the right form!")
#Testing functions to validate appropriate arguments
library(ggplot2)
bestError = Inf
print("Starting cold start...")
for(i in 1:coldStart ){
tempArgs = randArgs(optArgs)
samTrn = sample(n, size=nTrain[1])
#Sample validation observations from 1:n, removing training obs
samVal = sample((1:n)[-samTrn], size=nValid[1])
#Start args with the baseArgs that you always need
args = constArgs
#Add the best estimates so far, except for the current parameter being optimized
args = c(args, tempArgs)
if(!is.null(form)){
#Add the sampled data onto your args
args = c(list(data=data[samTrn,]), args)
args = c(list(form=form), args)
fit = do.call(func, args)
preds = predFunc(fit, newdata=data[samVal,])
error = optFunc(preds, data[samVal,all.vars(form)[1]])
} else {
#Add the sampled x,y onto your args, make sure x and y come first
args = c(list(y=y[samTrn]), args)
args = c(list(x=x[samTrn,]), args)
fit = do.call(func, args)
#pass samVal up to Global environment so predFunc can access it, if necessary
samVal <<- samVal
preds = predFunc(fit, newdata=x[samVal,])
error = optFunc(pred=preds, actual=y[samVal])
}
if(error<bestError){
currArgs = tempArgs
bestError = error
}
}
print("Cold start completed, beginning main optimization...")
#Main training loop
for(epoch in 1:length(nTrain)){
for(par in 1:length(optArgs)){
currParam = optArgs[[par]][[1]]
#Set the parameter values based on the desired types from optArgs
paramVals = optArgs[[par]][[3]]
if(optArgs[[par]][[2]]=="ordered"){
paramVals = round(seq(paramVals[1], paramVals[2], length.out=optVals[par]))
paramVals = unique(paramVals)
}
if(optArgs[[par]][[2]]=="numeric")
paramVals = seq(paramVals[1], paramVals[2], length.out=optVals[par])
errors = matrix(0, nrow=replications, ncol=length(paramVals))
for( repl in 1:replications[epoch] ){
samTrn = sample(n, size=nTrain[epoch])
#Sample validation observations from 1:n, removing training obs
samVal = sample((1:n)[-samTrn], size=nValid[epoch])
#pass samVal up to Global environment so predFunc can access it, if necessary
samVal <<- samVal
for(parVal in 1:length(paramVals)){
#Start args with the constArgs that you always need
args = constArgs
#Add the best estimates so far, except for the current parameter being optimized
args = c(args, currArgs[names(currArgs)!=currParam] )
#Add the parameter we're optimizing over
args[[length(args)+1]] = paramVals[parVal]
names(args)[length(args)] = currParam
if(!is.null(form)){
#Add the formula and sampled data onto your args
args = c(list(data=data[samTrn,]), args)
args = c(list(form=form), args)
fit = do.call(func, args)
preds = predFunc(fit, newdata=data[samVal,])
errors[repl,parVal] = optFunc(preds, data[samVal,all.vars(form)[1]])
} else {
#Add the sampled x,y onto your args
args = c(list(y=y[samTrn]), args)
args = c(list(x=x[samTrn,]), args)
fit = do.call(func, args)
preds = predFunc(fit, newdata=x[samVal,])
errors[repl,parVal] = optFunc(preds, y[samVal])
}
} #close parameter value loop
} #close replication loop
#Update currArgs with the best fit
errors = apply(errors, 2, function(x){
data.frame(mean=mean(x), sd=sd(x))})
errors = do.call("rbind", errors)
newPar = paramVals[which.min(errors[,1])]
currArgs[currParam] = newPar
#Update optArgs: tune parameter search as appropriate
if(optArgs[[par]][[2]] %in% c("ordered", "numeric") ){
oldRange = optArgs[[par]][[3]]
len = oldRange[2] - oldRange[1]
len = len*optRed[par]
newRange = newPar + c(-len/2,len/2)
#Move newRange if outside of original range
if(newRange[1]<oldRange[1])
newRange = newRange + oldRange[1] - newRange[1]
if(newRange[2]>oldRange[2])
newRange = newRange + oldRange[2] - newRange[2]
optArgs[[par]][[3]] = newRange
} else {
#categorical variables, remove ones from search that did significantly worse.
best = errors[which.min(errors[,1]),]
#Assume normality, then we're testing difference in means with different variances
zVals = sapply( 1:nrow(errors), function(i){
(errors[i,"mean"]-best["mean"])/sqrt(errors[i,"sd"]^2+best["sd"]^2) } )
#Remove vals with z-scores greater than 2, i.e. alpha=.05 test
optArgs[[par]][[3]] = optArgs[[par]][[3]][zVals<2]
}
errors$paramVals = paramVals
ggsave(paste0(Sys.info()[4],"_Param_",currParam,"_epoch_",epoch,".png")
,ggplot(errors, aes(x=paramVals)) + geom_point(aes(y=mean)) +
geom_errorbar(aes(ymax=mean+2*sd, ymin=mean-2*sd)) +
labs(x=currParam, y="Value of optFunc()") )
print(paste("Parameter",currParam,"optimized to",paramVals[which.min(errors[,1])],"for epoch",epoch))
} #close parameter loop
} #close epoch loop
}