package org.deeplearning4j.nn.layers.recurrent;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.class */
public class BidirectionalLayer implements RecurrentLayer {
    private NeuralNetConfiguration conf;
    private RecurrentLayer fwd;
    private RecurrentLayer bwd;
    private Bidirectional layerConf;
    private INDArray paramsView;
    private INDArray gradientView;
    private transient Map<String, INDArray> gradientViews;
    private INDArray input;
    private INDArray outFwd;
    private INDArray outBwd;

    public BidirectionalLayer(@NonNull NeuralNetConfiguration neuralNetConfiguration, @NonNull RecurrentLayer recurrentLayer, @NonNull RecurrentLayer recurrentLayer2) {
        if (neuralNetConfiguration == null) {
            throw new NullPointerException("conf");
        }
        if (recurrentLayer == null) {
            throw new NullPointerException("fwd");
        }
        if (recurrentLayer2 == null) {
            throw new NullPointerException("bwd");
        }
        this.conf = neuralNetConfiguration;
        this.fwd = recurrentLayer;
        this.bwd = recurrentLayer2;
        this.layerConf = (Bidirectional) neuralNetConfiguration.getLayer();
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray) {
        throw new UnsupportedOperationException("Cannot RnnTimeStep bidirectional layers");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Map<String, INDArray> rnnGetPreviousState() {
        throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore no previous state is supported");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnSetPreviousState(Map<String, INDArray> map) {
        throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore no previous state is supported");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnClearPreviousState() {
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Map<String, INDArray> rnnGetTBPTTState() {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnSetTBPTTState(Map<String, INDArray> map) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setCacheMode(CacheMode cacheMode) {
        this.fwd.setCacheMode(cacheMode);
        this.bwd.setCacheMode(cacheMode);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        return this.fwd.calcL2(z) + this.bwd.calcL2(z);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        return this.fwd.calcL1(z) + this.bwd.calcL1(z);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray iNDArray2;
        INDArray iNDArray3;
        int size = iNDArray.size(1) / 2;
        switch (this.layerConf.getMode()) {
            case ADD:
                iNDArray2 = iNDArray;
                iNDArray3 = iNDArray;
                break;
            case MUL:
                iNDArray2 = iNDArray.dup(iNDArray.ordering()).muli(this.outBwd);
                iNDArray3 = iNDArray.dup(iNDArray.ordering()).muli(this.outFwd);
                break;
            case AVERAGE:
                iNDArray2 = iNDArray.dup(iNDArray.ordering()).muli(Double.valueOf(0.5d));
                iNDArray3 = iNDArray2;
                break;
            case CONCAT:
                iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size), NDArrayIndex.all()});
                iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size), NDArrayIndex.all()});
                break;
            default:
                throw new RuntimeException("Unknown mode: " + this.layerConf.getMode());
        }
        INDArray reverseTimeSeries = TimeSeriesUtils.reverseTimeSeries(iNDArray3);
        Pair<Gradient, INDArray> backpropGradient = this.fwd.backpropGradient(iNDArray2);
        Pair<Gradient, INDArray> backpropGradient2 = this.bwd.backpropGradient(reverseTimeSeries);
        DefaultGradient defaultGradient = new DefaultGradient(this.gradientView);
        for (Map.Entry<String, INDArray> entry : ((Gradient) backpropGradient.getFirst()).gradientForVariable().entrySet()) {
            defaultGradient.gradientForVariable().put(BidirectionalParamInitializer.FORWARD_PREFIX + entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, INDArray> entry2 : ((Gradient) backpropGradient2.getFirst()).gradientForVariable().entrySet()) {
            defaultGradient.gradientForVariable().put("b" + entry2.getKey(), entry2.getValue());
        }
        return new Pair<>(defaultGradient, ((INDArray) backpropGradient.getRight()).addi(TimeSeriesUtils.reverseTimeSeries((INDArray) backpropGradient2.getRight())));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return activate(iNDArray, trainingMode);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(Layer.TrainingMode trainingMode) {
        return activate(trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        setInput(iNDArray);
        return activate(trainingMode);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        INDArray activate = this.fwd.activate(z);
        INDArray reverseTimeSeries = TimeSeriesUtils.reverseTimeSeries(this.bwd.activate(z));
        switch (this.layerConf.getMode()) {
            case ADD:
                return activate.addi(reverseTimeSeries);
            case MUL:
                this.outFwd = activate.detach();
                this.outBwd = reverseTimeSeries.detach();
                return activate.mul(reverseTimeSeries);
            case AVERAGE:
                return activate.addi(reverseTimeSeries).muli(Double.valueOf(0.5d));
            case CONCAT:
                return Nd4j.concat(1, new INDArray[]{activate, reverseTimeSeries});
            default:
                throw new RuntimeException("Unknown mode: " + this.layerConf.getMode());
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return activate(z);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activate(false);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activate();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Cannot transpose layer");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m117clone() {
        throw new UnsupportedOperationException("Clone not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Collection<IterationListener> getListeners() {
        return this.fwd.getListeners();
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(IterationListener... iterationListenerArr) {
        this.fwd.setListeners(iterationListenerArr);
        this.bwd.setListeners(iterationListenerArr);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void addListeners(IterationListener... iterationListenerArr) {
        this.fwd.addListeners(iterationListenerArr);
        this.bwd.addListeners(iterationListenerArr);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.fwd.score() + this.bwd.score();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        this.fwd.computeGradientAndScore();
        this.bwd.computeGradientAndScore();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void accumulateScore(double d) {
        this.fwd.accumulateScore(d);
        this.bwd.accumulateScore(d);
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return this.paramsView;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return this.fwd.numParams() + this.bwd.numParams();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams(boolean z) {
        return this.fwd.numParams(z) + this.bwd.numParams(z);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        this.paramsView.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        this.paramsView = iNDArray;
        int length = iNDArray.length();
        this.fwd.setParamsViewArray(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, length)}));
        this.bwd.setParamsViewArray(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(length, 2 * length)}));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.gradientView;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (this.paramsView != null && iNDArray.length() != numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true) + ", got array of length " + iNDArray.length());
        }
        this.gradientView = iNDArray;
        int length = iNDArray.length() / 2;
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, length)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(length, 2 * length)});
        this.fwd.setBackpropGradientsViewArray(iNDArray2);
        this.bwd.setBackpropGradientsViewArray(iNDArray3);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.fwd.batchSize();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public ConvexOptimizer getOptimizer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        String substring = str.substring(1);
        return str.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX) ? this.fwd.getParam(substring) : this.bwd.getParam(substring);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void initParams() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : this.fwd.paramTable(z).entrySet()) {
            linkedHashMap.put(BidirectionalParamInitializer.FORWARD_PREFIX + entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, INDArray> entry2 : this.bwd.paramTable(z).entrySet()) {
            linkedHashMap.put("b" + entry2.getKey(), entry2.getValue());
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            setParam(entry.getKey(), entry.getValue());
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        String substring = str.substring(1);
        if (str.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)) {
            this.fwd.setParam(substring, iNDArray);
        } else {
            this.bwd.setParam(substring, iNDArray);
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        this.fwd.clear();
        this.bwd.clear();
        this.input = null;
        this.outFwd = null;
        this.outBwd = null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyConstraints(int i, int i2) {
        this.fwd.applyConstraints(i, i2);
        this.bwd.applyConstraints(i, i2);
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public void init() {
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<IterationListener> collection) {
        this.fwd.setListeners(collection);
        this.bwd.setListeners(collection);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.fwd.setIndex(i);
        this.bwd.setIndex(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.fwd.getIndex();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIterationCount() {
        return this.fwd.getIterationCount();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getEpochCount() {
        return this.fwd.getEpochCount();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIterationCount(int i) {
        this.fwd.setIterationCount(i);
        this.bwd.setIterationCount(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setEpochCount(int i) {
        this.fwd.setEpochCount(i);
        this.bwd.setEpochCount(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
        this.fwd.setInput(iNDArray);
        this.bwd.setInput(TimeSeriesUtils.reverseTimeSeries(iNDArray));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void migrateInput() {
        this.fwd.migrateInput();
        this.bwd.migrateInput();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInputMiniBatchSize(int i) {
        this.fwd.setInputMiniBatchSize(i);
        this.bwd.setInputMiniBatchSize(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getInputMiniBatchSize() {
        return this.fwd.getInputMiniBatchSize();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        this.fwd.setMaskArray(iNDArray);
        this.bwd.setMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(iNDArray));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getMaskArray() {
        return this.fwd.getMaskArray();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return this.fwd.isPretrainLayer();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
        this.fwd.clearNoiseWeightParams();
        this.bwd.clearNoiseWeightParams();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        Pair<INDArray, MaskState> feedForwardMaskArray = this.fwd.feedForwardMaskArray(iNDArray, maskState, i);
        this.bwd.feedForwardMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(iNDArray), maskState, i);
        return feedForwardMaskArray;
    }
}
