package org.deeplearning4j.nn.layers.recurrent;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/ImageLSTM.class */
public class ImageLSTM extends BaseLayer<org.deeplearning4j.nn.conf.layers.ImageLSTM> {
    private INDArray iFogZ;
    private INDArray iFogA;
    private INDArray memCellActivations;
    private INDArray hIn;
    private INDArray hOut;
    private INDArray outputActivations;
    private INDArray u;
    private INDArray u2;
    private INDArray xi;
    private INDArray xs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/ImageLSTM$Beam.class */
    public static class Beam {
        private double logProba;
        private List<Integer> indices;
        private INDArray hidden;
        private INDArray c;

        public Beam(double d, List<Integer> list, INDArray iNDArray, INDArray iNDArray2) {
            this.logProba = 0.0d;
            this.logProba = d;
            this.indices = list;
            this.hidden = iNDArray;
            this.c = iNDArray2;
        }

        public double getLogProba() {
            return this.logProba;
        }

        public void setLogProba(double d) {
            this.logProba = d;
        }

        public List<Integer> getIndices() {
            return this.indices;
        }

        public void setIndices(List<Integer> list) {
            this.indices = list;
        }

        public INDArray getHidden() {
            return this.hidden;
        }

        public void setHidden(INDArray iNDArray) {
            this.hidden = iNDArray;
        }

        public INDArray getC() {
            return this.c;
        }

        public void setC(INDArray iNDArray) {
            this.c = iNDArray;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/ImageLSTM$BeamSearch.class */
    private class BeamSearch {
        private int nSteps;
        private INDArray h;
        private INDArray c;
        private INDArray ws;
        private List<Beam> beams = new ArrayList();
        private int beamSize = 5;

        public BeamSearch(int i, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
            this.nSteps = 0;
            this.nSteps = i;
            this.h = iNDArray2;
            this.c = iNDArray3;
            this.ws = iNDArray;
            this.beams.add(new Beam(0.0d, new ArrayList(), iNDArray2, iNDArray3));
        }

        public Collection<Pair<List<Integer>, Double>> search() {
            if (this.beamSize <= 1) {
                double d = 0.0d;
                ArrayList arrayList = new ArrayList();
                do {
                    Pair yMax = ImageLSTM.this.yMax((INDArray) ImageLSTM.this.lstmTick(this.ws.slice(0), this.h, this.c).getFirst());
                    arrayList.add(yMax.getFirst());
                    d += ((Double) yMax.getSecond()).doubleValue();
                    this.nSteps++;
                    if (0 == 0) {
                        break;
                    }
                } while (this.nSteps < 20);
                return Collections.singletonList(new Pair(arrayList, Double.valueOf(d)));
            }
            do {
                ArrayList arrayList2 = new ArrayList();
                for (Beam beam : this.beams) {
                    int intValue = beam.getIndices().get(beam.getIndices().size() - 1).intValue();
                    if (intValue != 0 || beam.getIndices().isEmpty()) {
                        Triple lstmTick = ImageLSTM.this.lstmTick(this.ws.slice(intValue), beam.getHidden(), beam.getC());
                        INDArray ravel = ((INDArray) lstmTick.getFirst()).ravel();
                        INDArray exp = Transforms.exp(ravel.subi(Double.valueOf(ravel.max(new int[]{Integer.MAX_VALUE}).getDouble(0))));
                        INDArray log = Transforms.log(exp.divi(Nd4j.sum(exp, Integer.MAX_VALUE)).addi(Double.valueOf(Nd4j.EPS_THRESHOLD)));
                        INDArray[] sortWithIndices = Nd4j.sortWithIndices(log, 0, false);
                        for (int i = 0; i < this.beamSize; i++) {
                            int i2 = sortWithIndices[0].getInt(new int[]{i});
                            ArrayList arrayList3 = new ArrayList(beam.getIndices());
                            arrayList3.add(Integer.valueOf(i2));
                            arrayList2.add(new Beam(beam.getLogProba() + log.getDouble(i2), arrayList3, (INDArray) lstmTick.getSecond(), (INDArray) lstmTick.getThird()));
                        }
                    } else {
                        arrayList2.add(beam);
                    }
                }
                this.nSteps++;
            } while (this.nSteps < 20);
            ArrayList arrayList4 = new ArrayList();
            for (Beam beam2 : this.beams) {
                arrayList4.add(new Pair(beam2.getIndices(), Double.valueOf(beam2.getLogProba())));
            }
            return arrayList4;
        }
    }

    public ImageLSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        throw new UnsupportedOperationException("Layer disabled: Version in development and will be provided in a later release.");
    }

    public ImageLSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        throw new UnsupportedOperationException("Layer disabled: Version in development and will be provided in a later release.");
    }

    public void setInput(INDArray iNDArray, INDArray iNDArray2) {
        this.xi = iNDArray;
        this.xs = iNDArray2;
        setInput(Nd4j.vstack(new INDArray[]{iNDArray, iNDArray2}));
    }

    public Pair<Gradient, INDArray> backpropGradient(Gradient gradient, INDArray iNDArray) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        INDArray zeros = Nd4j.zeros(this.hIn.shape());
        INDArray zeros2 = Nd4j.zeros(this.input.shape());
        INDArray gradientFor = gradient.getGradientFor("b");
        INDArray mul = this.hOut.transpose().mul(gradientFor);
        INDArray sum = Nd4j.sum(gradientFor, 0);
        INDArray mul2 = param.mul(gradientFor);
        INDArray vstack = Nd4j.vstack(new INDArray[]{Nd4j.zeros(mul2.columns()), mul2});
        if (this.conf.isUseDropConnect() & (this.conf.getLayer().getDropOut() > 0.0d)) {
            vstack.muli(this.u2);
        }
        INDArray zeros3 = Nd4j.zeros(this.iFogZ.shape());
        INDArray zeros4 = Nd4j.zeros(this.iFogA.shape());
        INDArray zeros5 = Nd4j.zeros(param2.shape());
        INDArray zeros6 = Nd4j.zeros(this.memCellActivations.shape());
        int rows = this.hOut.rows();
        int columns = this.hOut.columns();
        for (int i = rows - 1; i > 0; i--) {
            if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
                INDArray tanh = Transforms.tanh(this.memCellActivations.slice(i));
                zeros4.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}, tanh.mul(vstack.slice(i)));
                zeros6.slice(i).addi(Transforms.pow(tanh, 2).rsubi(1).muli(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}).mul(vstack.slice(i))));
            } else {
                zeros4.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}, this.memCellActivations.slice(i).mul(vstack.slice(i)));
                zeros6.slice(i).addi(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}).mul(vstack.slice(i)));
            }
            if (i > 0) {
                zeros4.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(columns, 2 * columns)}, this.memCellActivations.slice(i - 1).mul(zeros6.slice(i)));
                zeros6.slice(i - 1).addi(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(columns, 2 * columns)}).mul(zeros6.slice(i)));
            }
            zeros4.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(0, columns)}, this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * columns, this.iFogA.columns())}).mul(zeros6.slice(i)));
            zeros4.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros4.columns())}, this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(0, columns)}).mul(zeros6.slice(i)));
            zeros3.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros3.columns())}, Transforms.pow(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * columns, this.iFogA.columns())}), 2).rsubi(1).mul(zeros4.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros4.columns())})));
            INDArray iNDArray2 = this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)});
            zeros3.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)}, iNDArray2.mul(iNDArray2.rsub(1)).mul(zeros4.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)})));
            zeros5.addi(this.hIn.slice(i).transpose().mmul(zeros3.slice(i)));
            zeros.slice(i).assign(zeros3.slice(i).mmul(param2.transpose()));
            zeros2.slice(i).assign(zeros.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(1, 1 + columns)}));
            if (i > 0) {
                vstack.slice(i - 1).addi(zeros.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(1 + columns, zeros.columns())}));
            }
            if (this.conf.isUseDropConnect() & (this.conf.getLayer().getDropOut() > 0.0d)) {
                zeros2.muli(this.u);
            }
        }
        clear();
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("W", mul);
        defaultGradient.gradientForVariable().put("RW", zeros5);
        defaultGradient.gradientForVariable().put("b", sum);
        return new Pair<>(defaultGradient, vstack);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        INDArray param3 = getParam("b");
        if (this.conf.getLayer().getDropOut() > 0.0d) {
            this.u = Nd4j.rand(this.input.shape()).lti(Double.valueOf(1.0d - this.conf.getLayer().getDropOut())).muli(Double.valueOf(1.0d / (1.0d - this.conf.getLayer().getDropOut())));
            this.input.muli(this.u);
        }
        int size = this.input.size(0);
        int size2 = param.size(0);
        this.hIn = Nd4j.zeros(size, param2.size(0));
        this.hOut = Nd4j.zeros(size, size2);
        this.iFogZ = Nd4j.zeros(size, size2 * 4);
        this.iFogA = Nd4j.zeros(this.iFogZ.shape());
        this.memCellActivations = Nd4j.zeros(size, size2);
        int i = 0;
        while (i < size) {
            INDArray zeros = i == 0 ? Nd4j.zeros(size2) : this.hOut.slice(i - 1);
            INDArray zeros2 = i == 0 ? Nd4j.zeros(size2) : this.memCellActivations.slice(i - 1);
            this.hIn.slice(i).put(i, 0, 1);
            this.hIn.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(1, 1 + size2), NDArrayIndex.interval(i, i + 1)}, this.input.slice(i));
            this.hIn.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(1 + size2, this.hIn.columns()), NDArrayIndex.interval(0, 1)}, zeros);
            this.iFogZ.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size2 * 4), NDArrayIndex.interval(0, 1)}, this.hIn.slice(i).mmul(param2));
            this.iFogA.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(0, 3 * size2)}, Transforms.sigmoid(this.iFogZ.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(0, 3 * size2)})));
            this.iFogA.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(3 * size2, this.iFogA.columns() - 1)}, Transforms.tanh(this.iFogZ.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * size2, this.iFogZ.columns() - 1)})));
            this.memCellActivations.slice(i).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size2)}, this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(0, size2)}).mul(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * size2, this.iFogA.columns())})));
            if (i > 0) {
                this.memCellActivations.slice(i).addi(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(size2, 2 * size2)}).mul(zeros2));
            }
            if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
                this.hOut.slice(i).assign(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * size2, 3 * size2)}).mul(Transforms.tanh(this.memCellActivations.slice(i))));
            } else {
                this.hOut.slice(i).assign(this.iFogA.slice(i).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * size2, 3 * size2)}).mul(this.memCellActivations.slice(i)));
            }
            i++;
        }
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            this.u2 = Dropout.applyDropout(this.hOut, this.conf.getLayer().getDropOut(), this.u2);
            this.hOut.muli(this.u2);
        }
        this.outputActivations = this.hOut.get(new INDArrayIndex[]{NDArrayIndex.interval(1, this.hOut.rows())}).mmul(param).addiRowVector(param3);
        return this.outputActivations;
    }

    public Collection<Pair<List<Integer>, Double>> predict(INDArray iNDArray, INDArray iNDArray2) {
        int rows = getParam("W").rows();
        Triple<INDArray, INDArray, INDArray> lstmTick = lstmTick(iNDArray, Nd4j.zeros(rows), Nd4j.zeros(rows));
        return new BeamSearch(20, iNDArray2, lstmTick.getSecond(), lstmTick.getThird()).search();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void clear() {
        this.hIn = null;
        this.input = null;
        this.iFogZ = null;
        this.iFogA = null;
        this.u = null;
        this.u2 = null;
        this.memCellActivations = null;
        this.outputActivations = null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<Integer, Double> yMax(INDArray iNDArray) {
        INDArray exp = Transforms.exp(iNDArray.linearView().rsub(Double.valueOf(iNDArray.max(new int[]{Integer.MAX_VALUE}).getDouble(0))));
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(Transforms.log(exp.divi(exp.sum(new int[]{Integer.MAX_VALUE})).addi(Double.valueOf(Nd4j.EPS_THRESHOLD))), 0, true);
        int i = sortWithIndices[0].getInt(new int[]{0});
        return new Pair<>(Integer.valueOf(i), Double.valueOf(sortWithIndices[1].getDouble(i)));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Triple<INDArray, INDArray, INDArray> lstmTick(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        INDArray param3 = getParam("b");
        int rows = param.rows();
        INDArray zeros = Nd4j.zeros(1, param2.rows());
        zeros.putRow(0, Nd4j.ones(zeros.columns()));
        zeros.slice(0).put(new INDArrayIndex[]{NDArrayIndex.interval(1, 1 + rows)}, iNDArray);
        zeros.slice(0).put(new INDArrayIndex[]{NDArrayIndex.interval(1 + rows, zeros.columns())}, iNDArray2);
        INDArray zeros2 = Nd4j.zeros(1, rows * 4);
        INDArray zeros3 = Nd4j.zeros(zeros2.shape());
        INDArray zeros4 = Nd4j.zeros(rows);
        zeros2.putScalar(0, zeros.slice(0).mmul(param2).getDouble(0));
        INDArrayIndex[] iNDArrayIndexArr = {NDArrayIndex.interval(0, 3 * rows)};
        zeros3.slice(0).put(iNDArrayIndexArr, Transforms.sigmoid(this.iFogA.slice(0).get(iNDArrayIndexArr)));
        INDArrayIndex[] iNDArrayIndexArr2 = {NDArrayIndex.interval(3 * rows, zeros3.columns())};
        zeros3.slice(0).put(iNDArrayIndexArr2, Transforms.tanh(zeros3.slice(0).get(iNDArrayIndexArr2)));
        zeros4.slice(0).assign(zeros3.slice(0).get(new INDArrayIndex[]{NDArrayIndex.interval(0, rows)}).mul(zeros3.slice(0).get(new INDArrayIndex[]{NDArrayIndex.interval(3 * rows, zeros3.columns())})).addi(zeros3.slice(0).get(new INDArrayIndex[]{NDArrayIndex.interval(rows, 2 * rows)})).muli(iNDArray3));
        if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
            this.outputActivations.slice(0).assign(zeros3.slice(0).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * rows, 3 * rows)}).mul(Transforms.tanh(zeros4.slice(0))));
        } else {
            this.outputActivations.slice(0).assign(zeros3.slice(0).get(new INDArrayIndex[]{NDArrayIndex.interval(2 * rows, 3 * rows)}).mul(zeros4.slice(0)));
        }
        return new Triple<>(this.outputActivations.mmul(param).addiRowVector(param3), this.outputActivations, zeros4);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * (Transforms.pow(getParam("RW"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * (Transforms.abs(getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        this.xi = iNDArray.slice(0);
        this.xs = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(1, iNDArray.rows()), NDArrayIndex.interval(0, iNDArray.columns())});
        new Solver.Builder().configure(this.conf).model(this).listeners(getListeners()).build().optimize();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.xi.rows();
    }
}
