package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.class */
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;
    private int[] shape;

    @JsonCreator
    public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") int i, @JsonProperty("inputWidth") int i2, @JsonProperty("numChannels") int i3) {
        this.inputHeight = i;
        this.inputWidth = i2;
        this.numChannels = i3;
    }

    public FeedForwardToCnnPreProcessor(int i, int i2) {
        this.inputHeight = i2;
        this.inputWidth = i;
        this.numChannels = 1;
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray) {
        this.shape = iNDArray.shape();
        if (iNDArray.shape().length == 4) {
            return iNDArray;
        }
        if (iNDArray.columns() != this.inputWidth * this.inputHeight) {
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " + this.inputHeight + " x columns " + this.inputWidth + " but was instead " + Arrays.toString(iNDArray.shape()));
        }
        return iNDArray.reshape(new int[]{iNDArray.size(0), this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray) {
        if (this.shape == null || ArrayUtil.prod(this.shape) != iNDArray.length()) {
            int[] iArr = null;
            if (iNDArray.shape().length == 2) {
                return iNDArray;
            }
            if (iNDArray.shape().length == 4) {
                iArr = new int[3];
            } else if (iNDArray.shape().length == 3) {
                iArr = new int[2];
            }
            System.arraycopy(iNDArray.shape(), 1, iArr, 0, iArr.length);
            this.shape = new int[]{iNDArray.shape()[0], ArrayUtil.prod(iArr)};
        }
        return iNDArray.reshape(this.shape);
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public FeedForwardToCnnPreProcessor m29clone() {
        try {
            FeedForwardToCnnPreProcessor feedForwardToCnnPreProcessor = (FeedForwardToCnnPreProcessor) super.clone();
            if (feedForwardToCnnPreProcessor.shape != null) {
                feedForwardToCnnPreProcessor.shape = (int[]) feedForwardToCnnPreProcessor.shape.clone();
            }
            return feedForwardToCnnPreProcessor;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public void setInputHeight(int i) {
        this.inputHeight = i;
    }

    public void setInputWidth(int i) {
        this.inputWidth = i;
    }

    public void setNumChannels(int i) {
        this.numChannels = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof FeedForwardToCnnPreProcessor)) {
            return false;
        }
        FeedForwardToCnnPreProcessor feedForwardToCnnPreProcessor = (FeedForwardToCnnPreProcessor) obj;
        return feedForwardToCnnPreProcessor.canEqual(this) && getInputHeight() == feedForwardToCnnPreProcessor.getInputHeight() && getInputWidth() == feedForwardToCnnPreProcessor.getInputWidth() && getNumChannels() == feedForwardToCnnPreProcessor.getNumChannels() && Arrays.equals(this.shape, feedForwardToCnnPreProcessor.shape);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof FeedForwardToCnnPreProcessor;
    }

    public int hashCode() {
        return (((((((1 * 59) + getInputHeight()) * 59) + getInputWidth()) * 59) + getNumChannels()) * 59) + Arrays.hashCode(this.shape);
    }

    public String toString() {
        return "FeedForwardToCnnPreProcessor(inputHeight=" + getInputHeight() + ", inputWidth=" + getInputWidth() + ", numChannels=" + getNumChannels() + ", shape=" + Arrays.toString(this.shape) + ")";
    }
}
