|
@@ -27,6 +27,9 @@ s = ArgParseSettings()
|
|
|
help = "additional message describing the training log"
|
|
|
arg_type = String
|
|
|
default = ""
|
|
|
+ "--csv"
|
|
|
+ help = "set, if you additionally want a csv output of the learning process"
|
|
|
+ action = :store_true
|
|
|
end
|
|
|
parsed_args = parse_args(ARGS, s)
|
|
|
|
|
@@ -94,6 +97,7 @@ end
|
|
|
|
|
|
debug_str = ""
|
|
|
log_msg = parsed_args["logmsg"]
|
|
|
+csv_out = parse_args["csv"]
|
|
|
@debug begin
|
|
|
global debug_str
|
|
|
debug_str = "DEBUG_"
|
|
@@ -101,13 +105,14 @@ log_msg = parsed_args["logmsg"]
|
|
|
end
|
|
|
|
|
|
io = nothing
|
|
|
+io_csv = nothing
|
|
|
|
|
|
function adapt_learnrate(epoch_idx)
|
|
|
return init_learning_rate * decay_rate^(epoch_idx / decay_step)
|
|
|
end
|
|
|
|
|
|
function loss(x, y)
|
|
|
- # quadratic euclidean distance + parameternorm?
|
|
|
+ # quadratic euclidean distance + parameternorm
|
|
|
return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
|
|
|
end
|
|
|
|
|
@@ -141,25 +146,34 @@ model = Chain(
|
|
|
Dense(inputDense3, 2, σ), # coordinates between 0 and 1
|
|
|
)
|
|
|
|
|
|
-function train_model(model, train_set, validation_set, test_set)
|
|
|
- Flux.testmode!(model, true)
|
|
|
- opt = Momentum(learning_rate, momentum)
|
|
|
- if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
|
|
|
- else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
|
|
|
+function log(epoch, use_testset)
|
|
|
+ Flux.testmode!(model, true)
|
|
|
+
|
|
|
+ if(epoch == 0 | epoch == epochs) # evalutation phase
|
|
|
+ if(use_testset) @printf(io, "[%s] Epoch %3d: Loss(test): %f\n", Dates.format(now(), time_format), epoch, loss(test_set))
|
|
|
+ else @printf(io, "[%s] Epoch %3d: Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(validation_set)) end
|
|
|
+ else # learning phase
|
|
|
+ @printf(io, "[%s] Epoch %3d: Loss(train): %f\n", Dates.format(now(), time_format), epoch, loss(train_set))
|
|
|
+ end
|
|
|
+
|
|
|
+ if(csv_out) @printf(io_csv, "%d, %f\n", epoch, loss(train_set)) end
|
|
|
|
|
|
+ Flux.testmode!(model, false)
|
|
|
+end
|
|
|
+
|
|
|
+function train_model()
|
|
|
+ opt = Momentum(learning_rate, momentum)
|
|
|
+ log(0, !validate)
|
|
|
for i in 1:epochs
|
|
|
flush(io)
|
|
|
Flux.testmode!(model, false) # bring model in training mode
|
|
|
Flux.train!(loss, params(model), train_set, opt)
|
|
|
opt.eta = adapt_learnrate(i)
|
|
|
- if ( rem(i, printout_interval) == 0 )
|
|
|
- Flux.testmode!(model, true)
|
|
|
- @printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set))
|
|
|
+ if (rem(i, printout_interval) == 0)
|
|
|
+ log(i, false)
|
|
|
end
|
|
|
end
|
|
|
- Flux.testmode!(model, true)
|
|
|
- if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
|
|
|
- else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
|
|
|
+ log(epochs, !validate)
|
|
|
end
|
|
|
|
|
|
# logging framework
|
|
@@ -169,6 +183,14 @@ global_logger(SimpleLogger(io)) # for debug outputs
|
|
|
@printf(Base.stdout, "Logging to File: %s\n", fp)
|
|
|
@printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
|
|
|
@printf(io, "%s\n", log_msg)
|
|
|
+
|
|
|
+# csv handling
|
|
|
+if (csv_out)
|
|
|
+ fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
|
|
|
+ io_csv = open(fp_csv, "w+") # read, write, create, truncate
|
|
|
+ @printf(io_csv, "epoch, loss(train)\n")
|
|
|
+end
|
|
|
+
|
|
|
# dump configuration
|
|
|
@debug begin
|
|
|
for symbol in names(Main)
|
|
@@ -190,7 +212,7 @@ if (usegpu)
|
|
|
end
|
|
|
|
|
|
|
|
|
-train_model(model, train_set, validation_set, test_set)
|
|
|
+train_model()
|
|
|
|
|
|
|
|
|
|