|
@@ -18,7 +18,7 @@ s = ArgParseSettings()
|
|
|
"--epochs"
|
|
|
help = "Number of epochs"
|
|
|
arg_type = Int64
|
|
|
- default = 500
|
|
|
+ default = 100
|
|
|
"--logmsg"
|
|
|
help = "additional message describing the training log"
|
|
|
arg_type = String
|
|
@@ -54,15 +54,16 @@ norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
|
|
|
const batch_size = 100
|
|
|
momentum = 0.99f0
|
|
|
const lambda = 0.0005f0
|
|
|
-const delta = 0.000001
|
|
|
+const delta = 6e-8
|
|
|
learning_rate = 0.003f0
|
|
|
validate = parsed_args["eval"]
|
|
|
const epochs = parsed_args["epochs"]
|
|
|
const decay_rate = 0.1f0
|
|
|
const decay_step = 40
|
|
|
const usegpu = parsed_args["gpu"]
|
|
|
-const printout_interval = 2
|
|
|
+const printout_interval = 1
|
|
|
const time_format = "HH:MM:SS"
|
|
|
+const time_print_format = "HH_MM_SS"
|
|
|
const date_format = "dd_mm_yyyy"
|
|
|
data_size = (60, 6) # resulting in a 300ms frame
|
|
|
|
|
@@ -166,7 +167,7 @@ function log(model, epoch, use_testset)
|
|
|
Flux.testmode!(model, true)
|
|
|
|
|
|
if(epoch == 0) # evalutation phase
|
|
|
- if(use_testset) @printf(io, "[%s] INIT Loss(test): f% Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, test_set), accuracy(model, test_set))
|
|
|
+ if(use_testset) @printf(io, "[%s] INIT Loss(test): %f Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, test_set), accuracy(model, test_set))
|
|
|
else @printf(io, "[%s] INIT Loss(val): %f Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, validation_set), accuracy(model, validation_set)) end
|
|
|
elseif(epoch == epochs)
|
|
|
@printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
|
|
@@ -177,7 +178,7 @@ function log(model, epoch, use_testset)
|
|
|
end
|
|
|
else # learning phase
|
|
|
if (rem(epoch, printout_interval) == 0)
|
|
|
- @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
|
|
|
+ @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f acc(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set), accuracy(model, validation_set))
|
|
|
end
|
|
|
end
|
|
|
|
|
@@ -214,30 +215,32 @@ function train_model()
|
|
|
log(model, i, !validate)
|
|
|
|
|
|
# stop if network converged or is showing signs of overfitting
|
|
|
- curr_loss_train = loss(model, train_set)
|
|
|
- curr_loss_val = loss(model, validation_set)
|
|
|
- if(abs(last_loss_train - curr_loss_train) < delta)
|
|
|
- converged_epochs += 1
|
|
|
- if(converged_epochs == 8)
|
|
|
- @printf(io, "Converged at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
- return eval_model(model)
|
|
|
- end
|
|
|
- else
|
|
|
- converged_epochs = 0
|
|
|
- end
|
|
|
-
|
|
|
- if((curr_loss_val - last_loss_val) > 0 )
|
|
|
- overfitting_epochs += 1
|
|
|
- if(overfitting_epochs == 10)
|
|
|
- @printf(io, "Stopping before overfitting at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
- return eval(model)
|
|
|
- end
|
|
|
- else
|
|
|
- overfitting_epochs = 0
|
|
|
- end
|
|
|
+ #curr_loss_train = Tracker.data(loss(model, train_set))
|
|
|
+ #curr_loss_val = Tracker.data(loss(model, validation_set))
|
|
|
+ #if(abs(last_loss_train - curr_loss_train) < delta)
|
|
|
+ # converged_epochs += 1
|
|
|
+ # # @show converged_epochs
|
|
|
+ # if(converged_epochs == 8)
|
|
|
+ # @printf(io, "Converged at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
+ # return eval_model(model)
|
|
|
+ # end
|
|
|
+ #else
|
|
|
+ # # @show "reset convereged $(abs(last_loss_train - curr_loss_train)) $(abs(last_loss_train - curr_loss_train) < delta)"
|
|
|
+ # converged_epochs = 0
|
|
|
+ #end
|
|
|
+ #
|
|
|
+ #if((curr_loss_val - last_loss_val) > 0 )
|
|
|
+ # overfitting_epochs += 1
|
|
|
+ # if(overfitting_epochs == 10)
|
|
|
+ # @printf(io, "Stopping before overfitting at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
+ # return eval(model)
|
|
|
+ # end
|
|
|
+ #else
|
|
|
+ # overfitting_epochs = 0
|
|
|
+ #end
|
|
|
|
|
|
- last_loss_train = curr_loss_train
|
|
|
- last_loss_val = curr_loss_val
|
|
|
+ #last_loss_train = curr_loss_train
|
|
|
+ #last_loss_val = curr_loss_val
|
|
|
end
|
|
|
return eval_model(model)
|
|
|
end
|
|
@@ -252,7 +255,7 @@ global_logger(SimpleLogger(io)) # for debug outputs
|
|
|
|
|
|
# csv handling
|
|
|
if (csv_out)
|
|
|
- fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format))_$(Dates.format(now(), time_format)).csv"
|
|
|
+ fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format))_$(Dates.format(now(), time_print_format)).csv"
|
|
|
io_csv = open(fp_csv, "w+") # read, write, create, truncate
|
|
|
@printf(io_csv, "epoch, loss(train), loss(val)\n")
|
|
|
end
|
|
@@ -279,7 +282,7 @@ for rate in rs_learning_rate
|
|
|
learning_rate = rate
|
|
|
for decay in rs_decay_step
|
|
|
decay_step = decay
|
|
|
- config = "learning_rate=$(learning_rate), decay_rate=$(decay_rate)"
|
|
|
+ config = "learning_rate=$(learning_rate), decay_step=$(decay_step)"
|
|
|
@printf(io, "\nConfiguration %s\n", config)
|
|
|
train_model()
|
|
|
end
|