package org.deeplearning4j.nn.conf.layers.setup;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/setup/ConvolutionLayerSetup.class */
public class ConvolutionLayerSetup {
    private int lastHeight;
    private int lastWidth;
    private int lastOutChannels;
    private int numLayers;
    private Map<Integer, int[]> outSizesEachLayer = new HashMap();
    private Map<Integer, Integer> nInForLayer = new HashMap();

    public ConvolutionLayerSetup(MultiLayerConfiguration.Builder builder, int i, int i2, int i3) {
        this.lastHeight = -1;
        this.lastWidth = -1;
        this.lastOutChannels = -1;
        this.numLayers = -1;
        this.lastHeight = i;
        this.lastWidth = i2;
        this.lastOutChannels = i3;
        if (builder instanceof NeuralNetConfiguration.ListBuilder) {
            this.numLayers = ((NeuralNetConfiguration.ListBuilder) builder).getLayerwise().size();
        } else {
            this.numLayers = builder.getConfs().size();
        }
        for (int i4 = 0; i4 < this.numLayers; i4++) {
            boolean z = false;
            Layer layer = getLayer(i4, builder);
            if (i4 == 0 || (i4 < this.numLayers - 2 && (getLayer(i4, builder) instanceof ConvolutionLayer))) {
                ConvolutionLayer convolutionLayer = (ConvolutionLayer) getLayer(i4, builder);
                if (i4 == 0) {
                    convolutionLayer.setNIn(i3);
                }
                Layer layer2 = getLayer(i4 + 1, builder);
                if ((layer2 instanceof DenseLayer) || (layer2 instanceof OutputLayer) || (layer2 instanceof BaseRecurrentLayer) || (layer2 instanceof RnnOutputLayer)) {
                    int[] convolutionOutputSize = getConvolutionOutputSize(new int[]{this.lastHeight, this.lastWidth}, convolutionLayer.getKernelSize(), convolutionLayer.getPadding(), convolutionLayer.getStride());
                    if ((layer2 instanceof DenseLayer) || (layer2 instanceof OutputLayer)) {
                        builder.inputPreProcessor(Integer.valueOf(i4 + 1), new CnnToFeedForwardPreProcessor(convolutionOutputSize[0], convolutionOutputSize[1], convolutionLayer.getNOut()));
                    } else {
                        builder.inputPreProcessor(Integer.valueOf(i4 + 1), new CnnToRnnPreProcessor(convolutionOutputSize[0], convolutionOutputSize[1], convolutionLayer.getNOut()));
                    }
                    FeedForwardLayer feedForwardLayer = (FeedForwardLayer) layer2;
                    this.outSizesEachLayer.put(Integer.valueOf(i4), convolutionOutputSize);
                    int i5 = convolutionOutputSize[0];
                    int i6 = convolutionOutputSize[1];
                    this.lastHeight = i5;
                    this.lastWidth = i6;
                    this.lastOutChannels = convolutionLayer.getNOut();
                    int nOut = i6 * i5 * convolutionLayer.getNOut();
                    this.nInForLayer.put(Integer.valueOf(i4), Integer.valueOf(nOut));
                    feedForwardLayer.setNIn(nOut);
                    z = true;
                } else if (layer2 instanceof SubsamplingLayer) {
                    SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layer2;
                    if (subsamplingLayer.getPadding() == null) {
                        subsamplingLayer.setPadding(convolutionLayer.getPadding());
                    }
                } else if (layer2 instanceof ConvolutionLayer) {
                    ((ConvolutionLayer) layer2).setNIn(convolutionLayer.getNOut());
                }
            } else if (i4 < this.numLayers - 1 && (getLayer(i4, builder) instanceof SubsamplingLayer)) {
                SubsamplingLayer subsamplingLayer2 = (SubsamplingLayer) getLayer(i4, builder);
                Layer layer3 = getLayer(i4 + 1, builder);
                if ((layer3 instanceof DenseLayer) || (layer3 instanceof OutputLayer) || (layer3 instanceof BaseRecurrentLayer) || (layer3 instanceof RnnOutputLayer)) {
                    int[] subSamplingOutputSize = getSubSamplingOutputSize(new int[]{this.lastHeight, this.lastWidth}, subsamplingLayer2.getKernelSize(), subsamplingLayer2.getStride());
                    this.outSizesEachLayer.put(Integer.valueOf(i4), subSamplingOutputSize);
                    int i7 = subSamplingOutputSize[0];
                    int i8 = subSamplingOutputSize[1];
                    this.lastHeight = subSamplingOutputSize[0];
                    this.lastWidth = subSamplingOutputSize[1];
                    if ((layer3 instanceof DenseLayer) || (layer3 instanceof OutputLayer)) {
                        builder.inputPreProcessor(Integer.valueOf(i4 + 1), new CnnToFeedForwardPreProcessor(i7, i8, this.lastOutChannels));
                    } else {
                        builder.inputPreProcessor(Integer.valueOf(i4 + 1), new CnnToRnnPreProcessor(i7, i8, this.lastOutChannels));
                    }
                    FeedForwardLayer feedForwardLayer2 = (FeedForwardLayer) layer3;
                    int i9 = i8 * i7 * this.lastOutChannels;
                    feedForwardLayer2.setNIn(i9);
                    this.nInForLayer.put(Integer.valueOf(i4 + 1), Integer.valueOf(i9));
                    setFourDtoTwoD(i4, builder, feedForwardLayer2);
                    z = true;
                } else if (layer3 instanceof ConvolutionLayer) {
                    ((ConvolutionLayer) layer3).setNIn(this.lastOutChannels);
                }
            } else if (i4 < this.numLayers - 1 && ((getLayer(i4, builder) instanceof DenseLayer) || (getLayer(i4, builder) instanceof OutputLayer) || (getLayer(i4, builder) instanceof BaseRecurrentLayer) || (getLayer(i4, builder) instanceof RnnOutputLayer))) {
                FeedForwardLayer feedForwardLayer3 = (FeedForwardLayer) getLayer(i4, builder);
                if (getLayer(i4 + 1, builder) instanceof ConvolutionLayer) {
                    throw new UnsupportedOperationException("2d to 4d needs to be implemented");
                }
                if (getLayer(i4 + 1, builder) instanceof SubsamplingLayer) {
                    throw new UnsupportedOperationException("2d to 4d needs to be implemented");
                }
                if ((getLayer(i4 + 1, builder) instanceof OutputLayer) || (getLayer(i4 + 1, builder) instanceof DenseLayer)) {
                    ((FeedForwardLayer) getLayer(i4 + 1, builder)).setNIn(feedForwardLayer3.getNOut());
                    this.nInForLayer.put(Integer.valueOf(i4 + 1), Integer.valueOf(feedForwardLayer3.getNOut()));
                }
                setFourDtoTwoD(i4, builder, feedForwardLayer3);
            }
            if ((layer instanceof ConvolutionLayer) && i4 < this.numLayers - 1 && !z) {
                ConvolutionLayer convolutionLayer2 = (ConvolutionLayer) layer;
                int[] convolutionOutputSize2 = getConvolutionOutputSize(new int[]{this.lastHeight, this.lastWidth}, convolutionLayer2.getKernelSize(), convolutionLayer2.getPadding(), convolutionLayer2.getStride());
                this.lastHeight = convolutionOutputSize2[0];
                this.lastWidth = convolutionOutputSize2[1];
                this.lastOutChannels = convolutionLayer2.getNOut();
                this.outSizesEachLayer.put(Integer.valueOf(i4), convolutionOutputSize2);
            } else if ((layer instanceof SubsamplingLayer) && i4 < this.numLayers - 1 && !z) {
                SubsamplingLayer subsamplingLayer3 = (SubsamplingLayer) layer;
                int[] subSamplingOutputSize2 = getSubSamplingOutputSize(new int[]{this.lastHeight, this.lastWidth}, subsamplingLayer3.getKernelSize(), subsamplingLayer3.getStride());
                this.lastHeight = subSamplingOutputSize2[0];
                this.lastWidth = subSamplingOutputSize2[1];
                this.outSizesEachLayer.put(Integer.valueOf(i4), subSamplingOutputSize2);
            }
        }
        if ((getLayer(this.numLayers - 1, builder) instanceof OutputLayer) || (getLayer(this.numLayers - 1, builder) instanceof DenseLayer)) {
            FeedForwardLayer feedForwardLayer4 = (FeedForwardLayer) getLayer(this.numLayers - 1, builder);
            if ((getLayer(this.numLayers - 2, builder) instanceof DenseLayer) || (getLayer(this.numLayers - 2, builder) instanceof OutputLayer)) {
                FeedForwardLayer feedForwardLayer5 = (FeedForwardLayer) getLayer(this.numLayers - 2, builder);
                feedForwardLayer4.setNIn(feedForwardLayer5.getNOut());
                this.nInForLayer.put(Integer.valueOf(this.numLayers - 1), Integer.valueOf(feedForwardLayer5.getNOut()));
            } else if (getLayer(this.numLayers - 2, builder) instanceof SubsamplingLayer) {
                feedForwardLayer4.setNIn(this.lastHeight * this.lastWidth * this.lastOutChannels);
                this.nInForLayer.put(Integer.valueOf(this.numLayers - 1), Integer.valueOf(this.lastHeight * this.lastWidth * this.lastOutChannels));
            } else if (getLayer(this.numLayers - 2, builder) instanceof ConvolutionLayer) {
                feedForwardLayer4.setNIn(this.lastHeight * this.lastWidth * this.lastOutChannels);
                this.nInForLayer.put(Integer.valueOf(this.numLayers - 1), Integer.valueOf(this.lastHeight * this.lastWidth * this.lastOutChannels));
            }
        } else {
            if (getLayer(this.numLayers - 1, builder) instanceof ConvolutionLayer) {
                throw new UnsupportedOperationException("Unsupported path: final convolution layer");
            }
            if (getLayer(this.numLayers - 1, builder) instanceof SubsamplingLayer) {
                throw new UnsupportedOperationException("Unsupported path: final subsampling layer");
            }
        }
        if (!(builder instanceof NeuralNetConfiguration.ListBuilder)) {
            if ((builder.getConfs().get(0).getLayer() instanceof ConvolutionLayer) || (builder.getConfs().get(0).getLayer() instanceof SubsamplingLayer)) {
                builder.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(i, i2, i3));
                return;
            }
            return;
        }
        NeuralNetConfiguration.ListBuilder listBuilder = (NeuralNetConfiguration.ListBuilder) builder;
        if ((listBuilder.getLayerwise().get(0).getLayer() instanceof ConvolutionLayer) || (listBuilder.getLayerwise().get(0).getLayer() instanceof SubsamplingLayer)) {
            builder.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(i, i2, i3));
        }
    }

    private int[] getSubSamplingOutputSize(int[] iArr, int[] iArr2, int[] iArr3) {
        int[] iArr4 = new int[iArr.length];
        for (int i = 0; i < iArr4.length; i++) {
            if (iArr2[i] == 1) {
                iArr4[i] = iArr[i] / iArr3[i];
            } else {
                iArr4[i] = ((iArr[i] - iArr2[i]) / iArr3[i]) + 1;
            }
        }
        return iArr4;
    }

    private int[] getConvolutionOutputSize(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        int[] iArr5 = new int[iArr.length];
        for (int i = 0; i < iArr5.length; i++) {
            iArr5[i] = (((iArr[i] - iArr2[i]) + (2 * iArr3[i])) / iArr4[i]) + 1;
        }
        return iArr5;
    }

    public Layer getLayer(int i, MultiLayerConfiguration.Builder builder) {
        if (!(builder instanceof NeuralNetConfiguration.ListBuilder)) {
            return builder.getConfs().get(i).getLayer();
        }
        NeuralNetConfiguration.ListBuilder listBuilder = (NeuralNetConfiguration.ListBuilder) builder;
        if (listBuilder.getLayerwise().get(Integer.valueOf(i)) == null) {
            throw new IllegalStateException("Undefined layer " + i);
        }
        return listBuilder.getLayerwise().get(Integer.valueOf(i)).getLayer();
    }

    private void setFourDtoTwoD(int i, MultiLayerConfiguration.Builder builder, FeedForwardLayer feedForwardLayer) {
        if (feedForwardLayer instanceof ConvolutionLayer) {
            return;
        }
        Layer layer = builder instanceof NeuralNetConfiguration.ListBuilder ? ((NeuralNetConfiguration.ListBuilder) builder).getLayerwise().get(Integer.valueOf(i)).getLayer() : builder.getConfs().get(i).getLayer();
        if ((layer instanceof ConvolutionLayer) || (layer instanceof SubsamplingLayer)) {
            if (!(layer instanceof ConvolutionLayer)) {
                if (layer instanceof SubsamplingLayer) {
                    builder.inputPreProcessor(Integer.valueOf(i + 1), new CnnToFeedForwardPreProcessor(this.lastHeight, this.lastWidth, this.lastOutChannels));
                    return;
                }
                return;
            }
            ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
            int i2 = this.lastHeight;
            int i3 = this.lastWidth;
            if (convolutionLayer.getKernelSize() == null) {
                throw new IllegalStateException("Unable to infer width and height without convolution layer kernel size");
            }
            feedForwardLayer.setNOut(i2 * i3 * convolutionLayer.getNOut());
            builder.inputPreProcessor(Integer.valueOf(i + 1), new CnnToFeedForwardPreProcessor(i2, i3, this.lastOutChannels));
        }
    }

    public int getLastHeight() {
        return this.lastHeight;
    }

    public void setLastHeight(int i) {
        this.lastHeight = i;
    }

    public int getLastWidth() {
        return this.lastWidth;
    }

    public void setLastWidth(int i) {
        this.lastWidth = i;
    }

    public int getLastOutChannels() {
        return this.lastOutChannels;
    }

    public void setLastOutChannels(int i) {
        this.lastOutChannels = i;
    }

    public Map<Integer, int[]> getOutSizesEachLayer() {
        return this.outSizesEachLayer;
    }

    public void setOutSizesEachLayer(Map<Integer, int[]> map) {
        this.outSizesEachLayer = map;
    }

    public Map<Integer, Integer> getnInForLayer() {
        return this.nInForLayer;
    }

    public void setnInForLayer(Map<Integer, Integer> map) {
        this.nInForLayer = map;
    }
}
