net.jl 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using ArgParse
  5. s = ArgParseSettings()
  6. @add_arg_table s begin
  7. "--gpu"
  8. help = "set, if you want to train on the GPU"
  9. action = :store_true
  10. "--eval"
  11. help = "set, if you want to validate instead of test after training"
  12. action = :store_true
  13. "--learn"
  14. help = "learning rate"
  15. arg_type = Float32
  16. default = 0.1f0
  17. "--epochs"
  18. help = "Number of epochs"
  19. arg_type = Int64
  20. default = 100
  21. end
  22. parsed_args = parse_args(ARGS, s)
  23. using Flux, Statistics
  24. using Flux: onecold
  25. using BSON
  26. using Dates
  27. using Printf
  28. using NNlib
  29. using FeedbackNets
  30. include("./dataManager.jl")
  31. using .dataManager: make_batch
  32. using Logging
  33. import LinearAlgebra: norm
  34. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  35. ######################
  36. # PARAMETERS
  37. ######################
  38. const batch_size = 100
  39. const momentum = 0.9f0
  40. const lambda = 0.0005f0
  41. learning_rate = parsed_args["learn"]
  42. validate = parse_args["eval"]
  43. const epochs = parsed_args["epochs"]
  44. const decay_rate = 0.1f0
  45. const decay_step = 40
  46. const usegpu = parsed_args["gpu"]
  47. const printout_interval = 5
  48. const save_interval = 25
  49. const time_format = "HH:MM:SS"
  50. const date_format = "dd_mm_yyyy"
  51. data_size = (48, 6) # resulting in a 240ms frame
  52. # ARCHITECTURE
  53. channels = 1
  54. features1 = 32
  55. features2 = 64
  56. features3 = 256 # needs to find the relation between the axis which represents the screen position
  57. kernel1 = (3,1) # convolute only horizontally
  58. kernel2 = kernel1 # same here
  59. kernel3 = (3, 6) # this should convolute all 6 rows together to map relations between the channels
  60. pooldims1 = (2,1)# (24,6)
  61. pooldims2 = (2,1)# (12,6)
  62. # pooldims3 = (2,1)# (1, 4)
  63. inputDense1 = prod(data_size .÷ pooldims1 .÷ pooldims2 .÷ kernel3) * features3
  64. inputDense2 = 500
  65. inputDense3 = 500
  66. dropout_rate = 0.1f0
  67. dataset_folderpath = "../MATLAB/TrainingData/"
  68. const model_save_location = "../trainedModels/"
  69. const log_save_location = "../logs/"
  70. if usegpu
  71. using CuArrays
  72. end
  73. debug_str = ""
  74. @debug begin
  75. global debug_str
  76. debug_str = "DEBUG_"
  77. "------DEBUGGING ACTIVATED------"
  78. end
  79. io = nothing
  80. function adapt_learnrate(epoch_idx)
  81. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  82. end
  83. function loss(x, y)
  84. # quadratic euclidean distance + parameternorm?
  85. return Flux.mse(model(x), y)
  86. end
  87. function loss(dataset)
  88. loss_val = 0.0f0
  89. for (data, labels) in dataset
  90. loss_val += Tracker.data(loss(data, labels))
  91. end
  92. return loss_val / length(dataset)
  93. end
  94. function load_dataset()
  95. train = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
  96. val = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
  97. test = make_batch(dataset_folderpath, "", normalize_data=false, truncate_data=false)
  98. return (train, val, test)
  99. end
  100. model = Chain(
  101. Conv(kernel1, channels=>features1, relu, pad=map(x -> x ÷ 2, kernel1)),
  102. MaxPool(pooldims1, stride=pooldims1),
  103. Conv(kernel2, features1=>features2, relu, pad=map(x -> x ÷ 2, kernel2)),
  104. MaxPool(pooldims2, stride=pooldims2),
  105. Conv(kernel3, features2=>features3, relu),
  106. # MaxPool(),
  107. flatten,
  108. Dense(inputDense1, inputDense2, σ),
  109. Dropout(dropout_rate)
  110. Dense(inputDense2, inputDense3, σ),
  111. Dropout(dropout_rate)
  112. Dense(inputDense3, 2) # identity to output coordinates!
  113. )
  114. train_model(model, train_set, validation_set, test_set)
  115. opt = Momentum(learning_rate, momentum)
  116. if(validate) @printf(io, "[%s] INIT with Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  117. else @printf(io, "[%s] INIT with Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set)) end
  118. for i in 1:epochs
  119. flush(io)
  120. Flux.train!(loss, params(model), train_set, opt)
  121. opt.eta = adapt_learnrate(i)
  122. if ( rem(i, printout_interval) == 0 )
  123. @printf(io, "[%s] Epoch %3d: Loss: %f\n", Dates.format(now(), time_format), i, loss(train_set))
  124. end
  125. end
  126. if(validate) @printf(io, "[%s] FINAL Loss(val_set): %f\n", Dates.format(now(), time_format), loss(validation_set))
  127. else @printf(io, "[%s] FINAL Loss(test_set): %f\n", Dates.format(now(), time_format), loss(test_set))
  128. end
  129. # logging framework
  130. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  131. io = open(fp, "a+")
  132. global_logger(SimpleLogger(io)) # for debug outputs
  133. @printf(Base.stdout, "Logging to File: %s\n", fp)
  134. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  135. # dump configuration
  136. @debug begin
  137. for symbol in names(Main)
  138. var = "$(symbol) = $(eval(symbol))"
  139. @printf(io, "%s\n", var)
  140. end
  141. "--------End of VAR DUMP--------"
  142. end
  143. flush(io)
  144. flush(Base.stdout)
  145. train_set, validation_set, test_set = load_dataset()
  146. train_model(model, train_set, validation_set, test_set)