package org.deeplearning4j.nn.conf.preprocessor;

import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

/* loaded from: input_file:org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.class */
public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
    private static final long serialVersionUID = 1410433625085923838L;

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray, int i) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3 (i.e., activations for RNN layer)");
        }
        int[] shape = iNDArray.shape();
        return shape[0] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 2}) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 0}) : iNDArray.permute(new int[]{0, 2, 1}).reshape(shape[0] * shape[2], shape[1]);
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray, int i) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2 (i.e., epsilons from feed forward layer)");
        }
        if (iNDArray.ordering() == 'f') {
            iNDArray = Shape.toOffsetZeroCopy(iNDArray, 'c');
        }
        int[] shape = iNDArray.shape();
        return iNDArray.reshape(new int[]{i, shape[0] / i, shape[1]}).permute(new int[]{0, 2, 1});
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public RnnToFeedForwardPreProcessor m43clone() {
        try {
            return (RnnToFeedForwardPreProcessor) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof RnnToFeedForwardPreProcessor) && ((RnnToFeedForwardPreProcessor) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof RnnToFeedForwardPreProcessor;
    }

    public int hashCode() {
        return 1;
    }

    public String toString() {
        return "RnnToFeedForwardPreProcessor()";
    }
}
