package org.deeplearning4j.models.classifiers.lstm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/models/classifiers/lstm/LSTM.class */
public class LSTM extends BaseLayer {
    private INDArray iFog;
    private INDArray iFogF;
    private INDArray c;
    private INDArray x;
    private INDArray hIn;
    private INDArray hOut;
    private INDArray u;
    private INDArray u2;
    private INDArray xi;
    private INDArray xs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/classifiers/lstm/LSTM$Beam.class */
    public static class Beam {
        private double logProba;
        private List<Integer> indices;
        private INDArray hidden;
        private INDArray c;

        public Beam(double d, List<Integer> list, INDArray iNDArray, INDArray iNDArray2) {
            this.logProba = 0.0d;
            this.logProba = d;
            this.indices = list;
            this.hidden = iNDArray;
            this.c = iNDArray2;
        }

        public double getLogProba() {
            return this.logProba;
        }

        public void setLogProba(double d) {
            this.logProba = d;
        }

        public List<Integer> getIndices() {
            return this.indices;
        }

        public void setIndices(List<Integer> list) {
            this.indices = list;
        }

        public INDArray getHidden() {
            return this.hidden;
        }

        public void setHidden(INDArray iNDArray) {
            this.hidden = iNDArray;
        }

        public INDArray getC() {
            return this.c;
        }

        public void setC(INDArray iNDArray) {
            this.c = iNDArray;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/classifiers/lstm/LSTM$BeamSearch.class */
    private class BeamSearch {
        private int nSteps;
        private INDArray h;
        private INDArray c;
        private INDArray ws;
        private List<Beam> beams = new ArrayList();
        private int beamSize = 5;

        public BeamSearch(int i, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
            this.nSteps = 0;
            this.nSteps = i;
            this.h = iNDArray2;
            this.c = iNDArray3;
            this.ws = iNDArray;
            this.beams.add(new Beam(0.0d, new ArrayList(), iNDArray2, iNDArray3));
        }

        public Collection<Pair<List<Integer>, Double>> search() {
            if (this.beamSize <= 1) {
                double d = 0.0d;
                ArrayList arrayList = new ArrayList();
                do {
                    Pair yMax = LSTM.this.yMax((INDArray) LSTM.this.lstmTick(this.ws.slice(0), this.h, this.c).getFirst());
                    arrayList.add(yMax.getFirst());
                    d += ((Double) yMax.getSecond()).doubleValue();
                    this.nSteps++;
                    if (0 == 0) {
                        break;
                    }
                } while (this.nSteps < 20);
                return Collections.singletonList(new Pair(arrayList, Double.valueOf(d)));
            }
            do {
                ArrayList arrayList2 = new ArrayList();
                for (Beam beam : this.beams) {
                    int intValue = beam.getIndices().get(beam.getIndices().size() - 1).intValue();
                    if (intValue != 0 || beam.getIndices().isEmpty()) {
                        Triple lstmTick = LSTM.this.lstmTick(this.ws.slice(intValue), beam.getHidden(), beam.getC());
                        INDArray ravel = ((INDArray) lstmTick.getFirst()).ravel();
                        INDArray exp = Transforms.exp(ravel.subi(Double.valueOf(ravel.max(Integer.MAX_VALUE).getDouble(0))));
                        INDArray log = Transforms.log(exp.divi(Nd4j.sum(exp, Integer.MAX_VALUE)).addi(Double.valueOf(Nd4j.EPS_THRESHOLD)));
                        INDArray[] sortWithIndices = Nd4j.sortWithIndices(log, 0, false);
                        for (int i = 0; i < this.beamSize; i++) {
                            int i2 = sortWithIndices[0].getInt(new int[]{i});
                            ArrayList arrayList3 = new ArrayList(beam.getIndices());
                            arrayList3.add(Integer.valueOf(i2));
                            arrayList2.add(new Beam(beam.getLogProba() + log.getDouble(i2), arrayList3, (INDArray) lstmTick.getSecond(), (INDArray) lstmTick.getThird()));
                        }
                    } else {
                        arrayList2.add(beam);
                    }
                }
                this.nSteps++;
            } while (this.nSteps < 20);
            ArrayList arrayList4 = new ArrayList();
            for (Beam beam2 : this.beams) {
                arrayList4.add(new Pair(beam2.getIndices(), Double.valueOf(beam2.getLogProba())));
            }
            return arrayList4;
        }
    }

    public LSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public LSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    public INDArray forward(INDArray iNDArray, INDArray iNDArray2) {
        this.xs = iNDArray2;
        this.xi = iNDArray;
        this.x = Nd4j.vstack(new INDArray[]{iNDArray, iNDArray2});
        return activate(this.x);
    }

    public Gradient backward(INDArray iNDArray) {
        INDArray param = getParam(LSTMParamInitializer.DECODER_WEIGHTS);
        INDArray param2 = getParam(LSTMParamInitializer.RECURRENT_WEIGHTS);
        INDArray vstack = Nd4j.vstack(new INDArray[]{Nd4j.zeros(iNDArray.columns()), iNDArray});
        INDArray mmul = this.hOut.transpose().mmul(vstack);
        INDArray sum = Nd4j.sum(mmul, 0);
        INDArray mmul2 = vstack.mmul(param.transpose());
        if (this.conf.getDropOut() > 0.0d) {
            mmul2.muli(this.u2);
        }
        INDArray zeros = Nd4j.zeros(this.iFog.shape());
        INDArray zeros2 = Nd4j.zeros(this.iFogF.shape());
        INDArray zeros3 = Nd4j.zeros(param2.shape());
        INDArray zeros4 = Nd4j.zeros(this.hIn.shape());
        INDArray zeros5 = Nd4j.zeros(this.c.shape());
        INDArray zeros6 = Nd4j.zeros(this.x.shape());
        int rows = this.hOut.rows();
        int columns = this.hOut.columns();
        for (int i = rows - 1; i > 0; i--) {
            if (this.conf.getActivationFunction().type().equals("tanh")) {
                INDArray tanh = Transforms.tanh(this.c.slice(i));
                zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}, tanh.mul(mmul2.slice(i)));
                zeros5.slice(i).addi(Transforms.pow(tanh, 2).rsubi(1).muli(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}).mul(mmul2.slice(i))));
            } else {
                zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}, this.c.slice(i).mul(mmul2.slice(i)));
                zeros5.slice(i).addi(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * columns, 3 * columns)}).mul(mmul2.slice(i)));
            }
            if (i > 0) {
                zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(columns, 2 * columns)}, this.c.slice(i - 1).mul(zeros5.slice(i)));
                zeros5.slice(i - 1).addi(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(columns, 2 * columns)}).mul(zeros5.slice(i)));
            }
            zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(0, columns)}, this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * columns, this.iFogF.columns())}).mul(zeros5.slice(i)));
            zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros2.columns())}, this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(0, columns)}).mul(zeros5.slice(i)));
            zeros.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros.columns())}, Transforms.pow(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * columns, this.iFogF.columns())}), 2).rsubi(1).mul(zeros2.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * columns, zeros2.columns())})));
            INDArray iNDArray2 = this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)});
            zeros2.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)}, iNDArray2.mul(iNDArray2.rsub(1)).mul(zeros2.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(0, 3 * columns)})));
            zeros3.addi(this.hIn.slice(i).transpose().mmul(zeros.slice(i)));
            zeros4.slice(i).assign(zeros.slice(i).mmul(param2.transpose()));
            zeros6.slice(i).assign(zeros4.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(1, 1 + columns)}));
            if (i > 0) {
                mmul2.slice(i - 1).addi(zeros4.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(1 + columns, zeros4.columns())}));
            }
            if (this.conf.getDropOut() > 0.0d) {
                zeros6.muli(this.u);
            }
        }
        clear();
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientLookupTable().put(LSTMParamInitializer.DECODER_BIAS, sum);
        defaultGradient.gradientLookupTable().put(LSTMParamInitializer.DECODER_WEIGHTS, mmul);
        defaultGradient.gradientLookupTable().put(LSTMParamInitializer.RECURRENT_WEIGHTS, zeros3);
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        INDArray param = getParam(LSTMParamInitializer.DECODER_WEIGHTS);
        INDArray param2 = getParam(LSTMParamInitializer.RECURRENT_WEIGHTS);
        INDArray param3 = getParam(LSTMParamInitializer.DECODER_BIAS);
        if (this.conf.getDropOut() > 0.0d) {
            this.u = Nd4j.rand(this.x.shape()).lti(Double.valueOf(1.0d - this.conf.getDropOut())).muli(Double.valueOf(1.0d / (1.0d - this.conf.getDropOut())));
            this.x.muli(this.u);
        }
        int rows = this.x.rows();
        int rows2 = param.rows();
        this.hIn = Nd4j.zeros(rows, param2.rows());
        this.hOut = Nd4j.zeros(rows, rows2);
        this.iFog = Nd4j.zeros(rows, rows2 * 4);
        this.iFogF = Nd4j.zeros(this.iFog.shape());
        this.c = Nd4j.zeros(rows, rows2);
        int i = 0;
        while (i < rows) {
            INDArray zeros = i == 0 ? Nd4j.zeros(rows2) : this.hOut.getRow(i - 1);
            this.hIn.put(i, 0, Double.valueOf(1.0d));
            this.hIn.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(1, 1 + rows2)}, this.x.slice(i));
            this.hIn.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(1 + rows2, this.hIn.columns())}, zeros);
            this.iFog.putRow(i, this.hIn.slice(i).mmul(param2));
            this.iFogF.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(0, 3 * rows2)}, Transforms.sigmoid(this.iFog.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(0, 3 * rows2)})));
            this.iFogF.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(3 * rows2, this.iFogF.columns())}, Transforms.tanh(this.iFog.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * rows2, this.iFog.columns())})));
            this.c.slice(i).put(new NDArrayIndex[]{NDArrayIndex.interval(3 * rows2, this.iFogF.columns())}, this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(0, rows2)}).mul(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * rows2, this.iFogF.columns())})));
            if (i > 0) {
                this.c.slice(i).addi(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(rows2, 2 * rows2)}).mul(this.c.getRow(i - 1)));
            }
            if (this.conf.getActivationFunction().type().equals("tanh")) {
                this.hOut.slice(i).assign(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * rows2, 3 * rows2)}).mul(Transforms.tanh(this.c.getRow(i))));
            } else {
                this.hOut.slice(i).assign(this.iFogF.slice(i).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * rows2, 3 * rows2)}).mul(this.c.getRow(i)));
            }
            i++;
        }
        if (this.conf.getDropOut() > 0.0d) {
            this.u2 = Nd4j.rand(this.hOut.shape()).lti(Double.valueOf(1.0d - this.conf.getDropOut())).muli(Double.valueOf(1.0d / (1.0d - this.conf.getDropOut())));
            this.hOut.muli(this.u2);
        }
        return this.hOut.get(new NDArrayIndex[]{NDArrayIndex.interval(1, this.hOut.rows())}).mmul(param).addiRowVector(param3);
    }

    public Collection<Pair<List<Integer>, Double>> predict(INDArray iNDArray, INDArray iNDArray2) {
        int rows = getParam(LSTMParamInitializer.DECODER_WEIGHTS).rows();
        Triple<INDArray, INDArray, INDArray> lstmTick = lstmTick(iNDArray, Nd4j.zeros(rows), Nd4j.zeros(rows));
        return new BeamSearch(20, iNDArray2, lstmTick.getSecond(), lstmTick.getThird()).search();
    }

    private void clear() {
        this.u = null;
        this.hIn = null;
        this.hOut = null;
        this.iFog = null;
        this.iFogF = null;
        this.c = null;
        this.x = null;
        this.u2 = null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<Integer, Double> yMax(INDArray iNDArray) {
        INDArray exp = Transforms.exp(iNDArray.linearView().rsub(Double.valueOf(iNDArray.max(Integer.MAX_VALUE).getDouble(0))));
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(Transforms.log(exp.divi(exp.sum(Integer.MAX_VALUE)).addi(Double.valueOf(Nd4j.EPS_THRESHOLD))), 0, true);
        int i = sortWithIndices[0].getInt(new int[]{0});
        return new Pair<>(Integer.valueOf(i), Double.valueOf(sortWithIndices[1].getDouble(i)));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Triple<INDArray, INDArray, INDArray> lstmTick(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray param = getParam(LSTMParamInitializer.DECODER_WEIGHTS);
        INDArray param2 = getParam(LSTMParamInitializer.RECURRENT_WEIGHTS);
        INDArray param3 = getParam(LSTMParamInitializer.DECODER_BIAS);
        int rows = param.rows();
        INDArray zeros = Nd4j.zeros(1, param2.rows());
        zeros.putRow(0, Nd4j.ones(zeros.columns()));
        zeros.slice(0).put(new NDArrayIndex[]{NDArrayIndex.interval(1, 1 + rows)}, iNDArray);
        zeros.slice(0).put(new NDArrayIndex[]{NDArrayIndex.interval(1 + rows, zeros.columns())}, iNDArray2);
        INDArray zeros2 = Nd4j.zeros(1, rows * 4);
        INDArray zeros3 = Nd4j.zeros(zeros2.shape());
        INDArray zeros4 = Nd4j.zeros(rows);
        zeros2.putScalar(0, zeros.slice(0).mmul(param2).getDouble(0));
        NDArrayIndex[] nDArrayIndexArr = {NDArrayIndex.interval(0, 3 * rows)};
        zeros3.slice(0).put(nDArrayIndexArr, Transforms.sigmoid(this.iFogF.slice(0).get(nDArrayIndexArr)));
        NDArrayIndex[] nDArrayIndexArr2 = {NDArrayIndex.interval(3 * rows, zeros3.columns())};
        zeros3.slice(0).put(nDArrayIndexArr2, Transforms.tanh(zeros3.slice(0).get(nDArrayIndexArr2)));
        zeros4.slice(0).assign(zeros3.slice(0).get(new NDArrayIndex[]{NDArrayIndex.interval(0, rows)}).mul(zeros3.slice(0).get(new NDArrayIndex[]{NDArrayIndex.interval(3 * rows, zeros3.columns())})).addi(zeros3.slice(0).get(new NDArrayIndex[]{NDArrayIndex.interval(rows, 2 * rows)})).muli(iNDArray3));
        if (this.conf.getActivationFunction().equals("tanh")) {
            this.hOut.slice(0).assign(zeros3.slice(0).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * rows, 3 * rows)}).mul(Transforms.tanh(zeros4.slice(0))));
        } else {
            this.hOut.slice(0).assign(zeros3.slice(0).get(new NDArrayIndex[]{NDArrayIndex.interval(2 * rows, 3 * rows)}).mul(zeros4.slice(0)));
        }
        return new Triple<>(this.hOut.mmul(param).addiRowVector(param3), this.hOut, zeros4);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit() {
        new Solver.Builder().model(this).configure(conf()).listeners(conf().getListeners()).build().optimize();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        setParams(params().addi(gradient.gradient()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return LossFunctions.score(this.xs, this.conf.getLossFunction(), (INDArray) Activations.softMaxRows().apply(forward(this.xi, this.xs)), this.conf.getL2(), this.conf.isUseRegularization());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return (INDArray) Activations.softMaxRows().apply(forward(this.xi, this.xs));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        int i = 0;
        INDArray param = getParam(LSTMParamInitializer.DECODER_WEIGHTS);
        INDArray param2 = getParam(LSTMParamInitializer.RECURRENT_WEIGHTS);
        INDArray param3 = getParam(LSTMParamInitializer.DECODER_BIAS);
        INDArray linearView = param2.linearView();
        INDArray linearView2 = param.linearView();
        INDArray linearView3 = param3.linearView();
        int length = linearView.length() + linearView2.length();
        boolean z = false;
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            if (i == linearView.length()) {
                i = 0;
                z = true;
            } else if (i == linearView2.length() && z) {
                i = 0;
            }
            if (i2 < param2.length()) {
                int i3 = i;
                i++;
                param2.linearView().putScalar(i3, iNDArray.getDouble(i2));
            } else if (i2 < length) {
                int i4 = i;
                i++;
                linearView2.putScalar(i4, iNDArray.getDouble(i2));
            } else {
                int i5 = i;
                i++;
                linearView3.putScalar(i5, iNDArray.getDouble(i2));
            }
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        this.xi = iNDArray.slice(0);
        this.xs = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(1, iNDArray.rows()), NDArrayIndex.interval(0, iNDArray.columns())});
        new Solver.Builder().configure(this.conf).model(this).listeners(this.conf.getListeners()).build().optimize();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient getGradient() {
        return backward(Activations.softMaxRows().applyDerivative(forward(this.xi, this.xs)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(getGradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.xi.rows();
    }
}
