net.jl 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using Flux, Statistics
  5. using Flux: onecold
  6. using BSON
  7. using Dates
  8. using Printf
  9. using NNlib
  10. include("./dataManager.jl")
  11. using .dataManager: make_batch
  12. using Logging
  13. import LinearAlgebra: norm
  14. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  15. ######################
  16. # PARAMETERS
  17. ######################
  18. const batch_size = 100
  19. const momentum = 0.9f0
  20. const lambda = 0.0005f0
  21. init_learning_rate = 0.1f0
  22. learning_rate = init_learning_rate
  23. const epochs = 100
  24. const decay_rate = 0.1f0
  25. const decay_step = 40
  26. const usegpu = true
  27. const printout_interval = 5
  28. const save_interval = 25
  29. const time_format = "HH:MM:SS"
  30. const date_format = "dd_mm_yyyy"
  31. data_size = (48, 6) # resulting in a 240ms frame
  32. # ARCHITECTURE
  33. inputDense1 =
  34. inputDense2 =
  35. inputDense3 =
  36. # enter the datasets and models you want to train
  37. dataset_folderpath = "../MATLAB/TrainingData/"
  38. const model_save_location = "../trainedModels/"
  39. const log_save_location = "../logs/"
  40. if usegpu
  41. using CuArrays
  42. end
  43. debug_str = ""
  44. @debug begin
  45. global debug_str
  46. debug_str = "DEBUG_"
  47. "------DEBUGGING ACTIVATED------"
  48. end
  49. io = nothing
  50. function adapt_learnrate(epoch_idx)
  51. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  52. end
  53. function load_dataset(dataset_name)
  54. end
  55. model = Chain(
  56. Conv(kernel, channels=>features, relu, pad=map(x -> x ÷ 2, kernel)),
  57. MaxPool(pooldims1, stride=()),
  58. Conv(relu, pad=map(x -> x ÷ 2, kernel)),
  59. MaxPool(),
  60. Conv(relu, pad=map(x -> x ÷ 2, kernel)),
  61. MaxPool(),
  62. flatten,
  63. Dense(inputDense1, inputDense2, σ),
  64. Dense(inputDense2, inputDense3, σ),
  65. Dense(inputDense3, 2) # identity to output coordinates!
  66. )