package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Or;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.class */
public class ElementWiseVertex extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    /* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex$Op.class */
    public enum Op {
        Add,
        Subtract,
        Product
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, Op op) {
        this(computationGraph, str, i, null, null, op);
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, Op op) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
        this.op = op;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        this.nInForwardPass = this.inputs.length;
        if (this.inputs.length == 1) {
            return this.inputs[0];
        }
        switch (this.op) {
            case Add:
                INDArray dup = this.inputs[0].dup();
                for (int i = 1; i < this.inputs.length; i++) {
                    dup.addi(this.inputs[i]);
                }
                return dup;
            case Subtract:
                if (this.inputs.length != 2) {
                    throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
                }
                return this.inputs[0].sub(this.inputs[1]);
            case Product:
                INDArray dup2 = this.inputs[0].dup();
                for (int i2 = 1; i2 < this.inputs.length; i2++) {
                    dup2.muli(this.inputs[i2]);
                }
                return dup2;
            default:
                throw new UnsupportedOperationException("Unknown op: " + this.op);
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair<>((Object) null, new INDArray[]{this.epsilon});
        }
        switch (this.op) {
            case Add:
                INDArray[] iNDArrayArr = new INDArray[this.nInForwardPass];
                for (int i = 0; i < this.nInForwardPass; i++) {
                    iNDArrayArr[i] = this.epsilon.dup();
                }
                return new Pair<>((Object) null, iNDArrayArr);
            case Subtract:
                return new Pair<>((Object) null, new INDArray[]{this.epsilon, this.epsilon.neg()});
            case Product:
                INDArray[] iNDArrayArr2 = new INDArray[this.nInForwardPass];
                for (int i2 = 0; i2 < this.nInForwardPass; i2++) {
                    iNDArrayArr2[i2] = this.epsilon.dup();
                    for (int i3 = 0; i3 < this.nInForwardPass; i3++) {
                        if (i2 != i3) {
                            iNDArrayArr2[i2].muli(this.inputs[i3]);
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr2);
            default:
                throw new UnsupportedOperationException("Unknown op: " + this.op);
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return new Pair<>((Object) null, maskState);
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray == null) {
                return new Pair<>((Object) null, maskState);
            }
        }
        if (iNDArrayArr.length == 1) {
            return new Pair<>(iNDArrayArr[0], maskState);
        }
        INDArray dup = iNDArrayArr[0].dup(iNDArrayArr[0].ordering());
        Nd4j.getExecutioner().exec(new Or(iNDArrayArr[0], iNDArrayArr[1], dup));
        for (int i2 = 2; i2 < iNDArrayArr.length; i2++) {
            Nd4j.getExecutioner().exec(new Or(iNDArrayArr[i2], dup, dup));
        }
        return new Pair<>(dup, maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "ElementWiseVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + "\",op=" + this.op + ")";
    }
}
