/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonProperty;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

public class RnnToCnnPreProcessor
implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;
    private int product;

    public RnnToCnnPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = numChannels;
        this.product = inputHeight * inputWidth * numChannels;
    }

    @Override
    public INDArray preProcess(INDArray input, int miniBatchSize) {
        INDArray in2d;
        int[] shape;
        if (input.ordering() == 'c') {
            input = input.dup('f');
        }
        if ((shape = input.shape())[0] == 1) {
            in2d = input.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0});
        } else if (shape[2] == 1) {
            in2d = input.tensorAlongDimension(0, new int[]{1, 0});
        } else {
            INDArray permuted = input.permute(new int[]{0, 2, 1});
            in2d = permuted.reshape('f', shape[0] * shape[2], shape[1]);
        }
        return in2d.dup('c').reshape('c', new int[]{shape[0] * shape[2], this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override
    public INDArray backprop(INDArray output, int miniBatchSize) {
        if (output.ordering() == 'f') {
            output = output.dup('c');
        }
        int[] shape = output.shape();
        INDArray twod = output.reshape('c', output.size(0), ArrayUtil.prod((int[])output.shape()) / output.size(0));
        INDArray reshaped = twod.dup('f').reshape('f', new int[]{miniBatchSize, shape[0] / miniBatchSize, this.product});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    @Override
    public RnnToCnnPreProcessor clone() {
        return new RnnToCnnPreProcessor(this.inputHeight, this.inputWidth, this.numChannels);
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public void setInputHeight(int inputHeight) {
        this.inputHeight = inputHeight;
    }

    public void setInputWidth(int inputWidth) {
        this.inputWidth = inputWidth;
    }

    public void setNumChannels(int numChannels) {
        this.numChannels = numChannels;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RnnToCnnPreProcessor)) {
            return false;
        }
        RnnToCnnPreProcessor other = (RnnToCnnPreProcessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getInputHeight() != other.getInputHeight()) {
            return false;
        }
        if (this.getInputWidth() != other.getInputWidth()) {
            return false;
        }
        if (this.getNumChannels() != other.getNumChannels()) {
            return false;
        }
        return this.product == other.product;
    }

    protected boolean canEqual(Object other) {
        return other instanceof RnnToCnnPreProcessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getInputHeight();
        result = result * 59 + this.getInputWidth();
        result = result * 59 + this.getNumChannels();
        result = result * 59 + this.product;
        return result;
    }

    public String toString() {
        return "RnnToCnnPreProcessor(inputHeight=" + this.getInputHeight() + ", inputWidth=" + this.getInputWidth() + ", numChannels=" + this.getNumChannels() + ", product=" + this.product + ")";
    }
}

