package org.nd4j.autodiff.samediff.ops;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDRNN.class */
public class SDRNN extends SDOps {
    public SDRNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable gru(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5) {
        SDValidation.validateNumerical("gru", "x", sDVariable);
        SDValidation.validateNumerical("gru", "hLast", sDVariable2);
        SDValidation.validateNumerical("gru", "Wx", sDVariable3);
        SDValidation.validateNumerical("gru", "Wh", sDVariable4);
        SDValidation.validateNumerical("gru", "biases", sDVariable5);
        return new GRU(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5).outputVariable();
    }

    public SDVariable gru(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5) {
        SDValidation.validateNumerical("gru", "x", sDVariable);
        SDValidation.validateNumerical("gru", "hLast", sDVariable2);
        SDValidation.validateNumerical("gru", "Wx", sDVariable3);
        SDValidation.validateNumerical("gru", "Wh", sDVariable4);
        SDValidation.validateNumerical("gru", "biases", sDVariable5);
        return this.sd.updateVariableNameAndReference(new GRU(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5).outputVariable(), str);
    }

    public SDVariable[] gruCell(SDVariable sDVariable, SDVariable sDVariable2, GRUWeights gRUWeights) {
        SDValidation.validateNumerical("gruCell", "x", sDVariable);
        SDValidation.validateNumerical("gruCell", "hLast", sDVariable2);
        return new GRUCell(this.sd, sDVariable, sDVariable2, gRUWeights).outputVariables();
    }

    public SDVariable[] gruCell(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, GRUWeights gRUWeights) {
        SDValidation.validateNumerical("gruCell", "x", sDVariable);
        SDValidation.validateNumerical("gruCell", "hLast", sDVariable2);
        return this.sd.updateVariableNamesAndReferences(new GRUCell(this.sd, sDVariable, sDVariable2, gRUWeights).outputVariables(), strArr);
    }

    public SDVariable[] lstmCell(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmCell", "x", sDVariable);
        SDValidation.validateNumerical("lstmCell", "cLast", sDVariable2);
        SDValidation.validateNumerical("lstmCell", "yLast", sDVariable3);
        return new LSTMBlockCell(this.sd, sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration).outputVariables();
    }

    public SDVariable[] lstmCell(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmCell", "x", sDVariable);
        SDValidation.validateNumerical("lstmCell", "cLast", sDVariable2);
        SDValidation.validateNumerical("lstmCell", "yLast", sDVariable3);
        return this.sd.updateVariableNamesAndReferences(new LSTMBlockCell(this.sd, sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration).outputVariables(), strArr);
    }

    public SDVariable[] lstmLayer(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        SDValidation.validateNumerical("lstmLayer", "x", sDVariable);
        SDValidation.validateNumerical("lstmLayer", "cLast", sDVariable2);
        SDValidation.validateNumerical("lstmLayer", "yLast", sDVariable3);
        SDValidation.validateNumerical("lstmLayer", "maxTSLength", sDVariable4);
        return new LSTMLayer(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMLayerWeights, lSTMLayerConfig).outputVariables();
    }

    public SDVariable[] lstmLayer(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        SDValidation.validateNumerical("lstmLayer", "x", sDVariable);
        SDValidation.validateNumerical("lstmLayer", "cLast", sDVariable2);
        SDValidation.validateNumerical("lstmLayer", "yLast", sDVariable3);
        SDValidation.validateNumerical("lstmLayer", "maxTSLength", sDVariable4);
        return this.sd.updateVariableNamesAndReferences(new LSTMLayer(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMLayerWeights, lSTMLayerConfig).outputVariables(), strArr);
    }

    public SDVariable[] lstmLayer(SDVariable sDVariable, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        SDValidation.validateNumerical("lstmLayer", "x", sDVariable);
        return new LSTMLayer(this.sd, sDVariable, null, null, null, lSTMLayerWeights, lSTMLayerConfig).outputVariables();
    }

    public SDVariable[] lstmLayer(String[] strArr, SDVariable sDVariable, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        SDValidation.validateNumerical("lstmLayer", "x", sDVariable);
        return this.sd.updateVariableNamesAndReferences(new LSTMLayer(this.sd, sDVariable, null, null, null, lSTMLayerWeights, lSTMLayerConfig).outputVariables(), strArr);
    }

    public SDVariable lstmblock(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmblock", "maxTSLength", sDVariable);
        SDValidation.validateNumerical("lstmblock", "x", sDVariable2);
        SDValidation.validateNumerical("lstmblock", "cLast", sDVariable3);
        SDValidation.validateNumerical("lstmblock", "yLast", sDVariable4);
        return new LSTMBlock(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMWeights, lSTMConfiguration).outputVariable();
    }

    public SDVariable lstmblock(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmblock", "maxTSLength", sDVariable);
        SDValidation.validateNumerical("lstmblock", "x", sDVariable2);
        SDValidation.validateNumerical("lstmblock", "cLast", sDVariable3);
        SDValidation.validateNumerical("lstmblock", "yLast", sDVariable4);
        return this.sd.updateVariableNameAndReference(new LSTMBlock(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMWeights, lSTMConfiguration).outputVariable(), str);
    }

    public SDVariable lstmblock(SDVariable sDVariable, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmblock", "x", sDVariable);
        return new LSTMBlock(this.sd, null, sDVariable, null, null, lSTMWeights, lSTMConfiguration).outputVariable();
    }

    public SDVariable lstmblock(String str, SDVariable sDVariable, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        SDValidation.validateNumerical("lstmblock", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LSTMBlock(this.sd, null, sDVariable, null, null, lSTMWeights, lSTMConfiguration).outputVariable(), str);
    }

    public SDVariable sru(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sru", "x", sDVariable);
        SDValidation.validateNumerical("sru", "initialC", sDVariable2);
        SDValidation.validateNumerical("sru", "mask", sDVariable3);
        return new SRU(this.sd, sDVariable, sDVariable2, sDVariable3, sRUWeights).outputVariable();
    }

    public SDVariable sru(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sru", "x", sDVariable);
        SDValidation.validateNumerical("sru", "initialC", sDVariable2);
        SDValidation.validateNumerical("sru", "mask", sDVariable3);
        return this.sd.updateVariableNameAndReference(new SRU(this.sd, sDVariable, sDVariable2, sDVariable3, sRUWeights).outputVariable(), str);
    }

    public SDVariable sru(SDVariable sDVariable, SDVariable sDVariable2, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sru", "x", sDVariable);
        SDValidation.validateNumerical("sru", "initialC", sDVariable2);
        return new SRU(this.sd, sDVariable, sDVariable2, null, sRUWeights).outputVariable();
    }

    public SDVariable sru(String str, SDVariable sDVariable, SDVariable sDVariable2, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sru", "x", sDVariable);
        SDValidation.validateNumerical("sru", "initialC", sDVariable2);
        return this.sd.updateVariableNameAndReference(new SRU(this.sd, sDVariable, sDVariable2, null, sRUWeights).outputVariable(), str);
    }

    public SDVariable sruCell(SDVariable sDVariable, SDVariable sDVariable2, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sruCell", "x", sDVariable);
        SDValidation.validateNumerical("sruCell", "cLast", sDVariable2);
        return new SRUCell(this.sd, sDVariable, sDVariable2, sRUWeights).outputVariable();
    }

    public SDVariable sruCell(String str, SDVariable sDVariable, SDVariable sDVariable2, SRUWeights sRUWeights) {
        SDValidation.validateNumerical("sruCell", "x", sDVariable);
        SDValidation.validateNumerical("sruCell", "cLast", sDVariable2);
        return this.sd.updateVariableNameAndReference(new SRUCell(this.sd, sDVariable, sDVariable2, sRUWeights).outputVariable(), str);
    }
}
