package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

/* loaded from: input_file:org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.class */
public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;

    @JsonCreator
    public CnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") int i, @JsonProperty("inputWidth") int i2, @JsonProperty("numChannels") int i3) {
        this.inputHeight = i;
        this.inputWidth = i2;
        this.numChannels = i3;
    }

    public CnnToFeedForwardPreProcessor(int i, int i2) {
        this.inputHeight = i;
        this.inputWidth = i2;
        this.numChannels = 1;
    }

    public CnnToFeedForwardPreProcessor() {
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray, int i) {
        if (iNDArray.rank() == 2) {
            return iNDArray;
        }
        if (iNDArray.ordering() != 'c' || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = iNDArray.dup('c');
        }
        int[] shape = iNDArray.shape();
        return iNDArray.reshape('c', new int[]{shape[0], shape[1] * shape[2] * shape[3]});
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray, int i) {
        if (iNDArray.ordering() != 'c' || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = iNDArray.dup('c');
        }
        if (iNDArray.rank() == 4) {
            return iNDArray;
        }
        if (iNDArray.columns() != this.inputWidth * this.inputHeight * this.numChannels) {
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " + this.inputHeight + " x columns " + this.inputWidth + " x depth " + this.numChannels + " but was instead " + Arrays.toString(iNDArray.shape()));
        }
        return iNDArray.reshape('c', new int[]{iNDArray.size(0), this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CnnToFeedForwardPreProcessor m56clone() {
        try {
            return (CnnToFeedForwardPreProcessor) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public InputType getOutputType(InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType);
        }
        InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) inputType;
        return InputType.feedForward(inputTypeConvolutional.getDepth() * inputTypeConvolutional.getHeight() * inputTypeConvolutional.getWidth());
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public void setInputHeight(int i) {
        this.inputHeight = i;
    }

    public void setInputWidth(int i) {
        this.inputWidth = i;
    }

    public void setNumChannels(int i) {
        this.numChannels = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CnnToFeedForwardPreProcessor)) {
            return false;
        }
        CnnToFeedForwardPreProcessor cnnToFeedForwardPreProcessor = (CnnToFeedForwardPreProcessor) obj;
        return cnnToFeedForwardPreProcessor.canEqual(this) && getInputHeight() == cnnToFeedForwardPreProcessor.getInputHeight() && getInputWidth() == cnnToFeedForwardPreProcessor.getInputWidth() && getNumChannels() == cnnToFeedForwardPreProcessor.getNumChannels();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CnnToFeedForwardPreProcessor;
    }

    public int hashCode() {
        return (((((1 * 59) + getInputHeight()) * 59) + getInputWidth()) * 59) + getNumChannels();
    }

    public String toString() {
        return "CnnToFeedForwardPreProcessor(inputHeight=" + getInputHeight() + ", inputWidth=" + getInputWidth() + ", numChannels=" + getNumChannels() + ")";
    }
}
