package org.nd4j.linalg.dataset;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.beans.ConstructorProperties;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

/* loaded from: input_file:org/nd4j/linalg/dataset/BalanceMinibatches.class */
public class BalanceMinibatches {
    private DataSetIterator dataSetIterator;
    private int numLabels;
    private Map<Integer, List<File>> paths;
    private int miniBatchSize;
    private File rootDir;
    private File rootSaveDir;
    private List<File> labelRootDirs;
    private DataNormalization dataNormalization;

    /* loaded from: input_file:org/nd4j/linalg/dataset/BalanceMinibatches$BalanceMinibatchesBuilder.class */
    public static class BalanceMinibatchesBuilder {
        private DataSetIterator dataSetIterator;
        private int numLabels;
        private Map<Integer, List<File>> paths;
        private int miniBatchSize;
        private File rootDir;
        private File rootSaveDir;
        private List<File> labelRootDirs;
        private DataNormalization dataNormalization;

        BalanceMinibatchesBuilder() {
        }

        public BalanceMinibatchesBuilder dataSetIterator(DataSetIterator dataSetIterator) {
            this.dataSetIterator = dataSetIterator;
            return this;
        }

        public BalanceMinibatchesBuilder numLabels(int i) {
            this.numLabels = i;
            return this;
        }

        public BalanceMinibatchesBuilder paths(Map<Integer, List<File>> map) {
            this.paths = map;
            return this;
        }

        public BalanceMinibatchesBuilder miniBatchSize(int i) {
            this.miniBatchSize = i;
            return this;
        }

        public BalanceMinibatchesBuilder rootDir(File file) {
            this.rootDir = file;
            return this;
        }

        public BalanceMinibatchesBuilder rootSaveDir(File file) {
            this.rootSaveDir = file;
            return this;
        }

        public BalanceMinibatchesBuilder labelRootDirs(List<File> list) {
            this.labelRootDirs = list;
            return this;
        }

        public BalanceMinibatchesBuilder dataNormalization(DataNormalization dataNormalization) {
            this.dataNormalization = dataNormalization;
            return this;
        }

        public BalanceMinibatches build() {
            return new BalanceMinibatches(this.dataSetIterator, this.numLabels, this.paths, this.miniBatchSize, this.rootDir, this.rootSaveDir, this.labelRootDirs, this.dataNormalization);
        }

        public String toString() {
            return "BalanceMinibatches.BalanceMinibatchesBuilder(dataSetIterator=" + this.dataSetIterator + ", numLabels=" + this.numLabels + ", paths=" + this.paths + ", miniBatchSize=" + this.miniBatchSize + ", rootDir=" + this.rootDir + ", rootSaveDir=" + this.rootSaveDir + ", labelRootDirs=" + this.labelRootDirs + ", dataNormalization=" + this.dataNormalization + ")";
        }
    }

    public void balance() {
        if (!this.rootDir.exists()) {
            this.rootDir.mkdirs();
        }
        if (!this.rootSaveDir.exists()) {
            this.rootDir.mkdirs();
        }
        if (this.paths == null) {
            this.paths = Maps.newHashMap();
        }
        if (this.labelRootDirs == null) {
            this.labelRootDirs = Lists.newArrayList();
        }
        for (int i = 0; i < this.numLabels; i++) {
            this.paths.put(Integer.valueOf(i), new ArrayList());
            this.labelRootDirs.add(new File(this.rootDir, String.valueOf(i)));
        }
        while (this.dataSetIterator.hasNext()) {
            DataSet next = this.dataSetIterator.next();
            if (this.miniBatchSize < 0) {
                this.miniBatchSize = next.numExamples();
            }
            for (int i2 = 0; i2 < next.numExamples(); i2++) {
                DataSet dataSet = next.get(i2);
                if (!this.labelRootDirs.get(dataSet.outcome()).exists()) {
                    this.labelRootDirs.get(dataSet.outcome()).mkdirs();
                }
                File file = new File(this.labelRootDirs.get(dataSet.outcome()), String.valueOf(this.paths.get(Integer.valueOf(dataSet.outcome())).size()));
                dataSet.save(file);
                this.paths.get(Integer.valueOf(dataSet.outcome())).add(file);
            }
        }
        int i3 = 0;
        while (!this.paths.isEmpty()) {
            ArrayList arrayList = new ArrayList();
            while (arrayList.size() < this.miniBatchSize && !this.paths.isEmpty()) {
                for (int i4 = 0; i4 < this.numLabels; i4++) {
                    if (this.paths.get(Integer.valueOf(i4)) == null || this.paths.get(Integer.valueOf(i4)).isEmpty()) {
                        this.paths.remove(Integer.valueOf(i4));
                    } else {
                        DataSet dataSet2 = new DataSet();
                        dataSet2.load(this.paths.get(Integer.valueOf(i4)).remove(0));
                        arrayList.add(dataSet2);
                    }
                }
            }
            if (!this.rootSaveDir.exists()) {
                this.rootSaveDir.mkdirs();
            }
            DataSet merge = DataSet.merge(arrayList);
            if (this.dataNormalization != null) {
                this.dataNormalization.transform(merge);
            }
            int i5 = i3;
            i3++;
            merge.save(new File(this.rootSaveDir, String.format("dataset-%d.bin", Integer.valueOf(i5))));
        }
    }

    public static BalanceMinibatchesBuilder builder() {
        return new BalanceMinibatchesBuilder();
    }

    @ConstructorProperties({"dataSetIterator", "numLabels", "paths", "miniBatchSize", "rootDir", "rootSaveDir", "labelRootDirs", "dataNormalization"})
    public BalanceMinibatches(DataSetIterator dataSetIterator, int i, Map<Integer, List<File>> map, int i2, File file, File file2, List<File> list, DataNormalization dataNormalization) {
        this.paths = Maps.newHashMap();
        this.miniBatchSize = -1;
        this.rootDir = new File("minibatches");
        this.rootSaveDir = new File("minibatchessave");
        this.labelRootDirs = new ArrayList();
        this.dataSetIterator = dataSetIterator;
        this.numLabels = i;
        this.paths = map;
        this.miniBatchSize = i2;
        this.rootDir = file;
        this.rootSaveDir = file2;
        this.labelRootDirs = list;
        this.dataNormalization = dataNormalization;
    }

    public DataSetIterator getDataSetIterator() {
        return this.dataSetIterator;
    }

    public int getNumLabels() {
        return this.numLabels;
    }

    public Map<Integer, List<File>> getPaths() {
        return this.paths;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public File getRootDir() {
        return this.rootDir;
    }

    public File getRootSaveDir() {
        return this.rootSaveDir;
    }

    public List<File> getLabelRootDirs() {
        return this.labelRootDirs;
    }

    public DataNormalization getDataNormalization() {
        return this.dataNormalization;
    }

    public void setDataSetIterator(DataSetIterator dataSetIterator) {
        this.dataSetIterator = dataSetIterator;
    }

    public void setNumLabels(int i) {
        this.numLabels = i;
    }

    public void setPaths(Map<Integer, List<File>> map) {
        this.paths = map;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public void setRootDir(File file) {
        this.rootDir = file;
    }

    public void setRootSaveDir(File file) {
        this.rootSaveDir = file;
    }

    public void setLabelRootDirs(List<File> list) {
        this.labelRootDirs = list;
    }

    public void setDataNormalization(DataNormalization dataNormalization) {
        this.dataNormalization = dataNormalization;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BalanceMinibatches)) {
            return false;
        }
        BalanceMinibatches balanceMinibatches = (BalanceMinibatches) obj;
        if (!balanceMinibatches.canEqual(this)) {
            return false;
        }
        DataSetIterator dataSetIterator = getDataSetIterator();
        DataSetIterator dataSetIterator2 = balanceMinibatches.getDataSetIterator();
        if (dataSetIterator == null) {
            if (dataSetIterator2 != null) {
                return false;
            }
        } else if (!dataSetIterator.equals(dataSetIterator2)) {
            return false;
        }
        if (getNumLabels() != balanceMinibatches.getNumLabels()) {
            return false;
        }
        Map<Integer, List<File>> paths = getPaths();
        Map<Integer, List<File>> paths2 = balanceMinibatches.getPaths();
        if (paths == null) {
            if (paths2 != null) {
                return false;
            }
        } else if (!paths.equals(paths2)) {
            return false;
        }
        if (getMiniBatchSize() != balanceMinibatches.getMiniBatchSize()) {
            return false;
        }
        File rootDir = getRootDir();
        File rootDir2 = balanceMinibatches.getRootDir();
        if (rootDir == null) {
            if (rootDir2 != null) {
                return false;
            }
        } else if (!rootDir.equals(rootDir2)) {
            return false;
        }
        File rootSaveDir = getRootSaveDir();
        File rootSaveDir2 = balanceMinibatches.getRootSaveDir();
        if (rootSaveDir == null) {
            if (rootSaveDir2 != null) {
                return false;
            }
        } else if (!rootSaveDir.equals(rootSaveDir2)) {
            return false;
        }
        List<File> labelRootDirs = getLabelRootDirs();
        List<File> labelRootDirs2 = balanceMinibatches.getLabelRootDirs();
        if (labelRootDirs == null) {
            if (labelRootDirs2 != null) {
                return false;
            }
        } else if (!labelRootDirs.equals(labelRootDirs2)) {
            return false;
        }
        DataNormalization dataNormalization = getDataNormalization();
        DataNormalization dataNormalization2 = balanceMinibatches.getDataNormalization();
        return dataNormalization == null ? dataNormalization2 == null : dataNormalization.equals(dataNormalization2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BalanceMinibatches;
    }

    public int hashCode() {
        DataSetIterator dataSetIterator = getDataSetIterator();
        int hashCode = (((1 * 59) + (dataSetIterator == null ? 0 : dataSetIterator.hashCode())) * 59) + getNumLabels();
        Map<Integer, List<File>> paths = getPaths();
        int hashCode2 = (((hashCode * 59) + (paths == null ? 0 : paths.hashCode())) * 59) + getMiniBatchSize();
        File rootDir = getRootDir();
        int hashCode3 = (hashCode2 * 59) + (rootDir == null ? 0 : rootDir.hashCode());
        File rootSaveDir = getRootSaveDir();
        int hashCode4 = (hashCode3 * 59) + (rootSaveDir == null ? 0 : rootSaveDir.hashCode());
        List<File> labelRootDirs = getLabelRootDirs();
        int hashCode5 = (hashCode4 * 59) + (labelRootDirs == null ? 0 : labelRootDirs.hashCode());
        DataNormalization dataNormalization = getDataNormalization();
        return (hashCode5 * 59) + (dataNormalization == null ? 0 : dataNormalization.hashCode());
    }

    public String toString() {
        return "BalanceMinibatches(dataSetIterator=" + getDataSetIterator() + ", numLabels=" + getNumLabels() + ", paths=" + getPaths() + ", miniBatchSize=" + getMiniBatchSize() + ", rootDir=" + getRootDir() + ", rootSaveDir=" + getRootSaveDir() + ", labelRootDirs=" + getLabelRootDirs() + ", dataNormalization=" + getDataNormalization() + ")";
    }
}
