dataManager.jl 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. module dataManager
  2. using MAT
  3. using Base.Iterators: repeated, partition
  4. using Statistics
  5. using Flux.Data.MNIST
  6. using Flux:onehotbatch
  7. """
  8. make_minibatch(X, Y, idxset)
  9. loads and bundles training data and labels into batches
  10. X should be of size Width x Height x channels x batchsize
  11. Y should be of size 2 x batchsize
  12. """
  13. function make_minibatch(X, Y, idxset)
  14. X_batch = Array{Float32}(undef, size(X, 1), size(X, 2), 1, length(idxset))
  15. Y_batch = Array{Float32}(undef, 2, length(idxset))
  16. for i in 1:length(idxset)
  17. X_batch[:, :, :, i] = Float32.(X[:, :, :, idxset[i]])
  18. Y_batch[:, i] = Float32.(Y[:, idxset[i]])
  19. end
  20. return (X_batch, Y_batch)
  21. end
  22. """
  23. make_batch(filepath, batch_size=100, normalize=true)
  24. Creates batches with size batch_size(default 100) from filenames at given filepath. Images will be normalized if normalize is set (default true).
  25. If batch_size equals -1 the batch size will be the size of the dataset
  26. Structure of the .mat file:
  27. fieldname | size
  28. ----------------
  29. data | 50 x 6 x N
  30. bin_targets | 2 x N
  31. where N denotes the number of samples, 50 is the window size and 6 are the number of channels
  32. """
  33. function make_batch(filepath, filenames...; batch_size=100, normalize_data=true, truncate_data=false)
  34. data = nothing # Array{Float64}(undef, 0)
  35. labels = nothing # Array{Float64}(undef, 0)
  36. for (i, filename) in enumerate(filenames)
  37. # load the data from the mat file
  38. file = "$filepath$filename"
  39. @debug("Reading $(i) of $(length(filenames)) from $(file)")
  40. matfile = matopen(file)
  41. # size(images) = (N, width, height, 1)
  42. dataPart = read(matfile, "data")
  43. # size(bin_targets) = (N, 10)
  44. labelsPart = read(matfile, "labels")
  45. close(matfile)
  46. if (isnothing(data)) data = dataPart; labels = labelsPart;
  47. else
  48. data = cat(dims=3, data, dataPart)
  49. labels = cat(dims=2, labels, labelsPart)
  50. end
  51. end
  52. # add singleton dimension and permute dims so it matches the convention of Flux width x height x channels x batchsize(Setsize)
  53. data = cat(dims=4, data)
  54. # rearrange the data array
  55. # size(data) = (50, 6, 1, N)
  56. data = permutedims(data, (1, 2, 4, 3))
  57. @debug("Dimension of data $(size(data))")
  58. @debug("Dimension of binary targets $(size(labels))")
  59. if(normalize_data)
  60. normalize!(data, truncate_data)
  61. end
  62. # Convert to Float32
  63. labels = convert(Array{Float32}, labels)
  64. data = convert(Array{Float32}, data)
  65. # display one sample of the images depends on PyPlot!
  66. # matshow(dropdims(images[:,:,:,10], dims=3), cmap=PyPlot.cm.gray, vmin=0, vmax=255)
  67. if ( batch_size == -1 )
  68. batch_size = size(data, 4)
  69. end
  70. idxsets = partition(1:size(data, 4), batch_size)
  71. data_set = [make_minibatch(data, labels, i) for i in idxsets];
  72. return data_set
  73. end # function make_batch
  74. """
  75. normalize input images along the batch and channel dimension
  76. input should have standart flux order: Widht x height x channels x batchsize
  77. if truncate is set to true the last 1% beyond 2.576 sigma will be clipped to 2.576 sigma
  78. """
  79. function normalize!(data, truncate)
  80. mean_data = mean(data, dims=4)
  81. std_data = std(data, mean=mean_data, dims=4)
  82. setsize = size(data, 4)
  83. @debug("normalize dataset")
  84. std_data_tmp = copy(std_data)
  85. std_data_tmp[std_data_tmp .== 0] .= 1
  86. for i in 1:setsize
  87. data[:, :, :, i] = (data[:, :, :, i] - mean_data) ./ std_data_tmp
  88. end
  89. if(truncate)
  90. # truncate the last 1% beyond 2.576 sigma
  91. data[data .> 2.576] .= 2.576
  92. data[data .< -2.576] .= -2.576
  93. end
  94. return (mean_data, std_data)
  95. end
  96. end # module dataManager