Browse Source

added command line arguments
finished architecture
added training function

Sebastian Vendt 6 years ago
parent
commit
b6a60c289e
1 changed files with 106 additions and 17 deletions
  1. 106 17
      julia/net.jl

+ 106 - 17
julia/net.jl

@@ -6,12 +6,33 @@ Author: Sebastian Vendt, University of Ulm
 
 
 """
 """
 
 
+using ArgParse
+s = ArgParseSettings()
+@add_arg_table s begin
+    "--gpu"
+        help = "set, if you want to train on the GPU"
+		action = :store_true
+	"--eval"
+		help = "set, if you want to validate instead of test after training"
+		action = :store_true
+    "--learn"
+        help = "learning rate"
+		arg_type = Float32
+		default = 0.1f0
+    "--epochs" 
+		help = "Number of epochs"
+		arg_type = Int64
+		default = 100
+end
+parsed_args = parse_args(ARGS, s)
+
 using Flux, Statistics
 using Flux, Statistics
 using Flux: onecold
 using Flux: onecold
 using BSON
 using BSON
 using Dates
 using Dates
 using Printf
 using Printf
 using NNlib
 using NNlib
+using FeedbackNets
 include("./dataManager.jl")
 include("./dataManager.jl")
 using .dataManager: make_batch
 using .dataManager: make_batch
 using Logging
 using Logging
@@ -25,12 +46,12 @@ norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
 const batch_size = 100
 const batch_size = 100
 const momentum = 0.9f0
 const momentum = 0.9f0
 const lambda = 0.0005f0
 const lambda = 0.0005f0
-init_learning_rate = 0.1f0
-learning_rate = init_learning_rate
-const epochs = 100
+learning_rate = parsed_args["learn"]
+validate = parse_args["eval"]
+const epochs = parsed_args["epochs"]
 const decay_rate = 0.1f0
 const decay_rate = 0.1f0
 const decay_step = 40
 const decay_step = 40
-const usegpu = true
+const usegpu = parsed_args["gpu"]
 const printout_interval = 5
 const printout_interval = 5
 const save_interval = 25
 const save_interval = 25
 const time_format = "HH:MM:SS"
 const time_format = "HH:MM:SS"
@@ -38,10 +59,20 @@ const date_format = "dd_mm_yyyy"
 data_size = (48, 6) # resulting in a 240ms frame
 data_size = (48, 6) # resulting in a 240ms frame
 
 
 # ARCHITECTURE
 # ARCHITECTURE
-inputDense1 = 
-inputDense2 = 
-inputDense3 = 
-# enter the datasets and models you want to train
+channels = 1
+features1 = 32
+features2 = 64
+features3 = 256 # needs to find the relation between the axis which represents the screen position 
+kernel1 = (3,1)  # convolute only horizontally
+kernel2 = kernel1  # same here
+kernel3 = (3, 6) # this should convolute all 6 rows together to map relations between the channels  
+pooldims1 = (2,1)# (24,6)
+pooldims2 = (2,1)# (12,6)
+# pooldims3 = (2,1)# (1, 4)
+inputDense1 = prod(data_size .÷ pooldims1 .÷ pooldims2 .÷ kernel3) * features3
+inputDense2 = 500
+inputDense3 = 500
+dropout_rate = 0.1f0
 
 
 dataset_folderpath = "../MATLAB/TrainingData/"
 dataset_folderpath = "../MATLAB/TrainingData/"
 
 
@@ -65,23 +96,81 @@ function adapt_learnrate(epoch_idx)
     return init_learning_rate * decay_rate^(epoch_idx / decay_step)
     return init_learning_rate * decay_rate^(epoch_idx / decay_step)
 end
 end
 
 
-function load_dataset(dataset_name)
-
-
+function loss(x, y) 
+	# quadratic euclidean distance + parameternorm?
+	return Flux.mse(model(x), y)
 end
 end
 
 
+function loss(dataset)
+	loss_val = 0.0f0
+	for (data, labels) in dataset
+		loss_val += Tracker.data(loss(data, labels))
+	end
+	return loss_val / length(dataset)
+end
 
 
+function load_dataset()
+	train = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
+	val = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
+	test = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
+	return (train, val, test)
+end
 
 
 model = Chain(
 model = Chain(
-	Conv(kernel, channels=>features, relu, pad=map(x -> x ÷ 2, kernel)),
-	MaxPool(pooldims1, stride=()),
-	Conv(relu, pad=map(x -> x ÷ 2, kernel)),
-	MaxPool(),
-	Conv(relu, pad=map(x -> x ÷ 2, kernel)),
-	MaxPool(),
+	Conv(kernel1, channels=>features1, relu, pad=map(x -> x ÷ 2, kernel1)),
+	MaxPool(pooldims1, stride=pooldims1), 
+	Conv(kernel2, features1=>features2, relu, pad=map(x -> x ÷ 2, kernel2)),
+	MaxPool(pooldims2, stride=pooldims2),
+	Conv(kernel3, features2=>features3, relu),
+	# MaxPool(),
 	flatten, 
 	flatten, 
 	Dense(inputDense1, inputDense2, σ),
 	Dense(inputDense1, inputDense2, σ),
+	Dropout(dropout_rate)
 	Dense(inputDense2, inputDense3, σ),
 	Dense(inputDense2, inputDense3, σ),
+	Dropout(dropout_rate)
 	Dense(inputDense3, 2) # identity to output coordinates!
 	Dense(inputDense3, 2) # identity to output coordinates!
 )
 )
 
 
+train_model(model, train_set, validation_set, test_set)
+	opt = Momentum(learning_rate, momentum)
+	if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
+	else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
+	
+	 
+    for i in 1:epochs
+		flush(io)
+        Flux.train!(loss, params(model), train_set, opt)
+        opt.eta = adapt_learnrate(i)
+        if ( rem(i, printout_interval) == 0 ) 
+			@printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set)) 
+		end 
+    end
+	
+	if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
+	else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) 
+end
+
+# logging framework 
+fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
+io = open(fp, "a+")
+global_logger(SimpleLogger(io)) # for debug outputs
+@printf(Base.stdout, "Logging to File: %s\n", fp)
+@printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
+# dump configuration 
+@debug begin
+	for symbol in names(Main)
+		var = "$(symbol) = $(eval(symbol))"
+		@printf(io, "%s\n", var)
+	end
+	"--------End of VAR DUMP--------"
+end
+flush(io)
+flush(Base.stdout)
+
+train_set, validation_set, test_set = load_dataset()
+
+train_model(model, train_set, validation_set, test_set)
+
+
+
+