Przeglądaj źródła

now saving best loss with search index as bson
TODO: sort list and print best 10 results

Sebastian Vendt 6 lat temu
rodzic
commit
2ad570dc7d
1 zmienionych plików z 16 dodań i 8 usunięć
  1. 16 8
      julia/net.jl

+ 16 - 8
julia/net.jl

@@ -178,6 +178,12 @@ function log_csv(model, epoch)
 	Flux.testmode!(model, false)
 end
 
+function eval_model(model)
+	Flux.testmode!(model, true)
+	if (validate) return loss(model, validation_set)
+	else return loss(model, test_set) end
+end
+
 function train_model()
 	model = create_model()
 	if (usegpu) model = gpu(model) end
@@ -196,19 +202,16 @@ function train_model()
 		curr_loss = loss(model, train_set)
 		if(abs(last_loss - curr_loss) < delta)
 			@printf(io, "Early stopping with %f at %d", curr_loss, i)
-			Flux.testmode!(model, true)
-			if (validate) return loss(model, validation_set)
-			else return loss(model, test_set) end
+			return eval_model(model)
 		end
 		last_loss = curr_loss
     end
-    Flux.testmode!(model, true)
-    if (validate) return loss(model, validation_set)
-    else return loss(model, test_set) end
+    return eval_model(model)
 end
 
 function random_search()
 	rng = MersenneTwister()
+	results = []
 	for search in 1:500
 		# create random set
 		momentum = rand(rng, rs_momentum)
@@ -225,8 +228,10 @@ function random_search()
 		@printf(io, "%s\n", config1)
 		@printf(io, "%s\n\n", config2)
 		
-		train_model()
+		loss = train_model()
+		push!(results, (search, loss))
 	end
+	return results
 end
 
 
@@ -264,5 +269,8 @@ if (usegpu)
 	const test_set = gpu.(test)
 end
 
-if(!runD) random_search()
+if(!runD)
+	results = random_search()
+	BSON.@save "results.bson" results
+	#TODO sort and print best 5-10 results 
 else train_model() end