package ai.djl.training.util;

import ai.djl.ndarray.NDArray;

/* loaded from: input_file:ai/djl/training/util/MinMaxScaler.class */
public class MinMaxScaler implements AutoCloseable {
    private NDArray fittedMin;
    private NDArray fittedMax;
    private NDArray fittedRange;
    private float minRange;
    private float maxRange = 1.0f;
    private boolean detached;

    public MinMaxScaler fit(NDArray nDArray, int[] iArr) {
        this.fittedMin = nDArray.min(iArr);
        this.fittedMax = nDArray.max(iArr);
        this.fittedRange = this.fittedMax.sub(this.fittedMin);
        if (this.detached) {
            detach();
        }
        return this;
    }

    public MinMaxScaler fit(NDArray nDArray) {
        fit(nDArray, new int[]{0});
        return this;
    }

    public NDArray transform(NDArray nDArray) {
        if (this.fittedRange == null) {
            fit(nDArray, new int[]{0});
        }
        return scale(nDArray.sub(this.fittedMin).divi(this.fittedRange));
    }

    public NDArray transformi(NDArray nDArray) {
        if (this.fittedRange == null) {
            fit(nDArray, new int[]{0});
        }
        return scale(nDArray.subi(this.fittedMin).divi(this.fittedRange));
    }

    private NDArray scale(NDArray nDArray) {
        return (this.maxRange == 1.0f && this.minRange == 0.0f) ? nDArray : nDArray.muli(Float.valueOf(this.maxRange - this.minRange)).addi(Float.valueOf(this.minRange));
    }

    private NDArray inverseScale(NDArray nDArray) {
        return (this.maxRange == 1.0f && this.minRange == 0.0f) ? nDArray.duplicate() : nDArray.sub(Float.valueOf(this.minRange)).divi(Float.valueOf(this.maxRange - this.minRange));
    }

    private NDArray inverseScalei(NDArray nDArray) {
        return (this.maxRange == 1.0f && this.minRange == 0.0f) ? nDArray : nDArray.subi(Float.valueOf(this.minRange)).divi(Float.valueOf(this.maxRange - this.minRange));
    }

    public NDArray inverseTransform(NDArray nDArray) {
        throwsIllegalStateWhenNotFitted();
        return inverseScale(nDArray).muli(this.fittedRange).addi(this.fittedMin);
    }

    public NDArray inverseTransformi(NDArray nDArray) {
        throwsIllegalStateWhenNotFitted();
        return inverseScalei(nDArray).muli(this.fittedRange).addi(this.fittedMin);
    }

    private void throwsIllegalStateWhenNotFitted() {
        if (this.fittedRange == null) {
            throw new IllegalStateException("Min Max Scaler is not fitted");
        }
    }

    public MinMaxScaler detach() {
        this.detached = true;
        if (this.fittedMin != null) {
            this.fittedMin.detach();
        }
        if (this.fittedMax != null) {
            this.fittedMax.detach();
        }
        if (this.fittedRange != null) {
            this.fittedRange.detach();
        }
        return this;
    }

    public MinMaxScaler optRange(float f, float f2) {
        this.minRange = f;
        this.maxRange = f2;
        return this;
    }

    public NDArray getMin() {
        throwsIllegalStateWhenNotFitted();
        return this.fittedMin;
    }

    public NDArray getMax() {
        throwsIllegalStateWhenNotFitted();
        return this.fittedMax;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.fittedMin != null) {
            this.fittedMin.close();
        }
        if (this.fittedMax != null) {
            this.fittedMax.close();
        }
        if (this.fittedRange != null) {
            this.fittedRange.close();
        }
    }
}
