dataManager.jl 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. module dataManager
  2. using MAT
  3. using Base.Iterators: repeated, partition
  4. using Statistics
  5. using Flux.Data.MNIST
  6. using Flux:onehotbatch
  7. # dimension of coordinates (labels): (x, y)
  8. lbls_dims = (1080, 980)
  9. # lbls_offset = (0, 699)
  10. """
  11. make_minibatch(X, Y, idxset)
  12. loads and bundles training data and labels into batches
  13. X should be of size Width x Height x channels x batchsize
  14. Y should be of size 2 x batchsize
  15. """
  16. function make_minibatch(X, Y, idxset)
  17. X_batch = Array{Float32}(undef, size(X, 1), size(X, 2), 1, length(idxset))
  18. Y_batch = Array{Float32}(undef, 2, length(idxset))
  19. for i in 1:length(idxset)
  20. X_batch[:, :, :, i] = Float32.(X[:, :, :, idxset[i]])
  21. Y_batch[:, i] = Float32.(Y[:, idxset[i]])
  22. end
  23. return (X_batch, Y_batch)
  24. end
  25. """
  26. make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
  27. Creates batches with size batch_size(default 100) from filenames at given filepath. Images will be normalized if normalize is set (default true).
  28. If batch_size equals -1 the batch size will be the size of the dataset
  29. Structure of the .mat file:
  30. fieldname | size
  31. ----------------
  32. data | 60 x 6 x N
  33. bin_targets | 2 x N (1: x, 2: y)
  34. where N denotes the number of samples, 50 is the window size and 6 are the number of channels
  35. """
  36. function make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
  37. data = nothing # Array{Float64}(undef, 0)
  38. labels = nothing # Array{Float64}(undef, 0)
  39. for (i, filename) in enumerate(filenames)
  40. # load the data from the mat file
  41. file = "$filepath$filename"
  42. @debug("Reading $(i) of $(length(filenames)) from $(file)")
  43. matfile = matopen(file)
  44. # size(images) = (N, width, height, 1)
  45. dataPart = read(matfile, "data")
  46. # size(bin_targets) = (N, 10)
  47. labelsPart = read(matfile, "labels")
  48. close(matfile)
  49. if (isnothing(data)) data = dataPart; labels = labelsPart;
  50. else
  51. data = cat(dims=3, data, dataPart)
  52. labels = cat(dims=2, labels, labelsPart)
  53. end
  54. end
  55. # add singleton dimension and permute dims so it matches the convention of Flux width x height x channels x batchsize(Setsize)
  56. data = cat(dims=4, data)
  57. # normalize the labels
  58. # labels = (labels .- lbls_offset) ./ lbls_dims
  59. labels = labels ./ lbls_dims;
  60. # rearrange the data array
  61. # size(data) = (50, 6, 1, N)
  62. data = permutedims(data, (1, 2, 4, 3))
  63. @debug("Dimension of data $(size(data))")
  64. @debug("Dimension of binary targets $(size(labels))")
  65. if(normalize_data)
  66. normalize!(data, truncate_data)
  67. end
  68. # Convert to Float32
  69. labels = convert(Array{Float32}, labels)
  70. data = convert(Array{Float32}, data)
  71. # display one sample of the images depends on PyPlot!
  72. # matshow(dropdims(images[:,:,:,10], dims=3), cmap=PyPlot.cm.gray, vmin=0, vmax=255)
  73. if ( batch_size == -1 )
  74. batch_size = size(data, 4)
  75. end
  76. idxsets = partition(1:size(data, 4), batch_size)
  77. data_set = [make_minibatch(data, labels, i) for i in idxsets];
  78. return data_set
  79. end # function make_batch
  80. """
  81. normalize input images along the batch and channel dimension
  82. input should have standart flux order: Widht x height x channels x batchsize
  83. if truncate is set to true the last 1% beyond 2.576 sigma will be clipped to 2.576 sigma
  84. """
  85. function normalize!(data, truncate)
  86. mean_data = mean(data, dims=4)
  87. std_data = std(data, mean=mean_data, dims=4)
  88. setsize = size(data, 4)
  89. @debug("normalize dataset")
  90. std_data_tmp = copy(std_data)
  91. std_data_tmp[std_data_tmp .== 0] .= 1
  92. for i in 1:setsize
  93. data[:, :, :, i] = (data[:, :, :, i] - mean_data) ./ std_data_tmp
  94. end
  95. if(truncate)
  96. # truncate the last 1% beyond 2.576 sigma
  97. data[data .> 2.576] .= 2.576
  98. data[data .< -2.576] .= -2.576
  99. end
  100. return (mean_data, std_data)
  101. end
  102. end # module dataManager