|
@@ -18,7 +18,7 @@ s = ArgParseSettings()
|
|
|
"--epochs"
|
|
|
help = "Number of epochs"
|
|
|
arg_type = Int64
|
|
|
- default = 40
|
|
|
+ default = 500
|
|
|
"--logmsg"
|
|
|
help = "additional message describing the training log"
|
|
|
arg_type = String
|
|
@@ -52,10 +52,10 @@ norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
|
|
|
# PARAMETERS
|
|
|
######################
|
|
|
const batch_size = 100
|
|
|
-momentum = 0.9f0
|
|
|
+momentum = 0.99f0
|
|
|
const lambda = 0.0005f0
|
|
|
-const delta = 0.00001
|
|
|
-learning_rate = 0.1f0
|
|
|
+const delta = 0.000001
|
|
|
+learning_rate = 0.003f0
|
|
|
validate = parsed_args["eval"]
|
|
|
const epochs = parsed_args["epochs"]
|
|
|
const decay_rate = 0.1f0
|
|
@@ -69,14 +69,14 @@ data_size = (60, 6) # resulting in a 300ms frame
|
|
|
# ARCHITECTURE
|
|
|
channels = 1
|
|
|
features = [32, 64, 128] # needs to find the relation between the axis which represents the screen position
|
|
|
-kernel = [(3,1), (3,1), (3,6)] # convolute only horizontally, last should convolute all 6 rows together to map relations between the channels
|
|
|
-pooldims = [(2,1), (2,1)]# (30,6) -> (15,6)
|
|
|
+kernel = [(5,1), (5,1), (2,6)] # convolute only horizontally, last should convolute all 6 rows together to map relations between the channels
|
|
|
+pooldims = [(3,1), (3,1)]# (30,6) -> (15,6)
|
|
|
# formula for calculating output dimensions of convolution:
|
|
|
# dim1 = ((dim1 - Filtersize + 2 * padding) / stride) + 1
|
|
|
-inputDense = [1664, 600, 300] # prod((data_size .÷ pooldims[1] .÷ pooldims[2]) .- kernel[3] .+ 1) * features[3]
|
|
|
+inputDense = [0, 600, 300]
|
|
|
dropout_rate = 0.3f0
|
|
|
|
|
|
-rs_learning_rate = [0.3, 0.1, 0.03] # [1, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001]
|
|
|
+rs_learning_rate = [0.03, 0.01, 0.003] # [1, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001]
|
|
|
rs_decay_step = [20, 40, 60]
|
|
|
|
|
|
dataset_folderpath = "../MATLAB/TrainingData/"
|
|
@@ -213,25 +213,25 @@ function train_model()
|
|
|
log_csv(model, i)
|
|
|
log(model, i, !validate)
|
|
|
|
|
|
- # stopp if network converged or is showing signs of overfitting
|
|
|
+ # stop if network converged or is showing signs of overfitting
|
|
|
curr_loss_train = loss(model, train_set)
|
|
|
curr_loss_val = loss(model, validation_set)
|
|
|
if(abs(last_loss_train - curr_loss_train) < delta)
|
|
|
- converged_epochs++
|
|
|
- if(converged_epochs == 5)
|
|
|
+ converged_epochs += 1
|
|
|
+ if(converged_epochs == 8)
|
|
|
@printf(io, "Converged at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
+ return eval_model(model)
|
|
|
end
|
|
|
- return eval_model(model)
|
|
|
else
|
|
|
converged_epochs = 0
|
|
|
end
|
|
|
|
|
|
if((curr_loss_val - last_loss_val) > 0 )
|
|
|
- overfitting_epochs++
|
|
|
- if(overfitting_epochs == 8)
|
|
|
+ overfitting_epochs += 1
|
|
|
+ if(overfitting_epochs == 10)
|
|
|
@printf(io, "Stopping before overfitting at Loss(train): %f, Loss(val): %f in epoch %d with accuracy(val): %f\n", curr_loss_train, curr_loss_val, i, accuracy(model, validation_set))
|
|
|
+ return eval(model)
|
|
|
end
|
|
|
- return eval(model)
|
|
|
else
|
|
|
overfitting_epochs = 0
|
|
|
end
|
|
@@ -242,32 +242,6 @@ function train_model()
|
|
|
return eval_model(model)
|
|
|
end
|
|
|
|
|
|
-function random_search()
|
|
|
- rng = MersenneTwister()
|
|
|
- results = []
|
|
|
- for search in 1:800
|
|
|
- # create random set
|
|
|
- global momentum = rand(rng, rs_momentum)
|
|
|
- global features = rand(rng, rs_features)
|
|
|
- global dropout_rate = rand(rng, rs_dropout_rate)
|
|
|
- global kernel = rand(rng, rs_kernel)
|
|
|
- global pooldims = rand(rng, rs_pooldims)
|
|
|
- global learning_rate = rand(rng, rs_learning_rate)
|
|
|
-
|
|
|
- # printf configuration
|
|
|
- config1 = "momentum$(momentum), features=$(features), dropout_rate=$(dropout_rate)"
|
|
|
- config2 = "kernel=$(kernel), pooldims=$(pooldims), learning_rate=$(learning_rate)"
|
|
|
- @printf(io, "\nSearch %d of %d\n", search, 500)
|
|
|
- @printf(io, "%s\n", config1)
|
|
|
- @printf(io, "%s\n\n", config2)
|
|
|
-
|
|
|
- (loss, accuracy) = train_model()
|
|
|
- push!(results, (search, loss, accuracy))
|
|
|
- end
|
|
|
- return results
|
|
|
-end
|
|
|
-
|
|
|
-
|
|
|
# logging framework
|
|
|
fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
|
|
|
io = open(fp, "a+")
|
|
@@ -278,7 +252,7 @@ global_logger(SimpleLogger(io)) # for debug outputs
|
|
|
|
|
|
# csv handling
|
|
|
if (csv_out)
|
|
|
- fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
|
|
|
+ fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format))_$(Dates.format(now(), time_format)).csv"
|
|
|
io_csv = open(fp_csv, "w+") # read, write, create, truncate
|
|
|
@printf(io_csv, "epoch, loss(train), loss(val)\n")
|
|
|
end
|