package org.nd4j.linalg.activations.impl;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/activations/impl/ActivationReLU.class */
public class ActivationReLU extends BaseActivationFunction {
    private Double max;
    private Double threshold;
    private Double negativeSlope;

    public ActivationReLU() {
        this(null, null, null);
    }

    public ActivationReLU(Double d, Double d2, Double d3) {
        this.max = d;
        this.threshold = d2;
        this.negativeSlope = d3;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public INDArray getActivation(INDArray iNDArray, boolean z) {
        if (this.negativeSlope == null && this.threshold == null) {
            Nd4j.getExecutioner().exec((ScalarOp) new RectifiedLinear(iNDArray, iNDArray));
        } else {
            double doubleValue = this.threshold == null ? 0.0d : this.threshold.doubleValue();
            double doubleValue2 = this.negativeSlope == null ? 0.0d : this.negativeSlope.doubleValue();
            if (doubleValue == 0.0d) {
                Nd4j.getExecutioner().execAndReturn((ScalarOp) new LeakyReLU(iNDArray, doubleValue2));
            } else {
                iNDArray.assign(iNDArray.lt(Double.valueOf(doubleValue)).castTo(iNDArray.dataType()).muli(Double.valueOf(doubleValue2)).muli(iNDArray.sub(this.threshold)).addi(iNDArray.gte(Double.valueOf(doubleValue)).castTo(iNDArray.dataType()).muli(iNDArray)));
            }
        }
        if (this.max != null) {
            Nd4j.exec(new ScalarMin(iNDArray, null, iNDArray, this.max));
        }
        return iNDArray;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public Pair<INDArray, INDArray> backprop(INDArray iNDArray, INDArray iNDArray2) {
        INDArray muli;
        assertShape(iNDArray, iNDArray2);
        INDArray lt = (this.max == null || this.max.doubleValue() == 0.0d) ? null : iNDArray.lt(this.max);
        if (this.negativeSlope == null && this.threshold == null) {
            muli = Nd4j.getExecutioner().exec(new RectifiedLinearDerivative(iNDArray, iNDArray2, iNDArray.ulike(), this.threshold == null ? 0.0d : this.threshold.doubleValue()))[0];
        } else {
            double doubleValue = this.threshold == null ? 0.0d : this.threshold.doubleValue();
            double doubleValue2 = this.negativeSlope == null ? 0.0d : this.negativeSlope.doubleValue();
            muli = doubleValue == 0.0d ? Nd4j.getExecutioner().exec(new LeakyReLUBp(iNDArray, iNDArray2, iNDArray.ulike(), doubleValue2))[0] : iNDArray.assign(iNDArray.lt(Double.valueOf(doubleValue)).castTo(iNDArray.dataType()).muli(Double.valueOf(doubleValue2)).addi(iNDArray.gte(Double.valueOf(doubleValue)).castTo(iNDArray.dataType()))).muli(iNDArray2);
        }
        if (lt != null) {
            muli.muli(lt);
        }
        return new Pair<>(muli, (Object) null);
    }

    public String toString() {
        return "relu";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ActivationReLU)) {
            return false;
        }
        ActivationReLU activationReLU = (ActivationReLU) obj;
        if (!activationReLU.canEqual(this)) {
            return false;
        }
        Double max = getMax();
        Double max2 = activationReLU.getMax();
        if (max == null) {
            if (max2 != null) {
                return false;
            }
        } else if (!max.equals(max2)) {
            return false;
        }
        Double threshold = getThreshold();
        Double threshold2 = activationReLU.getThreshold();
        if (threshold == null) {
            if (threshold2 != null) {
                return false;
            }
        } else if (!threshold.equals(threshold2)) {
            return false;
        }
        Double negativeSlope = getNegativeSlope();
        Double negativeSlope2 = activationReLU.getNegativeSlope();
        return negativeSlope == null ? negativeSlope2 == null : negativeSlope.equals(negativeSlope2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ActivationReLU;
    }

    public int hashCode() {
        Double max = getMax();
        int hashCode = (1 * 59) + (max == null ? 43 : max.hashCode());
        Double threshold = getThreshold();
        int hashCode2 = (hashCode * 59) + (threshold == null ? 43 : threshold.hashCode());
        Double negativeSlope = getNegativeSlope();
        return (hashCode2 * 59) + (negativeSlope == null ? 43 : negativeSlope.hashCode());
    }

    public Double getMax() {
        return this.max;
    }

    public Double getThreshold() {
        return this.threshold;
    }

    public Double getNegativeSlope() {
        return this.negativeSlope;
    }
}
