net.jl 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using ArgParse
  5. s = ArgParseSettings()
  6. @add_arg_table s begin
  7. "--gpu"
  8. help = "set, if you want to train on the GPU"
  9. action = :store_true
  10. "--eval"
  11. help = "set, if you want to validate instead of test after training"
  12. action = :store_true
  13. "--learn"
  14. help = "learning rate"
  15. arg_type = Float32
  16. default = 0.1f0
  17. "--epochs"
  18. help = "Number of epochs"
  19. arg_type = Int64
  20. default = 100
  21. "--logmsg"
  22. help = "additional message describing the training log"
  23. arg_type = String
  24. default = ""
  25. end
  26. parsed_args = parse_args(ARGS, s)
  27. using Flux, Statistics
  28. using Flux: onecold
  29. using BSON
  30. using Dates
  31. using Printf
  32. using NNlib
  33. using FeedbackNets
  34. include("./dataManager.jl")
  35. include("./verbose.jl")
  36. using .dataManager: make_batch
  37. using .verbose
  38. using Logging
  39. import LinearAlgebra: norm
  40. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  41. ######################
  42. # PARAMETERS
  43. ######################
  44. const batch_size = 100
  45. const momentum = 0.9f0
  46. const lambda = 0.0005f0
  47. learning_rate = parsed_args["learn"]
  48. validate = parsed_args["eval"]
  49. const epochs = parsed_args["epochs"]
  50. const decay_rate = 0.1f0
  51. const decay_step = 40
  52. const usegpu = parsed_args["gpu"]
  53. const printout_interval = 5
  54. const save_interval = 25
  55. const time_format = "HH:MM:SS"
  56. const date_format = "dd_mm_yyyy"
  57. data_size = (60, 6) # resulting in a 300ms frame
  58. # ARCHITECTURE
  59. channels = 1
  60. features1 = 32
  61. features2 = 64
  62. features3 = 128 # needs to find the relation between the axis which represents the screen position
  63. kernel1 = (3,1) # convolute only horizontally
  64. kernel2 = kernel1 # same here
  65. kernel3 = (3, 6) # this should convolute all 6 rows together to map relations between the channels
  66. pooldims1 = (2,1)# (30,6)
  67. pooldims2 = (2,1)# (15,6)
  68. # pooldims3 = (2,1)# (1, 4)
  69. inputDense1 = 1664 # prod(data_size .÷ pooldims1 .÷ pooldims2 .÷ kernel3) * features3
  70. inputDense2 = 600
  71. inputDense3 = 300
  72. dropout_rate = 0.3f0
  73. dataset_folderpath = "../MATLAB/TrainingData/"
  74. dataset_name = "2019_09_09_1658"
  75. const model_save_location = "../trainedModels/"
  76. const log_save_location = "./logs/"
  77. if usegpu
  78. using CuArrays
  79. end
  80. debug_str = ""
  81. log_msg = parsed_args["logmsg"]
  82. @debug begin
  83. global debug_str
  84. debug_str = "DEBUG_"
  85. "------DEBUGGING ACTIVATED------"
  86. end
  87. io = nothing
  88. function adapt_learnrate(epoch_idx)
  89. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  90. end
  91. function loss(x, y)
  92. # quadratic euclidean distance + parameternorm?
  93. return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
  94. end
  95. function loss(dataset)
  96. loss_val = 0.0f0
  97. for (data, labels) in dataset
  98. loss_val += Tracker.data(loss(data, labels))
  99. end
  100. return loss_val / length(dataset)
  101. end
  102. function load_dataset()
  103. train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
  104. val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
  105. test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
  106. return (train, val, test)
  107. end
  108. model = Chain(
  109. Conv(kernel1, channels=>features1, relu, pad=map(x -> x ÷ 2, kernel1)),
  110. MaxPool(pooldims1, stride=pooldims1),
  111. Conv(kernel2, features1=>features2, relu, pad=map(x -> x ÷ 2, kernel2)),
  112. MaxPool(pooldims2, stride=pooldims2),
  113. Conv(kernel3, features2=>features3, relu),
  114. # MaxPool(),
  115. flatten,
  116. Dense(inputDense1, inputDense2, relu),
  117. Dropout(dropout_rate),
  118. Dense(inputDense2, inputDense3, relu),
  119. Dropout(dropout_rate),
  120. Dense(inputDense3, 2, σ), # coordinates between 0 and 1
  121. )
  122. function train_model(model, train_set, validation_set, test_set)
  123. Flux.testmode!(model, true)
  124. opt = Momentum(learning_rate, momentum)
  125. if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  126. else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
  127. for i in 1:epochs
  128. flush(io)
  129. Flux.testmode!(model, false) # bring model in training mode
  130. Flux.train!(loss, params(model), train_set, opt)
  131. opt.eta = adapt_learnrate(i)
  132. if ( rem(i, printout_interval) == 0 )
  133. Flux.testmode!(model, true)
  134. @printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set))
  135. end
  136. end
  137. Flux.testmode!(model, true)
  138. if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  139. else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
  140. end
  141. # logging framework
  142. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  143. io = open(fp, "a+")
  144. global_logger(SimpleLogger(io)) # for debug outputs
  145. @printf(Base.stdout, "Logging to File: %s\n", fp)
  146. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  147. @printf(io, "%s\n", log_msg)
  148. # dump configuration
  149. @debug begin
  150. for symbol in names(Main)
  151. var = "$(symbol) = $(eval(symbol))"
  152. @printf(io, "%s\n", var)
  153. end
  154. "--------End of VAR DUMP--------"
  155. end
  156. flush(io)
  157. flush(Base.stdout)
  158. train_set, validation_set, test_set = load_dataset()
  159. if (usegpu)
  160. train_set = gpu.(train_set)
  161. validation_set = gpu.(validation_set)
  162. test_set = gpu.(test_set)
  163. model = gpu(model)
  164. end
  165. train_model(model, train_set, validation_set, test_set)