dataManager.jl 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. module dataManager
  2. using MAT
  3. using Base.Iterators: repeated, partition
  4. using Statistics
  5. using Flux.Data.MNIST
  6. using Flux:onehotbatch
  7. # dimension of coordinates (labels): (x, y)
  8. lbls_dims = (1080, 980)
  9. lbls_offset = (0, 699)
  10. """
  11. make_minibatch(X, Y, idxset)
  12. loads and bundles training data and labels into batches
  13. X should be of size Width x Height x channels x batchsize
  14. Y should be of size 2 x batchsize
  15. """
  16. function make_minibatch(X, Y, idxset)
  17. X_batch = Array{Float32}(undef, size(X, 1), size(X, 2), 1, length(idxset))
  18. Y_batch = Array{Float32}(undef, 2, length(idxset))
  19. for i in 1:length(idxset)
  20. X_batch[:, :, :, i] = Float32.(X[:, :, :, idxset[i]])
  21. Y_batch[:, i] = Float32.(Y[:, idxset[i]])
  22. end
  23. return (X_batch, Y_batch)
  24. end
  25. """
  26. make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
  27. Creates batches with size batch_size(default 100) from filenames at given filepath. Images will be normalized if normalize is set (default true).
  28. If batch_size equals -1 the batch size will be the size of the dataset
  29. Structure of the .mat file:
  30. fieldname | size
  31. ----------------
  32. data | 60 x 6 x N
  33. bin_targets | 2 x N (1: x, 2: y)
  34. where N denotes the number of samples, 50 is the window size and 6 are the number of channels
  35. """
  36. function make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
  37. data = nothing # Array{Float64}(undef, 0)
  38. labels = nothing # Array{Float64}(undef, 0)
  39. for (i, filename) in enumerate(filenames)
  40. # load the data from the mat file
  41. file = "$filepath$filename"
  42. @debug("Reading $(i) of $(length(filenames)) from $(file)")
  43. matfile = matopen(file)
  44. # size(images) = (N, width, height, 1)
  45. dataPart = read(matfile, "data")
  46. # size(bin_targets) = (N, 10)
  47. labelsPart = read(matfile, "labels")
  48. close(matfile)
  49. if (isnothing(data)) data = dataPart; labels = labelsPart;
  50. else
  51. data = cat(dims=3, data, dataPart)
  52. labels = cat(dims=2, labels, labelsPart)
  53. end
  54. end
  55. # add singleton dimension and permute dims so it matches the convention of Flux width x height x channels x batchsize(Setsize)
  56. data = cat(dims=4, data)
  57. # normalize the labels
  58. labels = (labels .- lbls_offset) ./ lbls_dims
  59. # rearrange the data array
  60. # size(data) = (50, 6, 1, N)
  61. data = permutedims(data, (1, 2, 4, 3))
  62. @debug("Dimension of data $(size(data))")
  63. @debug("Dimension of binary targets $(size(labels))")
  64. if(normalize_data)
  65. normalize!(data, truncate_data)
  66. end
  67. # Convert to Float32
  68. labels = convert(Array{Float32}, labels)
  69. data = convert(Array{Float32}, data)
  70. # display one sample of the images depends on PyPlot!
  71. # matshow(dropdims(images[:,:,:,10], dims=3), cmap=PyPlot.cm.gray, vmin=0, vmax=255)
  72. if ( batch_size == -1 )
  73. batch_size = size(data, 4)
  74. end
  75. idxsets = partition(1:size(data, 4), batch_size)
  76. data_set = [make_minibatch(data, labels, i) for i in idxsets];
  77. return data_set
  78. end # function make_batch
  79. """
  80. normalize input images along the batch and channel dimension
  81. input should have standart flux order: Widht x height x channels x batchsize
  82. if truncate is set to true the last 1% beyond 2.576 sigma will be clipped to 2.576 sigma
  83. """
  84. function normalize!(data, truncate)
  85. mean_data = mean(data, dims=4)
  86. std_data = std(data, mean=mean_data, dims=4)
  87. setsize = size(data, 4)
  88. @debug("normalize dataset")
  89. std_data_tmp = copy(std_data)
  90. std_data_tmp[std_data_tmp .== 0] .= 1
  91. for i in 1:setsize
  92. data[:, :, :, i] = (data[:, :, :, i] - mean_data) ./ std_data_tmp
  93. end
  94. if(truncate)
  95. # truncate the last 1% beyond 2.576 sigma
  96. data[data .> 2.576] .= 2.576
  97. data[data .< -2.576] .= -2.576
  98. end
  99. return (mean_data, std_data)
  100. end
  101. end # module dataManager