Pārlūkot izejas kodu

created printout macro with timestamps

Sebastian Vendt 6 gadi atpakaļ
vecāks
revīzija
440a21f863
2 mainītis faili ar 23 papildinājumiem un 9 dzēšanām
  1. 22 8
      julia/net.jl
  2. 1 1
      julia/verbose.jl

+ 22 - 8
julia/net.jl

@@ -26,6 +26,8 @@ s = ArgParseSettings()
 end
 parsed_args = parse_args(ARGS, s)
 
+
+
 using Flux, Statistics
 using Flux: onecold
 using BSON
@@ -34,7 +36,9 @@ using Printf
 using NNlib
 using FeedbackNets
 include("./dataManager.jl")
+include("./verbose.jl")
 using .dataManager: make_batch
+using .verbose
 using Logging
 import LinearAlgebra: norm
 norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T)) 
@@ -75,6 +79,7 @@ inputDense3 = 500
 dropout_rate = 0.1f0
 
 dataset_folderpath = "../MATLAB/TrainingData/"
+dataset_name = "2019_09_09_1658"
 
 const model_save_location = "../trainedModels/"
 const log_save_location = "../logs/"
@@ -110,9 +115,9 @@ function loss(dataset)
 end
 
 function load_dataset()
-	train = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
-	val = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
-	test = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
+	train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
+	val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
+	test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
 	return (train, val, test)
 end
 
@@ -133,8 +138,8 @@ model = Chain(
 
 train_model(model, train_set, validation_set, test_set)
 	opt = Momentum(learning_rate, momentum)
-	if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
-	else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
+	if(validate) @tprintf(io, "INIT with Loss(val_set): %f\n", loss(validation_set)) 
+	else @tprintf(io, "INIT with Loss(test_set): %f\n", loss(test_set)) end
 	
 	 
     for i in 1:epochs
@@ -142,12 +147,12 @@ train_model(model, train_set, validation_set, test_set)
         Flux.train!(loss, params(model), train_set, opt)
         opt.eta = adapt_learnrate(i)
         if ( rem(i, printout_interval) == 0 ) 
-			@printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set)) 
+			@tprintf(io, "Epoch %3d: Loss: %f\n", i, loss(train_set)) 
 		end 
     end
 	
-	if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set)) 
-	else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) 
+	if(validate) @tprintf(io, "FINAL Loss(val_set): %f\n", loss(validation_set)) 
+	else @tprintf(io, "FINAL Loss(test_set): %f\n", loss(test_set)) 
 end
 
 # logging framework 
@@ -169,8 +174,17 @@ flush(Base.stdout)
 
 train_set, validation_set, test_set = load_dataset()
 
+if (usegpu)
+	train_set = gpu.(train_set)
+	validation_set = gpu.(validation_set)
+	test_set = gpu.(test_set)
+	model = gpu(model)
+end
+
+
 train_model(model, train_set, validation_set, test_set)
 
 
 
 
+

+ 1 - 1
julia/verbose.jl

@@ -23,7 +23,7 @@ macro tprintf(args...)
     else
         (length(args) >= 2 && (isa(args[2], AbstractString) || is_str_expr(args[2]))) ||
             throw(ArgumentError("@printf: first or second argument must be a format string"))
-        _printf("@printf", esc(args[1]), "[$(Dates.format(now(), time_format))]$(args[2])", args[3:end])
+        _printf("@printf", esc(args[1]), "[$(Dates.format(now(), time_format))] $(args[2])", args[3:end])
     end
 end