소스 검색

adding csv support

Sebastian Vendt 6 년 전
부모
커밋
880081ab45
1개의 변경된 파일35개의 추가작업 그리고 13개의 파일을 삭제
  1. 35 13
      julia/net.jl

+ 35 - 13
julia/net.jl

@@ -27,6 +27,9 @@ s = ArgParseSettings()
 		help = "additional message describing the training log"
 		arg_type = String
 		default = ""
+	"--csv"
+		help = "set, if you additionally want a csv output of the learning process"
+		action = :store_true
 end
 parsed_args = parse_args(ARGS, s)
 
@@ -94,6 +97,7 @@ end
 
 debug_str = ""
 log_msg = parsed_args["logmsg"]
+csv_out = parse_args["csv"]
 @debug begin
 	global debug_str
 	debug_str = "DEBUG_"
@@ -101,13 +105,14 @@ log_msg = parsed_args["logmsg"]
 end
 
 io = nothing
+io_csv = nothing
 
 function adapt_learnrate(epoch_idx)
     return init_learning_rate * decay_rate^(epoch_idx / decay_step)
 end
 
 function loss(x, y) 
-	# quadratic euclidean distance + parameternorm?
+	# quadratic euclidean distance + parameternorm
 	return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
 end
 
@@ -141,25 +146,34 @@ model = Chain(
 	Dense(inputDense3, 2, σ), # coordinates between 0 and 1
 )
 
-function train_model(model, train_set, validation_set, test_set)
-   Flux.testmode!(model, true)
-	opt = Momentum(learning_rate, momentum)
-	if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
-	else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
+function log(epoch, use_testset)
+	Flux.testmode!(model, true)
+	
+	if(epoch == 0 | epoch == epochs) # evalutation phase 
+		if(use_testset) @printf(io, "[%s] Epoch %3d: Loss(test): %f\n", Dates.format(now(), time_format), epoch, loss(test_set)) 
+		else @printf(io, "[%s] Epoch %3d: Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(validation_set)) end
+	else # learning phase
+		 @printf(io, "[%s] Epoch %3d: Loss(train): %f\n", Dates.format(now(), time_format), epoch, loss(train_set)) 
+	end
+	
+	if(csv_out) @printf(io_csv, "%d, %f\n", epoch, loss(train_set)) end
 	
+	Flux.testmode!(model, false)
+end
+
+function train_model()
+	opt = Momentum(learning_rate, momentum)
+	log(0, !validate)
     for i in 1:epochs
 		flush(io)
         Flux.testmode!(model, false) # bring model in training mode
         Flux.train!(loss, params(model), train_set, opt)
         opt.eta = adapt_learnrate(i)
-        if ( rem(i, printout_interval) == 0 ) 
-         Flux.testmode!(model, true)
-			@printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set)) 
+        if (rem(i, printout_interval) == 0) 
+			log(i, false)
 		end 
     end
-	Flux.testmode!(model, true)
-	if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
-	else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
+	log(epochs, !validate)
 end
 
 # logging framework 
@@ -169,6 +183,14 @@ global_logger(SimpleLogger(io)) # for debug outputs
 @printf(Base.stdout, "Logging to File: %s\n", fp)
 @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
 @printf(io, "%s\n", log_msg)
+
+# csv handling
+if (csv_out)
+	fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
+	io_csv = open(fp_csv, "w+") # read, write, create, truncate
+	@printf(io_csv, "epoch, loss(train)\n")
+end	
+
 # dump configuration 
 @debug begin
 	for symbol in names(Main)
@@ -190,7 +212,7 @@ if (usegpu)
 end
 
 
-train_model(model, train_set, validation_set, test_set)
+train_model()