package org.nd4j.linalg.activations.impl;

import java.util.Arrays;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/activations/impl/ActivationPReLU.class */
public class ActivationPReLU extends BaseActivationFunction {
    private INDArray alpha;
    private long[] sharedAxes;

    public ActivationPReLU(INDArray iNDArray, long[] jArr) {
        this.sharedAxes = null;
        this.alpha = iNDArray;
        this.sharedAxes = jArr;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public INDArray getActivation(INDArray iNDArray, boolean z) {
        DynamicCustomOp.DynamicCustomOpsBuilder addInputs = DynamicCustomOp.builder("prelu").addOutputs(iNDArray).addInputs(iNDArray, this.alpha);
        if (this.sharedAxes != null) {
            for (long j : this.sharedAxes) {
                addInputs.addIntegerArguments(j);
            }
        }
        Nd4j.getExecutioner().execAndReturn(addInputs.build());
        return iNDArray;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public Pair<INDArray, INDArray> backprop(INDArray iNDArray, INDArray iNDArray2) {
        assertShape(iNDArray, iNDArray2);
        INDArray ulike = this.alpha.ulike();
        INDArray ulike2 = iNDArray.ulike();
        DynamicCustomOp.DynamicCustomOpsBuilder addOutputs = DynamicCustomOp.builder("prelu_bp").addInputs(iNDArray, this.alpha, iNDArray2).addOutputs(ulike2, ulike);
        if (this.sharedAxes != null) {
            for (long j : this.sharedAxes) {
                addOutputs.addIntegerArguments(j);
            }
        }
        Nd4j.exec(addOutputs.build());
        iNDArray.assign(ulike2);
        return new Pair<>(iNDArray, ulike);
    }

    public String toString() {
        return "prelu";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ActivationPReLU)) {
            return false;
        }
        ActivationPReLU activationPReLU = (ActivationPReLU) obj;
        if (!activationPReLU.canEqual(this)) {
            return false;
        }
        INDArray alpha = getAlpha();
        INDArray alpha2 = activationPReLU.getAlpha();
        if (alpha == null) {
            if (alpha2 != null) {
                return false;
            }
        } else if (!alpha.equals(alpha2)) {
            return false;
        }
        return Arrays.equals(getSharedAxes(), activationPReLU.getSharedAxes());
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ActivationPReLU;
    }

    public int hashCode() {
        INDArray alpha = getAlpha();
        return (((1 * 59) + (alpha == null ? 43 : alpha.hashCode())) * 59) + Arrays.hashCode(getSharedAxes());
    }

    public INDArray getAlpha() {
        return this.alpha;
    }

    public long[] getSharedAxes() {
        return this.sharedAxes;
    }
}
