net.jl 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. This script implements the convolutional neural network for infering tap locations from motion data frames.
  4. For detailed documentation please read through ReportMotionLogger.pdf in the github repo.
  5. """
  6. using ArgParse
  7. s = ArgParseSettings()
  8. @add_arg_table s begin
  9. "--gpu"
  10. help = "set, if you want to train on the GPU"
  11. action = :store_true
  12. "--eval"
  13. help = "set, if you want to validate instead of test after training"
  14. action = :store_true
  15. "--epochs"
  16. help = "Number of epochs"
  17. arg_type = Int64
  18. default = 100
  19. "--logmsg"
  20. help = "additional message describing the training log"
  21. arg_type = String
  22. default = ""
  23. "--csv"
  24. help = "set, if you additionally want a csv output of the learning process"
  25. action = :store_true
  26. "--runD"
  27. help = "set, if you want to run the default config"
  28. action = :store_true
  29. end
  30. parsed_args = parse_args(ARGS, s)
  31. using Flux, Statistics
  32. using Flux: onecold
  33. using BSON
  34. using Dates
  35. using Printf
  36. using NNlib
  37. using FeedbackNets
  38. include("./dataManager.jl")
  39. include("./verbose.jl")
  40. using .dataManager: make_batch
  41. using .verbose
  42. using Logging
  43. using Random
  44. import LinearAlgebra: norm
  45. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  46. ######################
  47. # PARAMETERS
  48. ######################
  49. const batch_size = 100
  50. momentum = 0.99f0
  51. const lambda = 0.0005f0
  52. const delta = 0.00006
  53. learning_rate = 0.03f0
  54. validate = parsed_args["eval"]
  55. const epochs = parsed_args["epochs"]
  56. const decay_rate = 0.1f0
  57. const decay_step = 60
  58. const usegpu = parsed_args["gpu"]
  59. const printout_interval = 2
  60. const time_format = "HH:MM:SS"
  61. const date_format = "dd_mm_yyyy"
  62. data_size = (60, 6) # resulting in a 300ms frame
  63. # DEFAULT ARCHITECTURE
  64. channels = 1
  65. features = [32, 64, 128] # needs to find the relation between the axis which represents the screen position
  66. kernel = [(5,1), (5,1), (2,6)] # convolute only horizontally, last should convolute all 6 rows together to map relations between the channels
  67. pooldims = [(3,1), (3,1)]# (30,6) -> (15,6)
  68. # formula for calculating output dimensions of convolution:
  69. # dim1 = ((dim1 - Filtersize + 2 * padding) / stride) + 1
  70. inputDense = [0, 600, 300] # the first dense layer is automatically calculated
  71. dropout_rate = 0.3f0
  72. # random search values
  73. rs_momentum = [0.9, 0.92, 0.94, 0.96, 0.98, 0.99]
  74. rs_features = [[32, 64, 128], [64, 64, 64], [32, 32, 32], [96, 192, 192]]
  75. rs_dropout_rate = [0.1, 0.3, 0.4, 0.6, 0.8]
  76. rs_kernel = [[(3,1), (3,1), (3,6)], [(5,1), (5,1), (3,6)], [(7,1), (7,1), (3,6)], [(3,1), (3,1), (2,6)], [(5,1), (5,1), (2,6)], [(7,1), (7,1), (2,6)],
  77. [(7,1), (5,1), (2,6)], [(7,1), (5,1), (3,6)], [(5,1), (3,1), (2,6)], [(5,1), (3,1), (3,6)]]
  78. rs_pooldims = [[(2,1), (2,1)], [(3,1), (3,1)]]
  79. rs_learning_rate = [1, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001]
  80. dataset_folderpath = "../MATLAB/TrainingData/"
  81. dataset_name = "2019_09_09_1658"
  82. const model_save_location = "../trainedModels/"
  83. const log_save_location = "./logs/"
  84. if usegpu
  85. using CuArrays
  86. end
  87. debug_str = ""
  88. log_msg = parsed_args["logmsg"]
  89. csv_out = parsed_args["csv"]
  90. runD = parsed_args["runD"]
  91. io = nothing
  92. io_csv = nothing
  93. @debug begin
  94. global debug_str
  95. debug_str = "DEBUG_"
  96. "------DEBUGGING ACTIVATED------"
  97. end
  98. function adapt_learnrate(epoch_idx)
  99. return learning_rate * decay_rate^(epoch_idx / decay_step)
  100. end
  101. # TODO different idea for the accuracy: draw circle around ground truth and if prediction lays within the circle count this as a hit
  102. function accuracy(model, x, y)
  103. y_hat = Tracker.data(model(x))
  104. return mean(mapslices(button_number, y_hat, dims=1) .== mapslices(button_number, y, dims=1))
  105. end
  106. function accuracy(model, dataset)
  107. acc = 0.0f0
  108. for (data, labels) in dataset
  109. acc += accuracy(model, data, labels)
  110. end
  111. return acc / length(dataset)
  112. end
  113. function button_number(X)
  114. return (X[1] * 1080) ÷ 360 + 3 * ((X[2] * 980) ÷ 245)
  115. end
  116. function loss(model, x, y)
  117. # quadratic euclidean distance + norm
  118. return Flux.mse(model(x), y) + lambda * sum(norm, params(model))
  119. end
  120. function loss(model, dataset)
  121. loss_val = 0.0f0
  122. for (data, labels) in dataset
  123. loss_val += Tracker.data(loss(model, data, labels))
  124. end
  125. return loss_val / length(dataset)
  126. end
  127. function load_dataset()
  128. train = make_batch(dataset_folderpath, "$(dataset_name)_TRAIN.mat", normalize_data=false, truncate_data=false)
  129. val = make_batch(dataset_folderpath, "$(dataset_name)_VAL.mat", normalize_data=false, truncate_data=false)
  130. test = make_batch(dataset_folderpath, "$(dataset_name)_TEST.mat", normalize_data=false, truncate_data=false)
  131. return (train, val, test)
  132. end
  133. function create_model()
  134. return Chain(
  135. Conv(kernel[1], channels=>features[1], relu, pad=map(x -> x ÷ 2, kernel[1])),
  136. MaxPool(pooldims[1], stride=pooldims[1]),
  137. Conv(kernel[2], features[1]=>features[2], relu, pad=map(x -> x ÷ 2, kernel[2])),
  138. MaxPool(pooldims[2], stride=pooldims[2]),
  139. Conv(kernel[3], features[2]=>features[3], relu),
  140. # MaxPool(),
  141. flatten,
  142. Dense(prod((data_size .÷ pooldims[1] .÷ pooldims[2]) .- kernel[3] .+ 1) * features[3], inputDense[2], relu),
  143. Dropout(dropout_rate),
  144. Dense(inputDense[2], inputDense[3], relu),
  145. Dropout(dropout_rate),
  146. Dense(inputDense[3], 2, σ), # coordinates between 0 and 1
  147. )
  148. end
  149. function log(model, epoch, use_testset)
  150. Flux.testmode!(model, true)
  151. if(epoch == 0) # evalutation phase
  152. if(use_testset) @printf(io, "[%s] INIT Loss(test): f% Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, test_set), accuracy(model, test_set))
  153. else @printf(io, "[%s] INIT Loss(val): %f Accuarcy: %f\n", Dates.format(now(), time_format), loss(model, validation_set), accuracy(model, validation_set)) end
  154. elseif(epoch == epochs)
  155. @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
  156. if(use_testset)
  157. @printf(io, "[%s] FINAL(%d) Loss(test): %f Accuarcy: %f\n", Dates.format(now(), time_format), epoch, loss(model, test_set), accuracy(model, test_set))
  158. else
  159. @printf(io, "[%s] FINAL(%d) Loss(val): %f Accuarcy: %f\n", Dates.format(now(), time_format), epoch, loss(model, validation_set), accuracy(model, validation_set))
  160. end
  161. else # learning phase
  162. if (rem(epoch, printout_interval) == 0)
  163. @printf(io, "[%s] Epoch %3d: Loss(train): %f Loss(val): %f\n", Dates.format(now(), time_format), epoch, loss(model, train_set), loss(model, validation_set))
  164. end
  165. end
  166. Flux.testmode!(model, false)
  167. end
  168. function log_csv(model, epoch)
  169. Flux.testmode!(model, true)
  170. if(csv_out) @printf(io_csv, "%d, %f, %f\n", epoch, loss(model, train_set), loss(model, validation_set)) end
  171. Flux.testmode!(model, false)
  172. end
  173. function eval_model(model)
  174. Flux.testmode!(model, true)
  175. if (validate) return (loss(model, validation_set), accuracy(model, validation_set))
  176. else return (loss(model, test_set), accuracy(model, test_set)) end
  177. end
  178. function train_model()
  179. model = create_model()
  180. if (usegpu) model = gpu(model) end
  181. opt = Momentum(learning_rate, momentum)
  182. log(model, 0, !validate)
  183. Flux.testmode!(model, false) # bring model in training mode
  184. last_loss = loss(model, train_set)
  185. for i in 1:epochs
  186. flush(io)
  187. Flux.train!((x, y) -> loss(model, x, y), params(model), train_set, opt)
  188. opt.eta = adapt_learnrate(i)
  189. log_csv(model, i)
  190. log(model, i, !validate)
  191. end
  192. return eval_model(model)
  193. end
  194. function random_search()
  195. rng = MersenneTwister()
  196. results = []
  197. global epochs = 40
  198. for search in 1:800
  199. # create random set
  200. global momentum = rand(rng, rs_momentum)
  201. global features = rand(rng, rs_features)
  202. global dropout_rate = rand(rng, rs_dropout_rate)
  203. global kernel = rand(rng, rs_kernel)
  204. global pooldims = rand(rng, rs_pooldims)
  205. global learning_rate = rand(rng, rs_learning_rate)
  206. # printf configuration
  207. config1 = "momentum$(momentum), features=$(features), dropout_rate=$(dropout_rate)"
  208. config2 = "kernel=$(kernel), pooldims=$(pooldims), learning_rate=$(learning_rate)"
  209. @printf(io, "\nSearch %d of %d\n", search, 500)
  210. @printf(io, "%s\n", config1)
  211. @printf(io, "%s\n\n", config2)
  212. (loss, accuracy) = train_model()
  213. push!(results, (search, loss, accuracy))
  214. end
  215. return results
  216. end
  217. # logging framework
  218. fp = "$(log_save_location)$(debug_str)log_$(Dates.format(now(), date_format)).log"
  219. io = open(fp, "a+")
  220. global_logger(SimpleLogger(io)) # for debug outputs
  221. @printf(Base.stdout, "Logging to File: %s\n", fp)
  222. @printf(io, "\n--------[%s %s]--------\n", Dates.format(now(), date_format), Dates.format(now(), time_format))
  223. @printf(io, "%s\n", log_msg)
  224. # csv handling
  225. if (csv_out)
  226. fp_csv = "$(log_save_location)$(debug_str)csv_$(Dates.format(now(), date_format)).csv"
  227. io_csv = open(fp_csv, "w+") # read, write, create, truncate
  228. @printf(io_csv, "epoch, loss(train), loss(val)\n")
  229. end
  230. # dump configuration
  231. @debug begin
  232. for symbol in names(Main)
  233. var = "$(symbol) = $(eval(symbol))"
  234. @printf(io, "%s\n", var)
  235. end
  236. "--------End of VAR DUMP--------"
  237. end
  238. flush(io)
  239. flush(Base.stdout)
  240. train, validation, test = load_dataset()
  241. if (usegpu)
  242. const train_set = gpu.(train)
  243. const validation_set = gpu.(validation)
  244. const test_set = gpu.(test)
  245. end
  246. if(!runD)
  247. results = random_search()
  248. BSON.@save "results.bson" results
  249. #TODO sort and print best 5-10 results
  250. sort!(results, by = x -> x[2])
  251. # print results
  252. @printf("Best results by Loss:\n")
  253. for idx in 1:5
  254. @printf("#%d: Loss %f, accuracy %f in Search: %d\n", idx, results[idx][2], results[idx][3], results[idx][1])
  255. end
  256. sort!(results, by = x -> x[3], rev=true)
  257. @printf("Best results by Accuarcy:\n")
  258. for idx in 1:5
  259. @printf("#%d: Accuarcy: %f, Loss %f in Search: %d\n", idx, results[idx][3], results[idx][2], results[idx][1])
  260. end
  261. else
  262. train_model()
  263. end