package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/NormalizerStandardize.class */
public class NormalizerStandardize implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private int runningTotal;
    private int batchCount;
    private INDArray featureMeanStd;
    private INDArray labelMeanStd;
    private INDArray featureMean;
    private INDArray featureStd;
    private INDArray labelMean;
    private INDArray labelStd;
    private int labelRunningTotal = 0;
    private int labelbatchCount = 0;
    private int featureRank = 2;
    private boolean fitLabels = false;

    private INDArray fit(INDArray iNDArray) {
        INDArray mean = iNDArray.mean(0);
        INDArray std = iNDArray.std(0);
        std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        return Nd4j.vstack(mean, std).dup();
    }

    private void runnningFit(INDArray iNDArray, INDArray iNDArray2, int i, int i2, boolean z) {
        if (z) {
            iNDArray2.getRow(1).divi(Integer.valueOf(i2));
            iNDArray2.putRow(1, Transforms.sqrt(iNDArray2.getRow(1)));
            iNDArray2.getRow(1).addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
            if (iNDArray2.getRow(0).min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
                logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
                return;
            }
            return;
        }
        INDArray row = iNDArray2.getRow(0);
        INDArray row2 = iNDArray2.getRow(1);
        INDArray add = row.add(iNDArray.subRowVector(row).sum(0).divi(Integer.valueOf(i2)));
        INDArray mul = Transforms.pow(iNDArray.mean(0).subRowVector(row), 2).mul(Float.valueOf(((i2 - i) * i) / i2));
        INDArray pow = Transforms.pow(iNDArray.std(0), 2);
        pow.muli(Integer.valueOf(i));
        row2.addi(pow);
        row2.addi(mul);
        iNDArray2.putRow(0, add);
    }

    public void fitLabel(boolean z) {
        this.fitLabels = z;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSet dataSet) {
        this.featureRank = dataSet.getFeatures().rank();
        INDArray features = dataSet.getFeatures();
        if (this.featureRank == 3) {
            features = tailor3d2d(dataSet, true);
        }
        if (this.featureRank == 4) {
            features = tailor4d2d(dataSet, true);
        }
        this.featureMeanStd = fit(features);
        this.featureMean = this.featureMeanStd.getRow(0).dup();
        this.featureStd = this.featureMeanStd.getRow(1).dup();
        if (this.fitLabels) {
            INDArray labels = dataSet.getLabels();
            if (this.featureRank == 3) {
                labels = tailor3d2d(dataSet, false);
            }
            if (this.featureRank == 4) {
                labels = tailor4d2d(dataSet, false);
            }
            this.labelMeanStd = fit(labels);
            this.labelMean = this.labelMeanStd.getRow(0).dup();
            this.labelStd = this.labelMeanStd.getRow(1).dup();
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSetIterator dataSetIterator) {
        this.featureMeanStd = null;
        this.runningTotal = 0;
        this.labelRunningTotal = 0;
        while (dataSetIterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = dataSetIterator.next();
            this.batchCount = next.getFeaturesMaskArray() != null ? next.getFeaturesMaskArray().sumNumber().intValue() : next.getFeatures().size(0);
            this.runningTotal += this.batchCount;
            this.labelbatchCount = next.getLabelsMaskArray() != null ? next.getLabelsMaskArray().sumNumber().intValue() : next.getFeatures().size(0);
            this.labelRunningTotal += this.batchCount;
            if (this.featureMeanStd == null) {
                fit(next);
                this.featureMeanStd.getRow(1).muli(Integer.valueOf(this.batchCount));
                if (this.fitLabels) {
                    this.labelMeanStd.getRow(1).muli(Integer.valueOf(this.batchCount));
                }
            } else {
                INDArray features = next.getFeatures();
                if (this.featureRank == 3) {
                    features = tailor3d2d(next, true);
                }
                if (this.featureRank == 4) {
                    features = tailor4d2d(next, true);
                }
                runnningFit(features, this.featureMeanStd, this.batchCount, this.runningTotal, false);
                if (this.fitLabels) {
                    INDArray labels = next.getLabels();
                    if (this.featureRank == 3) {
                        labels = tailor3d2d(next, false);
                    }
                    if (this.featureRank == 4) {
                        labels = tailor4d2d(next, false);
                    }
                    runnningFit(labels, this.labelMeanStd, this.labelbatchCount, this.labelRunningTotal, false);
                }
            }
        }
        runnningFit(this.featureMeanStd, this.featureMeanStd, this.batchCount, this.runningTotal, true);
        this.featureMean = this.featureMeanStd.getRow(0).dup();
        this.featureStd = this.featureMeanStd.getRow(1).dup();
        if (this.fitLabels) {
            runnningFit(this.labelMeanStd, this.labelMeanStd, this.labelbatchCount, this.labelRunningTotal, true);
            this.labelMean = this.labelMeanStd.getRow(0).dup();
            this.labelStd = this.labelMeanStd.getRow(1).dup();
        }
        dataSetIterator.reset();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization, org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(DataSet dataSet) {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        INDArray features = dataSet.getFeatures();
        INDArray labels = dataSet.getLabels();
        preProcess(features, true);
        if (this.fitLabels) {
            preProcess(labels, false);
        }
    }

    private void preProcess(INDArray iNDArray, boolean z) {
        INDArray iNDArray2 = z ? this.featureMean : this.labelMean;
        INDArray iNDArray3 = z ? this.featureStd : this.labelStd;
        if (this.featureRank == 2) {
            iNDArray.subiRowVector(iNDArray2);
            iNDArray.diviRowVector(iNDArray3);
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastSubOp(iNDArray, iNDArray2, iNDArray, 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastDivOp(iNDArray, iNDArray3, iNDArray, 1));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(DataSet dataSet) {
        preProcess(dataSet);
    }

    public void transform(INDArray iNDArray) {
        transform(iNDArray, true);
    }

    public void transform(INDArray iNDArray, boolean z) {
        preProcess(iNDArray, z);
    }

    public void revert(DataSet dataSet) {
        if (this.featureMean == null || this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        if (this.featureRank == 2) {
            dataSet.getFeatures().muliRowVector(this.featureStd);
            dataSet.getFeatures().addiRowVector(this.featureMean);
        } else {
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastMulOp(dataSet.getFeatures(), this.featureStd, dataSet.getFeatures(), 1));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp) new BroadcastAddOp(dataSet.getFeatures(), this.featureMean, dataSet.getFeatures(), 1));
        }
    }

    public void revert(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            revert(dataSetIterator.next());
        }
        dataSetIterator.reset();
    }

    public INDArray getMean() {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureMean;
    }

    public INDArray getLabelMean() {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelMean;
    }

    public INDArray getStd() {
        if (this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureStd;
    }

    public INDArray getLabelStd() {
        if (this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelStd;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void load(File... fileArr) throws IOException {
        this.featureMean = Nd4j.readBinary(fileArr[0]);
        this.featureStd = Nd4j.readBinary(fileArr[1]);
        if (this.fitLabels) {
            this.labelMean = Nd4j.readBinary(fileArr[2]);
            this.labelStd = Nd4j.readBinary(fileArr[3]);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void save(File... fileArr) throws IOException {
        Nd4j.saveBinary(this.featureMean, fileArr[0]);
        Nd4j.saveBinary(this.featureStd, fileArr[1]);
        if (this.fitLabels) {
            Nd4j.saveBinary(this.labelMean, fileArr[2]);
            Nd4j.saveBinary(this.labelStd, fileArr[3]);
        }
    }

    private INDArray tailor3d2d(DataSet dataSet, boolean z) {
        INDArray features = z ? dataSet.getFeatures() : dataSet.getLabels();
        INDArray featuresMaskArray = z ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray();
        int size = features.size(0);
        int size2 = features.size(1);
        int size3 = features.size(2);
        boolean z2 = featuresMaskArray != null;
        INDArray create = Nd4j.create(size2, size3 * size);
        int tensorssAlongDimension = features.tensorssAlongDimension(2, 0);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            INDArray tensorAlongDimension = features.tensorAlongDimension(i, 2, 0);
            if (z2) {
                tensorAlongDimension.muli(featuresMaskArray);
            }
            create.putRow(i, Nd4j.toFlattened('c', tensorAlongDimension));
        }
        INDArray transpose = create.transpose();
        if (!z2) {
            return transpose;
        }
        INDArray transpose2 = Nd4j.toFlattened('c', featuresMaskArray).transpose();
        INDArray create2 = Nd4j.create(transpose2.sumNumber().intValue(), size2);
        int i2 = 0;
        for (int i3 = 0; i3 < size3 * size; i3++) {
            if (transpose2.getInt(i3, 0) != 0) {
                create2.putRow(i2, transpose.getRow(i3));
                i2++;
            }
        }
        return create2;
    }

    private INDArray tailor4d2d(DataSet dataSet, boolean z) {
        INDArray features = z ? dataSet.getFeatures() : dataSet.getLabels();
        INDArray create = Nd4j.create(features.size(1), features.size(2) * features.size(3) * features.size(0));
        int tensorssAlongDimension = features.tensorssAlongDimension(3, 2, 0);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            create.putRow(i, Nd4j.toFlattened(features.tensorAlongDimension(i, 3, 2, 0)));
        }
        return create.transposei();
    }
}
