net.jl 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using ArgParse
  5. s = ArgParseSettings()
  6. @add_arg_table s begin
  7. "--gpu"
  8. help = "set, if you want to train on the GPU"
  9. action = :store_true
  10. "--eval"
  11. help = "set, if you want to validate instead of test after training"
  12. action = :store_true
  13. "--learn"
  14. help = "learning rate"
  15. arg_type = Float32
  16. default = 0.1f0
  17. "--epochs"
  18. help = "Number of epochs"
  19. arg_type = Int64
  20. default = 100
  21. end
  22. parsed_args = parse_args(ARGS, s)
  23. using Flux, Statistics
  24. using Flux: onecold
  25. using BSON
  26. using Dates
  27. using Printf
  28. using NNlib
  29. using FeedbackNets
  30. include("./dataManager.jl")
  31. include("./verbose.jl")
  32. using .dataManager: make_batch
  33. using .verbose
  34. using Logging
  35. import LinearAlgebra: norm
  36. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  37. ######################
  38. # PARAMETERS
  39. ######################
  40. const batch_size = 100
  41. const momentum = 0.9f0
  42. const lambda = 0.0005f0
  43. learning_rate = parsed_args["learn"]
  44. validate = parsed_args["eval"]
  45. const epochs = parsed_args["epochs"]
  46. const decay_rate = 0.1f0
  47. const decay_step = 40
  48. const usegpu = parsed_args["gpu"]
  49. const printout_interval = 5
  50. const save_interval = 25
  51. const time_format = "HH:MM:SS"
  52. const date_format = "dd_mm_yyyy"
  53. data_size = (60, 6) # resulting in a 300ms frame
  54. # ARCHITECTURE
  55. channels = 1
  56. features1 = 32
  57. features2 = 64
  58. features3 = 128 # needs to find the relation between the axis which represents the screen position
  59. kernel1 = (3,1) # convolute only horizontally
  60. kernel2 = kernel1 # same here
  61. kernel3 = (3, 6) # this should convolute all 6 rows together to map relations between the channels
  62. pooldims1 = (2,1)# (30,6)
  63. pooldims2 = (2,1)# (15,6)
  64. # pooldims3 = (2,1)# (1, 4)
  65. inputDense1 = 1664 # prod(data_size .÷ pooldims1 .÷ pooldims2 .÷ kernel3) * features3
  66. inputDense2 = 600
  67. inputDense3 = 300
  68. dropout_rate = 0.3f0
  69. dataset_folderpath = "../MATLAB/TrainingData/"
  70. dataset_name = "2019_09_09_1658"
  71. const model_save_location = "../trainedModels/"
  72. const log_save_location = "./logs/"
  73. if usegpu
  74. using CuArrays
  75. end
  76. debug_str = ""
  77. @debug begin
  78. global debug_str
  79. debug_str = "DEBUG_"
  80. "------DEBUGGING ACTIVATED------"
  81. end
  82. io = nothing
  83. function adapt_learnrate(epoch_idx)
  84. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  85. end
  86. function loss(x, y)
  87. # quadratic euclidean distance + parameternorm?
  88. return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
  89. end
  90. function loss(dataset)
  91. loss_val = 0.0f0
  92. for (data, labels) in dataset
  93. loss_val += Tracker.data(loss(data, labels))
  94. end
  95. return loss_val / length(dataset)
  96. end
  97. function load_dataset()
  98. train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
  99. val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
  100. test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
  101. return (train, val, test)
  102. end
  103. model = Chain(
  104. Conv(kernel1, channels=>features1, relu, pad=map(x -> x ÷ 2, kernel1)),
  105. MaxPool(pooldims1, stride=pooldims1),
  106. Conv(kernel2, features1=>features2, relu, pad=map(x -> x ÷ 2, kernel2)),
  107. MaxPool(pooldims2, stride=pooldims2),
  108. Conv(kernel3, features2=>features3, relu),
  109. # MaxPool(),
  110. flatten,
  111. Dense(inputDense1, inputDense2, relu),
  112. Dropout(dropout_rate),
  113. Dense(inputDense2, inputDense3, relu),
  114. Dropout(dropout_rate),
  115. Dense(inputDense3, 2, σ), # coordinates between 0 and 1
  116. )
  117. function train_model(model, train_set, validation_set, test_set)
  118. Flux.testmode!(model, true)
  119. opt = Momentum(learning_rate, momentum)
  120. if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  121. else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
  122. for i in 1:epochs
  123. flush(io)
  124. Flux.testmode!(model, false) # bring model in training mode
  125. Flux.train!(loss, params(model), train_set, opt)
  126. opt.eta = adapt_learnrate(i)
  127. if ( rem(i, printout_interval) == 0 )
  128. Flux.testmode!(model, true)
  129. @printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set))
  130. end
  131. end
  132. Flux.testmode!(model, true)
  133. if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  134. else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
  135. end
  136. # logging framework
  137. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  138. io = open(fp, "a+")
  139. global_logger(SimpleLogger(io)) # for debug outputs
  140. @printf(Base.stdout, "Logging to File: %s\n", fp)
  141. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  142. # dump configuration
  143. @debug begin
  144. for symbol in names(Main)
  145. var = "$(symbol) = $(eval(symbol))"
  146. @printf(io, "%s\n", var)
  147. end
  148. "--------End of VAR DUMP--------"
  149. end
  150. flush(io)
  151. flush(Base.stdout)
  152. train_set, validation_set, test_set = load_dataset()
  153. if (usegpu)
  154. train_set = gpu.(train_set)
  155. validation_set = gpu.(validation_set)
  156. test_set = gpu.(test_set)
  157. model = gpu(model)
  158. end
  159. train_model(model, train_set, validation_set, test_set)