package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.class */
public class ElementWiseVertex extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    /* renamed from: org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op = new int[Op.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Add.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Average.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Subtract.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Product.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[Op.Max.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex$Op.class */
    public enum Op {
        Add,
        Subtract,
        Product,
        Average,
        Max
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, Op op, DataType dataType) {
        this(computationGraph, str, i, null, null, op, dataType);
    }

    public ElementWiseVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, Op op, DataType dataType) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2, dataType);
        this.op = op;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    /* JADX WARN: Removed duplicated region for block: B:80:0x02e0  */
    /* JADX WARN: Removed duplicated region for block: B:82:0x0339  */
    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public org.nd4j.linalg.api.ndarray.INDArray doForward(boolean r12, org.deeplearning4j.nn.workspace.LayerWorkspaceMgr r13) {
        /*
            Method dump skipped, instructions count: 940
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex.doForward(boolean, org.deeplearning4j.nn.workspace.LayerWorkspaceMgr):org.nd4j.linalg.api.ndarray.INDArray");
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        MemoryWorkspace notifyScopeBorrowed;
        MemoryWorkspace notifyScopeBorrowed2;
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair<>((Object) null, new INDArray[]{layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon)});
        }
        boolean z2 = false;
        for (int i = 1; i < this.nInForwardPass; i++) {
            z2 |= !this.inputs[0].equalShapes(this.inputs[i]);
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$graph$vertex$impl$ElementWiseVertex$Op[this.op.ordinal()]) {
            case org.deeplearning4j.nn.conf.graph.MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                INDArray[] iNDArrayArr = new INDArray[this.nInForwardPass];
                for (int i2 = 0; i2 < this.nInForwardPass; i2++) {
                    if (!z2) {
                        iNDArrayArr[i2] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    } else if (this.inputs[i2].equalShapes(this.epsilon)) {
                        iNDArrayArr[i2] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    } else {
                        int[] broadcastDimensions = Shape.getBroadcastDimensions(this.inputs[i2].shape(), this.epsilon.shape());
                        notifyScopeBorrowed2 = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                        Throwable th = null;
                        try {
                            try {
                                iNDArrayArr[i2] = this.epsilon.sum(true, broadcastDimensions);
                                if (notifyScopeBorrowed2 != null) {
                                    if (0 != 0) {
                                        try {
                                            notifyScopeBorrowed2.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        notifyScopeBorrowed2.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr);
            case 2:
                INDArray[] iNDArrayArr2 = new INDArray[this.nInForwardPass];
                notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                Throwable th3 = null;
                for (int i3 = 0; i3 < this.nInForwardPass; i3++) {
                    try {
                        try {
                            if (this.inputs[i3].equalShapes(this.epsilon)) {
                                iNDArrayArr2[i3] = this.epsilon.div(Integer.valueOf(this.nInForwardPass));
                            } else {
                                iNDArrayArr2[i3] = this.epsilon.div(Integer.valueOf(this.nInForwardPass)).sum(true, Shape.getBroadcastDimensions(this.inputs[i3].shape(), this.epsilon.shape()));
                            }
                        } finally {
                        }
                    } finally {
                    }
                }
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr2);
            case 3:
                INDArray[] iNDArrayArr3 = new INDArray[2];
                if (!z2) {
                    iNDArrayArr3[0] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    iNDArrayArr3[1] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon).negi();
                } else if (this.inputs[0].equalShapes(this.epsilon)) {
                    iNDArrayArr3[0] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    int[] broadcastDimensions2 = Shape.getBroadcastDimensions(this.inputs[1].shape(), this.epsilon.shape());
                    notifyScopeBorrowed2 = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                    Throwable th5 = null;
                    try {
                        try {
                            iNDArrayArr3[1] = this.epsilon.sum(true, broadcastDimensions2).negi();
                            if (notifyScopeBorrowed2 != null) {
                                if (0 != 0) {
                                    try {
                                        notifyScopeBorrowed2.close();
                                    } catch (Throwable th6) {
                                        th5.addSuppressed(th6);
                                    }
                                } else {
                                    notifyScopeBorrowed2.close();
                                }
                            }
                        } finally {
                        }
                    } finally {
                        if (notifyScopeBorrowed2 != null) {
                            if (th5 != null) {
                                try {
                                    notifyScopeBorrowed2.close();
                                } catch (Throwable th7) {
                                    th5.addSuppressed(th7);
                                }
                            } else {
                                notifyScopeBorrowed2.close();
                            }
                        }
                    }
                } else {
                    int[] broadcastDimensions3 = Shape.getBroadcastDimensions(this.inputs[0].shape(), this.epsilon.shape());
                    MemoryWorkspace notifyScopeBorrowed3 = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                    Throwable th8 = null;
                    try {
                        try {
                            iNDArrayArr3[0] = this.epsilon.sum(true, broadcastDimensions3);
                            if (notifyScopeBorrowed3 != null) {
                                if (0 != 0) {
                                    try {
                                        notifyScopeBorrowed3.close();
                                    } catch (Throwable th9) {
                                        th8.addSuppressed(th9);
                                    }
                                } else {
                                    notifyScopeBorrowed3.close();
                                }
                            }
                            iNDArrayArr3[1] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon).negi();
                        } finally {
                        }
                    } finally {
                        if (notifyScopeBorrowed3 != null) {
                            if (th8 != null) {
                                try {
                                    notifyScopeBorrowed3.close();
                                } catch (Throwable th10) {
                                    th8.addSuppressed(th10);
                                }
                            } else {
                                notifyScopeBorrowed3.close();
                            }
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr3);
            case 4:
                INDArray[] iNDArrayArr4 = new INDArray[this.nInForwardPass];
                INDArray[] iNDArrayArr5 = this.inputs;
                if (z2) {
                    iNDArrayArr5 = new INDArray[this.inputs.length];
                    for (int i4 = 0; i4 < this.inputs.length; i4++) {
                        if (this.inputs[i4].equalShapes(this.epsilon)) {
                            iNDArrayArr5[i4] = this.inputs[i4];
                        } else {
                            iNDArrayArr5[i4] = this.epsilon.ulike();
                            Nd4j.exec(new BroadcastTo(this.inputs[i4], this.epsilon.shape(), iNDArrayArr5[i4]));
                        }
                    }
                }
                for (int i5 = 0; i5 < this.nInForwardPass; i5++) {
                    iNDArrayArr4[i5] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    for (int i6 = 0; i6 < this.nInForwardPass; i6++) {
                        if (i5 != i6) {
                            iNDArrayArr4[i5].muli(iNDArrayArr5[i6]);
                        }
                    }
                    if (!this.inputs[i5].equalShapes(this.epsilon)) {
                        int[] broadcastDimensions4 = Shape.getBroadcastDimensions(this.inputs[i5].shape(), this.epsilon.shape());
                        MemoryWorkspace notifyScopeBorrowed4 = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                        Throwable th11 = null;
                        try {
                            try {
                                iNDArrayArr4[i5] = iNDArrayArr4[i5].sum(true, broadcastDimensions4);
                                if (notifyScopeBorrowed4 != null) {
                                    if (0 != 0) {
                                        try {
                                            notifyScopeBorrowed4.close();
                                        } catch (Throwable th12) {
                                            th11.addSuppressed(th12);
                                        }
                                    } else {
                                        notifyScopeBorrowed4.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                            if (notifyScopeBorrowed4 != null) {
                                if (th11 != null) {
                                    try {
                                        notifyScopeBorrowed4.close();
                                    } catch (Throwable th13) {
                                        th11.addSuppressed(th13);
                                    }
                                } else {
                                    notifyScopeBorrowed4.close();
                                }
                            }
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr4);
            case 5:
                INDArray[] iNDArrayArr6 = new INDArray[this.nInForwardPass];
                INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, this.epsilon.shape(), this.epsilon.ordering());
                INDArray[] iNDArrayArr7 = this.inputs;
                if (z2) {
                    iNDArrayArr7 = new INDArray[this.inputs.length];
                    for (int i7 = 0; i7 < this.inputs.length; i7++) {
                        if (this.inputs[i7].equalShapes(this.epsilon)) {
                            iNDArrayArr7[i7] = this.inputs[i7];
                        } else {
                            iNDArrayArr7[i7] = this.epsilon.ulike();
                            Nd4j.exec(new BroadcastTo(this.inputs[i7], this.epsilon.shape(), iNDArrayArr7[i7]));
                        }
                    }
                }
                Nd4j.getExecutioner().exec(DynamicCustomOp.builder("mergemaxindex").addInputs(iNDArrayArr7).addOutputs(new INDArray[]{createUninitialized}).callInplace(false).build());
                for (int i8 = 0; i8 < this.nInForwardPass; i8++) {
                    iNDArrayArr6[i8] = layerWorkspaceMgr.create(ArrayType.BP_WORKING_MEM, DataType.BOOL, createUninitialized.shape());
                    Nd4j.getExecutioner().exec(new MatchConditionTransform(createUninitialized, iNDArrayArr6[i8], Conditions.equals(Integer.valueOf(i8))));
                    if (!z2 || this.epsilon.equalShapes(this.inputs[i8])) {
                        iNDArrayArr6[i8] = layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, iNDArrayArr6[i8].castTo(this.epsilon.dataType()).muli(this.epsilon));
                    } else {
                        iNDArrayArr6[i8] = iNDArrayArr6[i8].castTo(this.epsilon.dataType()).mul(this.epsilon);
                        int[] broadcastDimensions5 = Shape.getBroadcastDimensions(this.inputs[i8].shape(), this.epsilon.shape());
                        notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);
                        Throwable th14 = null;
                        try {
                            try {
                                iNDArrayArr6[i8] = iNDArrayArr6[i8].sum(true, broadcastDimensions5);
                                if (notifyScopeBorrowed != null) {
                                    if (0 != 0) {
                                        try {
                                            notifyScopeBorrowed.close();
                                        } catch (Throwable th15) {
                                            th14.addSuppressed(th15);
                                        }
                                    } else {
                                        notifyScopeBorrowed.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                            if (notifyScopeBorrowed != null) {
                                if (th14 != null) {
                                    try {
                                        notifyScopeBorrowed.close();
                                    } catch (Throwable th16) {
                                        th14.addSuppressed(th16);
                                    }
                                } else {
                                    notifyScopeBorrowed.close();
                                }
                            }
                        }
                    }
                }
                return new Pair<>((Object) null, iNDArrayArr6);
            default:
                throw new UnsupportedOperationException("Unknown op: " + this.op);
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return new Pair<>((Object) null, maskState);
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray == null) {
                return new Pair<>((Object) null, maskState);
            }
        }
        if (iNDArrayArr.length == 1) {
            return new Pair<>(iNDArrayArr[0], maskState);
        }
        INDArray createUninitialized = Nd4j.createUninitialized(DataType.BOOL, iNDArrayArr[0].shape());
        Nd4j.getExecutioner().exec(new Or(iNDArrayArr[0].castTo(DataType.BOOL), iNDArrayArr[1].castTo(DataType.BOOL), createUninitialized));
        for (int i2 = 2; i2 < iNDArrayArr.length; i2++) {
            Nd4j.getExecutioner().exec(new Or(iNDArrayArr[i2].castTo(DataType.BOOL), createUninitialized, createUninitialized));
        }
        return new Pair<>(createUninitialized.castTo(Nd4j.defaultFloatingPointType()), maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "ElementWiseVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + "\",op=" + this.op + ")";
    }
}
