net.jl 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using ArgParse
  5. s = ArgParseSettings()
  6. @add_arg_table s begin
  7. "--gpu"
  8. help = "set, if you want to train on the GPU"
  9. action = :store_true
  10. "--eval"
  11. help = "set, if you want to validate instead of test after training"
  12. action = :store_true
  13. "--epochs"
  14. help = "Number of epochs"
  15. arg_type = Int64
  16. default = 40
  17. "--logmsg"
  18. help = "additional message describing the training log"
  19. arg_type = String
  20. default = ""
  21. "--csv"
  22. help = "set, if you additionally want a csv output of the learning process"
  23. action = :store_true
  24. "--runD"
  25. help = "set, if you want to run the default config"
  26. action = :store_true
  27. end
  28. parsed_args = parse_args(ARGS, s)
  29. using Flux, Statistics
  30. using Flux: onecold
  31. using BSON
  32. using Dates
  33. using Printf
  34. using NNlib
  35. using FeedbackNets
  36. include("./dataManager.jl")
  37. include("./verbose.jl")
  38. using .dataManager: make_batch
  39. using .verbose
  40. using Logging
  41. using Random
  42. import LinearAlgebra: norm
  43. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  44. ######################
  45. # PARAMETERS
  46. ######################
  47. const batch_size = 100
  48. momentum = 0.9f0
  49. const lambda = 0.0005f0
  50. const delta = 0.00001
  51. learning_rate = 0.1f0
  52. validate = parsed_args["eval"]
  53. const epochs = parsed_args["epochs"]
  54. const decay_rate = 0.1f0
  55. const decay_step = 40
  56. const usegpu = parsed_args["gpu"]
  57. const printout_interval = 2
  58. const time_format = "HH:MM:SS"
  59. const date_format = "dd_mm_yyyy"
  60. data_size = (60, 6) # resulting in a 300ms frame
  61. # DEFAULT ARCHITECTURE
  62. channels = 1
  63. features = [32, 64, 128] # needs to find the relation between the axis which represents the screen position
  64. kernel = [(3,1), (3,1), (3,6)] # convolute only horizontally, last should convolute all 6 rows together to map relations between the channels
  65. pooldims = [(2,1), (2,1)]# (30,6) -> (15,6)
  66. # formula for calculating output dimensions of convolution:
  67. # dim1 = ((dim1 - Filtersize + 2 * padding) / stride) + 1
  68. inputDense = [1664, 600, 300] # prod((data_size .÷ pooldims[1] .÷ pooldims[2]) .- kernel[3] .+ 1) * features[3]
  69. dropout_rate = 0.3f0
  70. # random search values
  71. rs_momentum = [0.9, 0.92, 0.94, 0.96, 0.98, 0.99]
  72. rs_features = [[32, 64, 128], [64, 64, 64], [32, 32, 32], [96, 192, 192]]
  73. rs_dropout_rate = [0.1, 0.3, 0.4, 0.6, 0.8]
  74. rs_kernel = [[(3,1), (3,1), (3,6)], [(5,1), (5,1), (3,6)], [(7,1), (7,1), (3,6)], [(3,1), (3,1), (2,6)], [(5,1), (5,1), (2,6)], [(7,1), (7,1), (2,6)],
  75. [(7,1), (5,1), (2,6)], [(7,1), (5,1), (3,6)], [(5,1), (3,1), (2,6)], [(5,1), (3,1), (3,6)]]
  76. rs_pooldims = [[(2,1), (2,1)], [(3,1), (3,1)]]
  77. rs_learning_rate = [1, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001]
  78. dataset_folderpath = "../MATLAB/TrainingData/"
  79. dataset_name = "2019_09_09_1658"
  80. const model_save_location = "../trainedModels/"
  81. const log_save_location = "./logs/"
  82. if usegpu
  83. using CuArrays
  84. end
  85. debug_str = ""
  86. log_msg = parsed_args["logmsg"]
  87. csv_out = parsed_args["csv"]
  88. runD = parsed_args["runD"]
  89. io = nothing
  90. io_csv = nothing
  91. @debug begin
  92. global debug_str
  93. debug_str = "DEBUG_"
  94. "------DEBUGGING ACTIVATED------"
  95. end
  96. function adapt_learnrate(epoch_idx)
  97. return learning_rate * decay_rate^(epoch_idx / decay_step)
  98. end
  99. # TODO different idea for the accuracy: draw circle around ground truth and if prediction lays within the circle count this as a hit
  100. function accuracy(model, x, y)
  101. y_hat = Tracker.data(model(x))
  102. return mean(mapslices(button_number, y_hat, dims=1) .== mapslices(button_number, y, dims=1))
  103. end
  104. function accuracy(model, dataset)
  105. acc = 0.0f0
  106. for (data, labels) in dataset
  107. acc += accuracy(model, data, labels)
  108. end
  109. return acc / length(dataset)
  110. end
  111. function button_number(X)
  112. return (X[1] * 1080) ÷ 360 + 3 * ((X[2] * 980) ÷ 245)
  113. end
  114. function loss(model, x, y)
  115. # quadratic euclidean distance + parameternorm
  116. return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
  117. end
  118. function loss(model, dataset)
  119. loss_val = 0.0f0
  120. for (data, labels) in dataset
  121. loss_val += Tracker.data(loss(model, data, labels))
  122. end
  123. return loss_val / length(dataset)
  124. end
  125. function load_dataset()
  126. train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
  127. val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
  128. test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
  129. return (train, val, test)
  130. end
  131. function create_model()
  132. return Chain(
  133. Conv(kernel[1], channels=>features[1], relu, pad=map(x -> x ÷ 2, kernel[1])),
  134. MaxPool(pooldims[1], stride=pooldims[1]),
  135. Conv(kernel[2], features[1]=>features[2], relu, pad=map(x -> x ÷ 2, kernel[2])),
  136. MaxPool(pooldims[2], stride=pooldims[2]),
  137. Conv(kernel[3], features[2]=>features[3], relu),
  138. # MaxPool(),
  139. flatten,
  140. Dense(prod((data_size .÷ pooldims[1] .÷ pooldims[2]) .- kernel[3] .+ 1) * features[3], inputDense[2], relu),
  141. Dropout(dropout_rate),
  142. Dense(inputDense[2], inputDense[3], relu),
  143. Dropout(dropout_rate),
  144. Dense(inputDense[3], 2, σ), # coordinates between 0 and 1
  145. )
  146. end
  147. function log(model, epoch, use_testset)
  148. Flux.testmode!(model, true)
  149. if(epoch == 0) # evalutation phase
  150. if(use_testset) @printf(io, "[%s] INIT Loss(test): f% Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, test_set), accuracy(model, test_set))
  151. else @printf(io, "[%s] INIT Loss(val): %f Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, validation_set), accuracy(model, validation_set)) end
  152. elseif(epoch == epochs)
  153. @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
  154. if(use_testset)
  155. @printf(io, "[%s] FINAL(%d) Loss(test): %f Accuarcy: %f\n", Dates.format(now(), time_format), epoch, loss(model, test_set), accuracy(model, test_set))
  156. else
  157. @printf(io, "[%s] FINAL(%d) Loss(val): %f Accuarcy: %f\n", Dates.format(now(), time_format), epoch, loss(model, validation_set), accuracy(model, validation_set))
  158. end
  159. else # learning phase
  160. if (rem(epoch, printout_interval) == 0)
  161. @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
  162. end
  163. end
  164. Flux.testmode!(model, false)
  165. end
  166. function log_csv(model, epoch)
  167. Flux.testmode!(model, true)
  168. if(csv_out) @printf(io_csv, "%d, %f, %f\n", epoch, loss(model, train_set), loss(model, validation_set)) end
  169. Flux.testmode!(model, false)
  170. end
  171. function eval_model(model)
  172. Flux.testmode!(model, true)
  173. if (validate) return (loss(model, validation_set), accuracy(model, validation_set))
  174. else return (loss(model, test_set), accuracy(model, test_set)) end
  175. end
  176. function train_model()
  177. model = create_model()
  178. if (usegpu) model = gpu(model) end
  179. opt = Momentum(learning_rate, momentum)
  180. log(model, 0, !validate)
  181. Flux.testmode!(model, false) # bring model in training mode
  182. last_loss = loss(model, train_set)
  183. for i in 1:epochs
  184. flush(io)
  185. Flux.train!((x, y) -> loss(model, x, y), params(model), train_set, opt)
  186. opt.eta = adapt_learnrate(i)
  187. log_csv(model, i)
  188. log(model, i, !validate)
  189. # early stopping
  190. curr_loss = loss(model, train_set)
  191. if(abs(last_loss - curr_loss) < delta)
  192. @printf(io, "Early stopping with Loss(train) %f at epoch %d (Accuracy: %f)\n", curr_loss, i, accuracy(model, validation_set))
  193. return eval_model(model)
  194. end
  195. last_loss = curr_loss
  196. end
  197. return eval_model(model)
  198. end
  199. function random_search()
  200. rng = MersenneTwister()
  201. results = []
  202. for search in 1:800
  203. # create random set
  204. global momentum = rand(rng, rs_momentum)
  205. global features = rand(rng, rs_features)
  206. global dropout_rate = rand(rng, rs_dropout_rate)
  207. global kernel = rand(rng, rs_kernel)
  208. global pooldims = rand(rng, rs_pooldims)
  209. global learning_rate = rand(rng, rs_learning_rate)
  210. # printf configuration
  211. config1 = "momentum$(momentum), features=$(features), dropout_rate=$(dropout_rate)"
  212. config2 = "kernel=$(kernel), pooldims=$(pooldims), learning_rate=$(learning_rate)"
  213. @printf(io, "\nSearch %d of %d\n", search, 500)
  214. @printf(io, "%s\n", config1)
  215. @printf(io, "%s\n\n", config2)
  216. (loss, accuracy) = train_model()
  217. push!(results, (search, loss, accuracy))
  218. end
  219. return results
  220. end
  221. # logging framework
  222. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  223. io = open(fp, "a+")
  224. global_logger(SimpleLogger(io)) # for debug outputs
  225. @printf(Base.stdout, "Logging to File: %s\n", fp)
  226. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  227. @printf(io, "%s\n", log_msg)
  228. # csv handling
  229. if (csv_out)
  230. fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
  231. io_csv = open(fp_csv, "w+") # read, write, create, truncate
  232. @printf(io_csv, "epoch, loss(train), loss(val)\n")
  233. end
  234. # dump configuration
  235. @debug begin
  236. for symbol in names(Main)
  237. var = "$(symbol) = $(eval(symbol))"
  238. @printf(io, "%s\n", var)
  239. end
  240. "--------End of VAR DUMP--------"
  241. end
  242. flush(io)
  243. flush(Base.stdout)
  244. train, validation, test = load_dataset()
  245. if (usegpu)
  246. const train_set = gpu.(train)
  247. const validation_set = gpu.(validation)
  248. const test_set = gpu.(test)
  249. end
  250. if(!runD)
  251. results = random_search()
  252. BSON.@save "results.bson" results
  253. #TODO sort and print best 5-10 results
  254. sort!(results, by = x -> x[2])
  255. # print results
  256. @printf("Best results by Loss:\n")
  257. for idx in 1:5
  258. @printf("#%d: Loss %f, accuracy %f in Search: %d\n", idx, results[idx][2], results[idx][3], results[idx][1])
  259. end
  260. sort!(results, by = x -> x[3], rev=true)
  261. @printf("Best results by Accuarcy:\n")
  262. for idx in 1:5
  263. @printf("#%d: Accuarcy: %f, Loss %f in Search: %d\n", idx, results[idx][3], results[idx][2], results[idx][1])
  264. end
  265. else
  266. train_model()
  267. end