package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/StackVertex.class */
public class StackVertex extends BaseGraphVertex {
    private int[][] lastInputShapes;

    public StackVertex(ComputationGraph computationGraph, String str, int i) {
        this(computationGraph, str, i, null, null);
    }

    public StackVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        this.lastInputShapes = (int[][]) null;
        int length = this.inputs.length;
        int[] shape = this.inputs[0].shape();
        int[] iArr = new int[shape.length];
        iArr[0] = length * shape[0];
        for (int i = 1; i < shape.length; i++) {
            iArr[i] = shape[i];
        }
        if (shape.length != 3) {
            return Nd4j.concat(0, this.inputs);
        }
        int size = this.inputs[0].size(2);
        int i2 = size;
        for (int i3 = 1; i3 < this.inputs.length; i3++) {
            int size2 = this.inputs[i3].size(2);
            size = Math.min(size, size2);
            i2 = Math.max(i2, size2);
        }
        if (!(size != i2)) {
            return Nd4j.concat(0, this.inputs);
        }
        iArr[2] = i2;
        INDArray create = Nd4j.create(iArr);
        int size3 = this.inputs[0].size(0);
        this.lastInputShapes = new int[this.inputs.length][0];
        for (int i4 = 0; i4 < this.inputs.length; i4++) {
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4 * size3, (i4 + 1) * size3), NDArrayIndex.all(), NDArrayIndex.interval(0, this.inputs[i4].size(2))}, this.inputs[i4]);
            this.lastInputShapes[i4] = this.inputs[i4].shape();
        }
        return create;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: input not set");
        }
        if (this.epsilon == null) {
            return new Pair<>((Object) null, new INDArray[this.inputs.length]);
        }
        int length = this.inputs.length;
        INDArray[] iNDArrayArr = new INDArray[length];
        int size = this.epsilon.size(0) / length;
        for (int i = 0; i < length; i++) {
            switch (this.epsilon.rank()) {
                case 2:
                    iNDArrayArr[i] = this.epsilon.get(new INDArrayIndex[]{NDArrayIndex.interval(i * size, (i + 1) * size), NDArrayIndex.all()});
                    break;
                case 3:
                    if (this.lastInputShapes != null) {
                        iNDArrayArr[i] = this.epsilon.get(new INDArrayIndex[]{NDArrayIndex.interval(i * size, (i + 1) * size), NDArrayIndex.all(), NDArrayIndex.interval(0, this.lastInputShapes[i][2])});
                        break;
                    } else {
                        iNDArrayArr[i] = this.epsilon.get(new INDArrayIndex[]{NDArrayIndex.interval(i * size, (i + 1) * size), NDArrayIndex.all(), NDArrayIndex.all()});
                        break;
                    }
                case 4:
                    iNDArrayArr[i] = this.epsilon.get(new INDArrayIndex[]{NDArrayIndex.interval(i * size, (i + 1) * size), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
                    break;
                default:
                    throw new UnsupportedOperationException("Cannot get subset for activations of rank " + this.inputs[0].rank());
            }
        }
        return new Pair<>((Object) null, iNDArrayArr);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return new Pair<>((Object) null, maskState);
        }
        boolean z = true;
        int size = iNDArrayArr[0].size(1);
        int i2 = size;
        for (int i3 = 1; i3 < iNDArrayArr.length; i3++) {
            z &= size == iNDArrayArr[i3].size(1);
            i2 = Math.max(i2, iNDArrayArr[i3].size(1));
        }
        if (z) {
            return new Pair<>(Nd4j.vstack(iNDArrayArr), maskState);
        }
        int size2 = iNDArrayArr[0].size(0);
        INDArray create = Nd4j.create(iNDArrayArr.length * size2, i2);
        for (int i4 = 0; i4 < iNDArrayArr.length; i4++) {
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4 * size2, (i4 + 1) * size2), NDArrayIndex.interval(0, iNDArrayArr[i4].size(1))}, iNDArrayArr[i4]);
        }
        return new Pair<>(create, maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "StackVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + ")";
    }
}
