|
@@ -178,6 +178,12 @@ function log_csv(model, epoch)
|
|
|
Flux.testmode!(model, false)
|
|
|
end
|
|
|
|
|
|
+function eval_model(model)
|
|
|
+ Flux.testmode!(model, true)
|
|
|
+ if (validate) return loss(model, validation_set)
|
|
|
+ else return loss(model, test_set) end
|
|
|
+end
|
|
|
+
|
|
|
function train_model()
|
|
|
model = create_model()
|
|
|
if (usegpu) model = gpu(model) end
|
|
@@ -196,19 +202,16 @@ function train_model()
|
|
|
curr_loss = loss(model, train_set)
|
|
|
if(abs(last_loss - curr_loss) < delta)
|
|
|
@printf(io, "Early stopping with %f at %d", curr_loss, i)
|
|
|
- Flux.testmode!(model, true)
|
|
|
- if (validate) return loss(model, validation_set)
|
|
|
- else return loss(model, test_set) end
|
|
|
+ return eval_model(model)
|
|
|
end
|
|
|
last_loss = curr_loss
|
|
|
end
|
|
|
- Flux.testmode!(model, true)
|
|
|
- if (validate) return loss(model, validation_set)
|
|
|
- else return loss(model, test_set) end
|
|
|
+ return eval_model(model)
|
|
|
end
|
|
|
|
|
|
function random_search()
|
|
|
rng = MersenneTwister()
|
|
|
+ results = []
|
|
|
for search in 1:500
|
|
|
# create random set
|
|
|
momentum = rand(rng, rs_momentum)
|
|
@@ -225,8 +228,10 @@ function random_search()
|
|
|
@printf(io, "%s\n", config1)
|
|
|
@printf(io, "%s\n\n", config2)
|
|
|
|
|
|
- train_model()
|
|
|
+ loss = train_model()
|
|
|
+ push!(results, (search, loss))
|
|
|
end
|
|
|
+ return results
|
|
|
end
|
|
|
|
|
|
|
|
@@ -264,5 +269,8 @@ if (usegpu)
|
|
|
const test_set = gpu.(test)
|
|
|
end
|
|
|
|
|
|
-if(!runD) random_search()
|
|
|
+if(!runD)
|
|
|
+ results = random_search()
|
|
|
+ BSON.@save "results.bson" results
|
|
|
+ #TODO sort and print best 5-10 results
|
|
|
else train_model() end
|