Quellcode durchsuchen

label normalization and more docu in dataManager.jl
changed tprintf macro to function

Sebastian Vendt vor 6 Jahren
Ursprung
Commit
b77dd443b9

BIN
MATLAB/TrainingData/2019_09_09_1658_TEST.mat


BIN
MATLAB/TrainingData/2019_09_09_1658_TRAIN.mat


BIN
MATLAB/TrainingData/2019_09_09_1658_VAL.mat


+ 10 - 2
julia/dataManager.jl

@@ -5,6 +5,11 @@ using Base.Iterators: repeated, partition
 using Statistics
 using Flux.Data.MNIST
 using Flux:onehotbatch
+
+# dimension of coordinates (labels): (x, y)
+lbls_dims = (1080, 980)
+lbls_offset = (0, 699)
+
 """
 	make_minibatch(X, Y, idxset)
 	
@@ -24,7 +29,7 @@ function make_minibatch(X, Y, idxset)
 end
 
 """
-    make_batch(filepath, batch_size=100, normalize=true)
+    make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
     
 Creates batches with size batch_size(default 100) from filenames at given filepath. Images will be normalized if normalize is set (default true). 
 If batch_size equals -1 the batch size will be the size of the dataset
@@ -33,7 +38,7 @@ Structure of the .mat file:
     fieldname | size
     ----------------
        data   | 50 x 6 x N
-  bin_targets | 2 x N
+  bin_targets | 2 x N (1: x, 2: y)
 
 where N denotes the number of samples, 50 is the window size and 6 are the number of channels
 """
@@ -59,6 +64,9 @@ function make_batch(filepath, filenames...; batch_size=100, normalize_data=true,
 	
 	# add singleton dimension and permute dims so it matches the convention of Flux width x height x channels x batchsize(Setsize)   
 	data = cat(dims=4, data)
+	
+	# normalize the labels 
+	labels = (labels .- lbls_offset) ./ lbls_dims
 
     # rearrange the data array 
 	# size(data) = (50, 6, 1, N)

+ 5 - 5
julia/net.jl

@@ -139,8 +139,8 @@ model = Chain(
 function train_model(model, train_set, validation_set, test_set)
    Flux.testmode!(model, true)
 	opt = Momentum(learning_rate, momentum)
-	if(validate) @tprintf(io, "INIT with Loss(val_set): %f\n", loss(validation_set)) 
-	else @tprintf(io, "INIT with Loss(test_set): %f\n", loss(test_set)) end
+	if(validate) tprintf(io, "INIT with Loss(val_set): %f\n", loss(validation_set)) 
+	else tprintf(io, "INIT with Loss(test_set): %f\n", loss(test_set)) end
 	
     for i in 1:epochs
 		flush(io)
@@ -149,12 +149,12 @@ function train_model(model, train_set, validation_set, test_set)
         opt.eta = adapt_learnrate(i)
         if ( rem(i, printout_interval) == 0 ) 
          Flux.testmode!(model, true)
-			@tprintf(io, "Epoch %3d: Loss: %f\n", i, loss(train_set)) 
+			tprintf(io, "Epoch %3d: Loss: %f\n", i, loss(train_set)) 
 		end 
     end
 	Flux.testmode!(model, true)
-	if(validate) @tprintf(io, "FINAL Loss(val_set): %f\n", loss(validation_set)) 
-	else @tprintf(io, "FINAL Loss(test_set): %f\n", loss(test_set)) end
+	if(validate) tprintf(io, "FINAL Loss(val_set): %f\n", loss(validation_set)) 
+	else tprintf(io, "FINAL Loss(test_set): %f\n", loss(test_set)) end
 end
 
 # logging framework 

+ 6 - 6
julia/verbose.jl

@@ -6,7 +6,7 @@ using Base.Printf: _printf, is_str_expr, fix_dec, DIGITS, DIGITSs, print_fixed,
                    ini_hex, ini_HEX, print_exp_a, decode_0ct, decode_HEX, ini_dec, print_exp_e,
                    decode_oct, _limit, SmallNumber
 
-export @tprintf
+export tprintf
 
 time_format = "HH:MM:SS"
 
@@ -16,14 +16,14 @@ time_format = "HH:MM:SS"
 
 Same as printf but with leading timestamps 
 """
-macro tprintf(args...)
-	isempty(args) && throw(ArgumentError("@printf: called with no arguments"))
+function tprintf(args...)
+	isempty(args) && throw(ArgumentError("tprintf: called with no arguments"))
     if isa(args[1], AbstractString) || is_str_expr(args[1])
-        _printf("@printf", :stdout, "[$(Dates.format(now(), time_format))]$(args[1])", args[2:end])
+        _printf("tprintf", :stdout, "[$(Dates.format(now(), time_format))] $(args[1])", args[2:end])
     else
         (length(args) >= 2 && (isa(args[2], AbstractString) || is_str_expr(args[2]))) ||
-            throw(ArgumentError("@printf: first or second argument must be a format string"))
-        _printf("@printf", esc(args[1]), "[$(Dates.format(now(), time_format))] $(args[2])", args[3:end])
+            throw(ArgumentError("tprintf: first or second argument must be a format string"))
+        _printf("tprintf", esc(args[1]), "[$(Dates.format(now(), time_format))] $(args[2])", args[3:end])
     end
 end