net.jl 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. Author: Sebastian Vendt, University of Ulm
  3. """
  4. using Flux, Statistics
  5. using Flux: onecold
  6. using BSON
  7. using Dates
  8. using Printf
  9. using NNlib
  10. include("./dataManager.jl")
  11. using .dataManager: make_batch
  12. using Logging
  13. import LinearAlgebra: norm
  14. norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T))
  15. ######################
  16. # PARAMETERS
  17. ######################
  18. const batch_size = 100
  19. const momentum = 0.9f0
  20. const lambda = 0.0005f0
  21. init_learning_rate = 0.1f0
  22. learning_rate = init_learning_rate
  23. const epochs = 100
  24. const decay_rate = 0.1f0
  25. const decay_step = 40
  26. const usegpu = true
  27. const printout_interval = 5
  28. const save_interval = 25
  29. const time_format = "HH:MM:SS"
  30. const date_format = "dd_mm_yyyy"
  31. data_size = (50, 1)
  32. # ARCHITECTURE
  33. inputDense1
  34. inputDense2
  35. inputDense3
  36. classes = 2
  37. # enter the datasets and models you want to train
  38. dataset_folderpath = "../MATLAB/TrainingData/"
  39. const model_save_location = "../trainedModels/"
  40. const log_save_location = "../logs/"
  41. if usegpu
  42. using CuArrays
  43. end
  44. debug_str = ""
  45. @debug begin
  46. global debug_str
  47. debug_str = "DEBUG_"
  48. "------DEBUGGING ACTIVATED------"
  49. end
  50. io = nothing
  51. function adapt_learnrate(epoch_idx)
  52. return init_learning_rate * decay_rate^(epoch_idx / decay_step)
  53. end
  54. function load_dataset(dataset_name)
  55. end
  56. model = Chain(
  57. Conv(kernel, channels=>features, relu, pad=map(x -> x ÷ 2, kernel)),
  58. MaxPool(pooldims1, stride=()),
  59. Conv(relu, pad=map(x -> x ÷ 2, kernel)),
  60. MaxPool(),
  61. Conv(relu, pad=map(x -> x ÷ 2, kernel)),
  62. MaxPool(),
  63. flatten,
  64. Dense(inputDense1, inputDense2, σ),
  65. Dense(inputDense2, inputDense3, σ),
  66. Dense(inputDense3, classes) # identity to output coordinates!
  67. )