package org.deeplearning4j.nn.graph.vertex.impl.rnn;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.class */
public class ReverseTimeSeriesVertex extends BaseGraphVertex {
    private final String inputName;
    private final int inputIdx;

    public ReverseTimeSeriesVertex(ComputationGraph computationGraph, String str, int i, String str2) {
        super(computationGraph, str, i, null, null);
        this.inputName = str2;
        if (str2 == null) {
            this.inputIdx = -1;
            return;
        }
        this.inputIdx = computationGraph.getConfiguration().getNetworkInputs().indexOf(str2);
        if (this.inputIdx == -1) {
            throw new IllegalArgumentException("Invalid input name: \"" + str2 + "\" not found in list of network inputs (" + computationGraph.getConfiguration().getNetworkInputs() + ")");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        return revertTimeSeries(this.inputs[0], getMask());
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        return new Pair<>((Object) null, new INDArray[]{revertTimeSeries(this.epsilon, getMask())});
    }

    private INDArray getMask() {
        INDArray[] inputMaskArrays;
        if (this.inputIdx >= 0 && (inputMaskArrays = this.graph.getInputMaskArrays()) != null) {
            return inputMaskArrays[this.inputIdx];
        }
        return null;
    }

    private static INDArray revertTimeSeries(INDArray iNDArray, INDArray iNDArray2) {
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(2);
        INDArray dup = iNDArray.dup();
        for (int i = 0; i < size; i++) {
            int i2 = 0;
            int i3 = size2 - 1;
            while (i2 < size2 && i3 >= 0) {
                if (iNDArray2 != null) {
                    while (i2 < size2 && iNDArray2.getDouble(i, i2) == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                        i2++;
                    }
                    while (i3 >= 0 && iNDArray2.getDouble(i, i3) == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                        i3--;
                    }
                }
                dup.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(i3)}, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(i2)}));
                i2++;
                i3--;
            }
        }
        return dup;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr.length > 1) {
            throw new IllegalArgumentException("This vertex can only handle one input and hence only one mask");
        }
        return new Pair<>(iNDArrayArr[0], maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "ReverseTimeSeriesVertex(" + (this.inputName == null ? "" : "inputName=" + this.inputName) + ")";
    }
}
