net_local.jl 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using ArgParse
  5. s = ArgParseSettings()
  6. @add_arg_table s begin
  7. "--gpu"
  8. help = "set, if you want to train on the GPU"
  9. action = :store_true
  10. "--eval"
  11. help = "set, if you want to validate instead of test after training"
  12. action = :store_true
  13. "--learn"
  14. help = "learning rate"
  15. arg_type = Float32
  16. default = 0.1f0
  17. "--epochs"
  18. help = "Number of epochs"
  19. arg_type = Int64
  20. default = 100
  21. "--logmsg"
  22. help = "additional message describing the training log"
  23. arg_type = String
  24. default = ""
  25. "--csv"
  26. help = "set, if you additionally want a csv output of the learning process"
  27. action = :store_true
  28. end
  29. parsed_args = parse_args(ARGS, s)
  30. using Flux, Statistics
  31. using Flux: onecold
  32. using BSON
  33. using Dates
  34. using Printf
  35. using NNlib
  36. using FeedbackNets
  37. include("./dataManager.jl")
  38. include("./verbose.jl")
  39. using .dataManager: make_batch
  40. using .verbose
  41. using Logging
  42. import LinearAlgebra: norm
  43. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  44. ######################
  45. # PARAMETERS
  46. ######################
  47. const batch_size = 100
  48. const momentum = 0.9f0
  49. const lambda = 0.0005f0
  50. learning_rate = parsed_args["learn"]
  51. validate = parsed_args["eval"]
  52. const epochs = parsed_args["epochs"]
  53. const decay_rate = 0.1f0
  54. const decay_step = 40
  55. const usegpu = parsed_args["gpu"]
  56. const printout_interval = 5
  57. const save_interval = 25
  58. const time_format = "HH:MM:SS"
  59. const date_format = "dd_mm_yyyy"
  60. data_size = (60, 6) # resulting in a 300ms frame
  61. # ARCHITECTURE
  62. channels = 1
  63. features1 = 32
  64. features2 = 64
  65. features3 = 128 # needs to find the relation between the axis which represents the screen position
  66. kernel1 = (3,1) # convolute only horizontally
  67. kernel2 = kernel1 # same here
  68. kernel3 = (3, 6) # this should convolute all 6 rows together to map relations between the channels
  69. pooldims1 = (2,1)# (30,6)
  70. pooldims2 = (2,1)# (15,6)
  71. # pooldims3 = (2,1)# (1, 4)
  72. inputDense1 = 1664 # prod(data_size .÷ pooldims1 .÷ pooldims2 .÷ kernel3) * features3
  73. inputDense2 = 600
  74. inputDense3 = 300
  75. dropout_rate = 0.3f0
  76. dataset_folderpath = "../MATLAB/TrainingData/"
  77. dataset_name = "2019_09_09_1658"
  78. const model_save_location = "../trainedModels/"
  79. const log_save_location = "./logs/"
  80. if usegpu
  81. using CuArrays
  82. end
  83. debug_str = ""
  84. log_msg = parsed_args["logmsg"]
  85. csv_out = parse_args["csv"]
  86. @debug begin
  87. global debug_str
  88. debug_str = "DEBUG_"
  89. "------DEBUGGING ACTIVATED------"
  90. end
  91. io = nothing
  92. io_csv = nothing
  93. function adapt_learnrate(epoch_idx)
  94. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  95. end
  96. function loss(model, x, y)
  97. # quadratic euclidean distance + parameternorm
  98. return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
  99. end
  100. function loss(model, dataset)
  101. loss_val = 0.0f0
  102. for (data, labels) in dataset
  103. loss_val += Tracker.data(loss(model, data, labels))
  104. end
  105. return loss_val / length(dataset)
  106. end
  107. function load_dataset()
  108. train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
  109. val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
  110. test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
  111. return (train, val, test)
  112. end
  113. function create_model()
  114. return Chain(
  115. Conv(kernel1, channels=>features1, relu, pad=map(x -> x ÷ 2, kernel1)),
  116. MaxPool(pooldims1, stride=pooldims1),
  117. Conv(kernel2, features1=>features2, relu, pad=map(x -> x ÷ 2, kernel2)),
  118. MaxPool(pooldims2, stride=pooldims2),
  119. Conv(kernel3, features2=>features3, relu),
  120. # MaxPool(),
  121. flatten,
  122. Dense(inputDense1, inputDense2, relu),
  123. Dropout(dropout_rate),
  124. Dense(inputDense2, inputDense3, relu),
  125. Dropout(dropout_rate),
  126. Dense(inputDense3, 2, σ), # coordinates between 0 and 1
  127. )
  128. end
  129. function log(model, epoch, use_testset)
  130. Flux.testmode!(model, true)
  131. if(epoch == 0 | epoch == epochs) # evalutation phase
  132. if(use_testset) @printf(io, "[%s] Epoch %3d: Loss(test): %f\n", Dates.format(now(), time_format), epoch, loss(model, test_set))
  133. else @printf(io, "[%s] Epoch %3d: Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, validation_set)) end
  134. else # learning phase
  135. @printf(io, "[%s] Epoch %3d: Loss(train): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set))
  136. end
  137. if(csv_out) @printf(io_csv, "%d, %f\n", epoch, loss(model, train_set)) end
  138. Flux.testmode!(model, false)
  139. end
  140. function train_model()
  141. model = create_model()
  142. if(usegpu) model = gpu(model) end
  143. opt = Momentum(learning_rate, momentum)
  144. log(model, 0, !validate)
  145. for i in 1:epochs
  146. flush(io)
  147. Flux.testmode!(model, false) # bring model in training mode
  148. Flux.train!((x, y) -> loss(model, x, y), params(model), train_set, opt)
  149. opt.eta = adapt_learnrate(i)
  150. if (rem(i, printout_interval) == 0)
  151. log(model, i, false)
  152. end
  153. end
  154. log(model, epochs, !validate)
  155. end
  156. # logging framework
  157. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  158. io = open(fp, "a+")
  159. global_logger(SimpleLogger(io)) # for debug outputs
  160. @printf(Base.stdout, "Logging to File: %s\n", fp)
  161. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  162. @printf(io, "%s\n", log_msg)
  163. # csv handling
  164. if (csv_out)
  165. fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
  166. io_csv = open(fp_csv, "w+") # read, write, create, truncate
  167. @printf(io_csv, "epoch, loss(train)\n")
  168. end
  169. # dump configuration
  170. @debug begin
  171. for symbol in names(Main)
  172. var = "$(symbol) = $(eval(symbol))"
  173. @printf(io, "%s\n", var)
  174. end
  175. "--------End of VAR DUMP--------"
  176. end
  177. flush(io)
  178. flush(Base.stdout)
  179. train_set, validation_set, test_set = load_dataset()
  180. if (usegpu)
  181. train_set = gpu.(train_set)
  182. validation_set = gpu.(validation_set)
  183. test_set = gpu.(test_set)
  184. end
  185. train_model()