package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.Or;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.OldSubOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.class */
public class ElementWiseVertex extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    /* renamed from: org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op = new int[Op.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Add.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Average.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Subtract.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Product.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Max.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex$Op.class */
    public enum Op {
        Add,
        Subtract,
        Product,
        Average,
        Max
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, Op op) {
        this(computationGraph, str, i, null, null, op);
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, Op op) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
        this.op = op;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        this.nInForwardPass = this.inputs.length;
        if (this.inputs.length == 1) {
            return this.inputs[0];
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[this.op.ordinal()]) {
            case 1:
                INDArray dup = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.inputs[0]);
                for (int i = 1; i < this.inputs.length; i++) {
                    dup.addi(this.inputs[i]);
                }
                return dup;
            case 2:
                INDArray dup2 = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.inputs[0]);
                for (int i2 = 1; i2 < this.inputs.length; i2++) {
                    dup2.addi(this.inputs[i2]);
                }
                return dup2.divi(Integer.valueOf(this.inputs.length));
            case 3:
                if (this.inputs.length != 2) {
                    throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
                }
                return Nd4j.getExecutioner().execAndReturn(new OldSubOp(this.inputs[0], this.inputs[1], layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.inputs[0].shape())));
            case 4:
                INDArray dup3 = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.inputs[0]);
                for (int i3 = 1; i3 < this.inputs.length; i3++) {
                    dup3.muli(this.inputs[i3]);
                }
                return dup3;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.inputs[0].shape(), this.inputs[0].ordering());
                Nd4j.getExecutioner().exec(DynamicCustomOp.builder("mergemax").addInputs(this.inputs).addOutputs(new INDArray[]{createUninitialized}).callInplace(false).build());
                return createUninitialized;
            default:
                throw new UnsupportedOperationException("Unknown op: " + this.op);
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair<>((Object) null, new INDArray[]{this.epsilon});
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[this.op.ordinal()]) {
            case 1:
                INDArray[] iNDArrayArr = new INDArray[this.nInForwardPass];
                for (int i = 0; i < this.nInForwardPass; i++) {
                    iNDArrayArr[i] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                }
                return new Pair<>((Object) null, iNDArrayArr);
            case 2:
                INDArray[] iNDArrayArr2 = new INDArray[this.nInForwardPass];
                MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                Throwable th = null;
                for (int i2 = 0; i2 < this.nInForwardPass; i2++) {
                    try {
                        try {
                            iNDArrayArr2[i2] = this.epsilon.div(Integer.valueOf(this.nInForwardPass));
                        } finally {
                        }
                    } catch (Throwable th2) {
                        if (notifyScopeBorrowed != null) {
                            if (th != null) {
                                try {
                                    notifyScopeBorrowed.close();
                                } catch (Throwable th3) {
                                    th.addSuppressed(th3);
                                }
                            } else {
                                notifyScopeBorrowed.close();
                            }
                        }
                        throw th2;
                    }
                }
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr2);
            case 3:
                return new Pair<>((Object) null, new INDArray[]{layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon), layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon).negi()});
            case 4:
                INDArray[] iNDArrayArr3 = new INDArray[this.nInForwardPass];
                for (int i3 = 0; i3 < this.nInForwardPass; i3++) {
                    iNDArrayArr3[i3] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    for (int i4 = 0; i4 < this.nInForwardPass; i4++) {
                        if (i3 != i4) {
                            iNDArrayArr3[i3].muli(this.inputs[i4]);
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr3);
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                INDArray[] iNDArrayArr4 = new INDArray[this.nInForwardPass];
                INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, this.epsilon.shape(), this.epsilon.ordering());
                Nd4j.getExecutioner().exec(DynamicCustomOp.builder("mergemaxindex").addInputs(this.inputs).addOutputs(new INDArray[]{createUninitialized}).callInplace(false).build());
                for (int i5 = 0; i5 < this.nInForwardPass; i5++) {
                    iNDArrayArr4[i5] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, createUninitialized);
                    Nd4j.getExecutioner().exec(new MatchConditionTransform(iNDArrayArr4[i5], iNDArrayArr4[i5], Conditions.equals(Integer.valueOf(i5))));
                    iNDArrayArr4[i5].muli(this.epsilon);
                }
                return new Pair<>((Object) null, iNDArrayArr4);
            default:
                throw new UnsupportedOperationException("Unknown op: " + this.op);
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return new Pair<>((Object) null, maskState);
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray == null) {
                return new Pair<>((Object) null, maskState);
            }
        }
        if (iNDArrayArr.length == 1) {
            return new Pair<>(iNDArrayArr[0], maskState);
        }
        INDArray dup = iNDArrayArr[0].dup(iNDArrayArr[0].ordering());
        Nd4j.getExecutioner().exec(new Or(iNDArrayArr[0], iNDArrayArr[1], dup));
        for (int i2 = 2; i2 < iNDArrayArr.length; i2++) {
            Nd4j.getExecutioner().exec(new Or(iNDArrayArr[i2], dup, dup));
        }
        return new Pair<>(dup, maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "ElementWiseVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + "\",op=" + this.op + ")";
    }
}
