package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Max;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Min;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/NormalizerMinMaxScaler.class */
public class NormalizerMinMaxScaler implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerMinMaxScaler.class);
    private INDArray min;
    private INDArray max;
    private INDArray maxMinusMin;
    private double minRange;
    private double maxRange;

    public NormalizerMinMaxScaler(double d, double d2) {
        setMinRange(d);
        setMaxRange(d2);
    }

    public NormalizerMinMaxScaler() {
        this(0.0d, 1.0d);
    }

    public void setMinRange(double d) {
        this.minRange = d;
    }

    public void setMaxRange(double d) {
        this.maxRange = d;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSet dataSet) {
        this.min = dataSet.getFeatureMatrix().min(0);
        this.max = dataSet.getFeatureMatrix().max(0);
        this.maxMinusMin = this.max.sub(this.min);
        this.maxMinusMin.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.maxMinusMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = dataSetIterator.next();
            if (this.min == null) {
                fit(next);
            } else {
                this.min = Nd4j.getExecutioner().execAndReturn((TransformOp) new Min(next.getFeatureMatrix().min(0), this.min, this.min, this.min.length()));
                this.max = Nd4j.getExecutioner().execAndReturn((TransformOp) new Max(next.getFeatureMatrix().max(0), this.max, this.max, this.max.length()));
            }
        }
        this.maxMinusMin = this.max.sub(this.min).add(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.maxMinusMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        dataSetIterator.reset();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization, org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(DataSet dataSet) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        if (this.maxRange - this.minRange < 0.0d) {
            throw new RuntimeException("API_USE_ERROR: The given max value minus min value has to be greater than 0");
        }
        dataSet.getFeatures().subiRowVector(this.min);
        dataSet.getFeatures().diviRowVector(this.maxMinusMin);
        dataSet.getFeatures().divi(Double.valueOf((this.maxRange - this.minRange) + Nd4j.EPS_THRESHOLD));
        dataSet.getFeatures().addi(Double.valueOf(this.minRange));
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(DataSet dataSet) {
        preProcess(dataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            preProcess(dataSetIterator.next());
        }
        dataSetIterator.reset();
    }

    public void revertPreProcess(DataSet dataSet) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        dataSet.getFeatures().subi(Double.valueOf(this.minRange));
        dataSet.getFeatures().muli(Double.valueOf((this.maxRange - this.minRange) + Nd4j.EPS_THRESHOLD));
        dataSet.getFeatures().muliRowVector(this.maxMinusMin);
        dataSet.getFeatures().addiRowVector(this.min);
    }

    public void revert(DataSet dataSet) {
        revertPreProcess(dataSet);
    }

    public void revert(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            revertPreProcess(dataSetIterator.next());
        }
        dataSetIterator.reset();
    }

    public INDArray getMin() {
        if (this.min == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.min;
    }

    public INDArray getMax() {
        if (this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.max;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void load(File... fileArr) throws IOException {
        this.min = Nd4j.readBinary(fileArr[0]);
        this.max = Nd4j.readBinary(fileArr[1]);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void save(File... fileArr) throws IOException {
        Nd4j.saveBinary(this.min, fileArr[0]);
        Nd4j.saveBinary(this.max, fileArr[1]);
    }
}
