package org.apache.sysds.runtime.instructions.gpu;

import java.util.ArrayList;
import jcuda.Pointer;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/DnnGPUInstruction.class */
public class DnnGPUInstruction extends GPUInstruction {
    private CPOperand _input3;
    private CPOperand _input4;
    private CPOperand _input5;
    private CPOperand _input6;
    private CPOperand _input7;
    private CPOperand _input8;
    private CPOperand _output2;
    private CPOperand _output3;
    private CPOperand _output4;
    private CPOperand _output5;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride;
    private ArrayList<CPOperand> _padding;
    private double _intermediateMemoryBudget;

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, double d) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand3, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        if (!str.equals("bias_add") && !str.equals("bias_multiply") && !str.equals("relu_backward")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + str);
        }
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, CPOperand cPOperand8, String str, String str2, double d) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand7, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        this._input3 = cPOperand3;
        this._input4 = cPOperand4;
        this._input5 = cPOperand5;
        this._input6 = cPOperand6;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output2 = cPOperand8;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, CPOperand cPOperand8, CPOperand cPOperand9, CPOperand cPOperand10, CPOperand cPOperand11, CPOperand cPOperand12, CPOperand cPOperand13, String str, String str2, double d) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand9, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        this._input3 = cPOperand3;
        this._input4 = cPOperand4;
        this._input5 = cPOperand5;
        this._input6 = cPOperand6;
        this._input7 = cPOperand7;
        this._input8 = cPOperand8;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output2 = cPOperand10;
        this._output3 = cPOperand11;
        this._output4 = cPOperand12;
        this._output5 = cPOperand13;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, double d) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand4, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        if (!str.equals("channel_sums")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + str);
        }
        this._input3 = cPOperand3;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, String str, String str2, double d) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand5, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        if (!str.equals("update_nesterov_x")) {
            throw new DMLRuntimeException("Incorrect opcode: " + str);
        }
        this._input3 = cPOperand3;
        this._input4 = cPOperand4;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, double d) {
        this(cPOperand, cPOperand2, cPOperand4, str, str2, arrayList, arrayList2, arrayList3, arrayList4, d);
        this._input3 = cPOperand3;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, double d) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand3, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
        this._intermediateMemoryBudget = d;
    }

    public DnnGPUInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, String str, String str2, double d) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, cPOperand7, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._intermediateMemoryBudget = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        if (!str.equals("batch_norm2d_test")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be batch_norm2d_test, but found " + str);
        }
        this._input3 = cPOperand3;
        this._input4 = cPOperand4;
        this._input5 = cPOperand5;
        this._input6 = cPOperand6;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._intermediateMemoryBudget = d;
    }

    public static DnnGPUInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("conv2d") || str2.equalsIgnoreCase("conv2d_backward_filter") || str2.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 16);
            CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
            CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[15]);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[13]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[14]));
            return new DnnGPUInstruction(cPOperand, cPOperand2, cPOperand3, str2, str, (ArrayList<CPOperand>) arrayList, (ArrayList<CPOperand>) arrayList2, (ArrayList<CPOperand>) arrayList3, (ArrayList<CPOperand>) arrayList4, Double.parseDouble(instructionPartsWithValueType[16]));
        }
        if (str2.equalsIgnoreCase("maxpooling_backward") || str2.equalsIgnoreCase("avgpooling_backward")) {
            boolean z = false;
            if (instructionPartsWithValueType.length == 18) {
                z = true;
            } else {
                InstructionUtils.checkNumFields(instructionPartsWithValueType, 16);
            }
            CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[2]);
            CPOperand cPOperand6 = z ? new CPOperand(instructionPartsWithValueType[15]) : null;
            CPOperand cPOperand7 = z ? new CPOperand(instructionPartsWithValueType[16]) : new CPOperand(instructionPartsWithValueType[15]);
            double parseDouble = z ? Double.parseDouble(instructionPartsWithValueType[17]) : Double.parseDouble(instructionPartsWithValueType[16]);
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            ArrayList arrayList8 = new ArrayList();
            arrayList5.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList5.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[13]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[14]));
            return new DnnGPUInstruction(cPOperand4, cPOperand5, cPOperand6, cPOperand7, str2, str, (ArrayList<CPOperand>) arrayList5, (ArrayList<CPOperand>) arrayList6, (ArrayList<CPOperand>) arrayList7, (ArrayList<CPOperand>) arrayList8, parseDouble);
        }
        if (str2.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 17);
            CPOperand cPOperand8 = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand9 = new CPOperand(instructionPartsWithValueType[2]);
            CPOperand cPOperand10 = new CPOperand(instructionPartsWithValueType[3]);
            CPOperand cPOperand11 = new CPOperand(instructionPartsWithValueType[16]);
            ArrayList arrayList9 = new ArrayList();
            ArrayList arrayList10 = new ArrayList();
            ArrayList arrayList11 = new ArrayList();
            ArrayList arrayList12 = new ArrayList();
            arrayList9.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList9.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList10.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList10.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList11.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList11.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList11.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList11.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList12.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList12.add(new CPOperand(instructionPartsWithValueType[13]));
            arrayList12.add(new CPOperand(instructionPartsWithValueType[14]));
            arrayList12.add(new CPOperand(instructionPartsWithValueType[15]));
            return new DnnGPUInstruction(cPOperand8, cPOperand9, cPOperand10, cPOperand11, str2, str, (ArrayList<CPOperand>) arrayList9, (ArrayList<CPOperand>) arrayList10, (ArrayList<CPOperand>) arrayList11, (ArrayList<CPOperand>) arrayList12, Double.parseDouble(instructionPartsWithValueType[17]));
        }
        if (str2.equalsIgnoreCase("maxpooling") || str2.equalsIgnoreCase("avgpooling")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 15);
            CPOperand cPOperand12 = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand13 = new CPOperand(instructionPartsWithValueType[14]);
            ArrayList arrayList13 = new ArrayList();
            ArrayList arrayList14 = new ArrayList();
            ArrayList arrayList15 = new ArrayList();
            ArrayList arrayList16 = new ArrayList();
            arrayList13.add(new CPOperand(instructionPartsWithValueType[2]));
            arrayList13.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList14.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList14.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList15.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList15.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList15.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList15.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList16.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList16.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList16.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList16.add(new CPOperand(instructionPartsWithValueType[13]));
            return new DnnGPUInstruction(cPOperand12, (CPOperand) null, cPOperand13, str2, str, (ArrayList<CPOperand>) arrayList13, (ArrayList<CPOperand>) arrayList14, (ArrayList<CPOperand>) arrayList15, (ArrayList<CPOperand>) arrayList16, Double.parseDouble(instructionPartsWithValueType[15]));
        }
        if (str2.equalsIgnoreCase("bias_add") || str2.equalsIgnoreCase("relu_backward") || str2.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), str2, str, Double.parseDouble(instructionPartsWithValueType[4]));
        }
        if (str2.equalsIgnoreCase("channel_sums")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (str2.equalsIgnoreCase("update_nesterov_x")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (str2.equalsIgnoreCase("lstm")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 8);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), new CPOperand(instructionPartsWithValueType[7]), new CPOperand(instructionPartsWithValueType[8]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (str2.equalsIgnoreCase("batch_norm2d") || str2.equalsIgnoreCase("lstm_backward")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 13);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), new CPOperand(instructionPartsWithValueType[7]), new CPOperand(instructionPartsWithValueType[8]), new CPOperand(instructionPartsWithValueType[9], Types.ValueType.FP64, Types.DataType.MATRIX), new CPOperand(instructionPartsWithValueType[10]), new CPOperand(instructionPartsWithValueType[11]), new CPOperand(instructionPartsWithValueType[12]), new CPOperand(instructionPartsWithValueType[13]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (str2.equalsIgnoreCase("batch_norm2d_backward")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 9);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), null, null, new CPOperand(instructionPartsWithValueType[7]), new CPOperand(instructionPartsWithValueType[8]), new CPOperand(instructionPartsWithValueType[9]), null, null, str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (str2.equalsIgnoreCase("batch_norm2d_test")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 7);
            return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), new CPOperand(instructionPartsWithValueType[7]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        if (!str2.equalsIgnoreCase("batch_norm2d_train")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 12);
        return new DnnGPUInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), new CPOperand(instructionPartsWithValueType[7]), null, new CPOperand(instructionPartsWithValueType[8]), new CPOperand(instructionPartsWithValueType[9]), new CPOperand(instructionPartsWithValueType[10]), new CPOperand(instructionPartsWithValueType[11]), new CPOperand(instructionPartsWithValueType[12]), str2, str, DataExpression.DEFAULT_DELIM_FILL_VALUE);
    }

    private void processBiasInstruction(String str, ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject denseMatrixOutputForGPUInstruction = getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns());
        if (str.equalsIgnoreCase("bias_add")) {
            LibMatrixCUDA.biasAdd(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, denseMatrixOutputForGPUInstruction);
        } else if (str.equalsIgnoreCase("bias_multiply")) {
            LibMatrixCUDA.biasMultiply(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, denseMatrixOutputForGPUInstruction);
        }
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processBatchNorm2dInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input4.getName());
        MatrixObject matrixInputForGPUInstruction5 = getMatrixInputForGPUInstruction(executionContext, this._input5.getName());
        String stringValue = executionContext.getScalarInput(this._input6).getStringValue();
        double doubleValue = executionContext.getScalarInput(this._input7).getDoubleValue();
        MatrixObject denseMatrixOutputForGPUInstruction = getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns());
        if (stringValue.equalsIgnoreCase("train")) {
            double doubleValue2 = 1.0d - executionContext.getScalarInput(this._input8).getDoubleValue();
            LibMatrixCuDNN.batchNormalizationForwardTraining(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, matrixInputForGPUInstruction3, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5, denseMatrixOutputForGPUInstruction, getDenseMatrixOutputForGPUInstruction(executionContext, this._output2.getName(), matrixInputForGPUInstruction4.getNumRows(), matrixInputForGPUInstruction4.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output3.getName(), matrixInputForGPUInstruction5.getNumRows(), matrixInputForGPUInstruction5.getNumColumns()), doubleValue, doubleValue2, getDenseMatrixOutputForGPUInstruction(executionContext, this._output4.getName(), matrixInputForGPUInstruction4.getNumRows(), matrixInputForGPUInstruction4.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output5.getName(), matrixInputForGPUInstruction5.getNumRows(), matrixInputForGPUInstruction5.getNumColumns()));
            executionContext.releaseMatrixOutputForGPUInstruction(this._output2.getName());
            executionContext.releaseMatrixOutputForGPUInstruction(this._output3.getName());
            executionContext.releaseMatrixOutputForGPUInstruction(this._output4.getName());
            executionContext.releaseMatrixOutputForGPUInstruction(this._output5.getName());
        } else {
            if (!stringValue.equalsIgnoreCase("test")) {
                throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + stringValue);
            }
            LibMatrixCuDNN.batchNormalizationForwardInference(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, matrixInputForGPUInstruction3, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5, denseMatrixOutputForGPUInstruction, doubleValue);
            executionContext.setMatrixOutput(this._output2.getName(), new MatrixBlock((int) matrixInputForGPUInstruction4.getNumRows(), (int) matrixInputForGPUInstruction4.getNumColumns(), true));
            executionContext.setMatrixOutput(this._output3.getName(), new MatrixBlock((int) matrixInputForGPUInstruction5.getNumRows(), (int) matrixInputForGPUInstruction5.getNumColumns(), true));
            executionContext.setMatrixOutput(this._output4.getName(), new MatrixBlock((int) matrixInputForGPUInstruction4.getNumRows(), (int) matrixInputForGPUInstruction4.getNumColumns(), true));
            executionContext.setMatrixOutput(this._output5.getName(), new MatrixBlock((int) matrixInputForGPUInstruction5.getNumRows(), (int) matrixInputForGPUInstruction5.getNumColumns(), true));
        }
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input4.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processBatchNorm2dTrainInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input4.getName());
        MatrixObject matrixInputForGPUInstruction5 = getMatrixInputForGPUInstruction(executionContext, this._input5.getName());
        double doubleValue = executionContext.getScalarInput(this._input6).getDoubleValue();
        double doubleValue2 = 1.0d - executionContext.getScalarInput(this._input7).getDoubleValue();
        LibMatrixCuDNN.batchNormalizationForwardTraining(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, matrixInputForGPUInstruction3, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output2.getName(), matrixInputForGPUInstruction4.getNumRows(), matrixInputForGPUInstruction4.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output3.getName(), matrixInputForGPUInstruction5.getNumRows(), matrixInputForGPUInstruction5.getNumColumns()), doubleValue, doubleValue2, getDenseMatrixOutputForGPUInstruction(executionContext, this._output4.getName(), matrixInputForGPUInstruction4.getNumRows(), matrixInputForGPUInstruction4.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output5.getName(), matrixInputForGPUInstruction5.getNumRows(), matrixInputForGPUInstruction5.getNumColumns()));
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input4.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output3.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output4.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output5.getName());
    }

    private void processBatchNorm2dTestInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input4.getName());
        MatrixObject matrixInputForGPUInstruction5 = getMatrixInputForGPUInstruction(executionContext, this._input5.getName());
        double doubleValue = executionContext.getScalarInput(this._input6).getDoubleValue();
        LibMatrixCuDNN.batchNormalizationForwardInference(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, matrixInputForGPUInstruction3, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns()), doubleValue);
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input4.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    public void processBatchNorm2dBackwardInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        double doubleValue = executionContext.getScalarInput(this._input4).getDoubleValue();
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input5.getName());
        MatrixObject matrixInputForGPUInstruction5 = getMatrixInputForGPUInstruction(executionContext, this._input6.getName());
        LibMatrixCuDNN.batchNormalizationBackward(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, matrixInputForGPUInstruction3, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output2.getName(), matrixInputForGPUInstruction3.getNumRows(), matrixInputForGPUInstruction3.getNumColumns()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output3.getName(), matrixInputForGPUInstruction3.getNumRows(), matrixInputForGPUInstruction3.getNumColumns()), doubleValue, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5);
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input6.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output3.getName());
    }

    public void processReLUBackwardInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        LibMatrixCUDA.reluBackward(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, getMatrixInputForGPUInstruction(executionContext, this._input2.getName()), getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), matrixInputForGPUInstruction.getNumRows(), matrixInputForGPUInstruction.getNumColumns()));
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processChannelSumsInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        int longValue = (int) executionContext.getScalarInput(this._input2).getLongValue();
        int longValue2 = (int) executionContext.getScalarInput(this._input3).getLongValue();
        if (longValue * longValue2 != matrixInputForGPUInstruction.getNumColumns()) {
            throw new DMLRuntimeException("Expected rows*cols" + longValue + "*" + longValue2 + " to be equal to number of columns of input " + matrixInputForGPUInstruction.getNumColumns());
        }
        LibMatrixCUDA.channelSums(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), longValue, 1L), longValue, longValue2);
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processNesterovUpdateInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        double doubleValue = (int) executionContext.getScalarInput(this._input4).getDoubleValue();
        int i = LibMatrixCUDA.toInt(matrixInputForGPUInstruction.getNumRows());
        int i2 = LibMatrixCUDA.toInt(matrixInputForGPUInstruction.getNumColumns());
        MatrixObject denseMatrixOutputForGPUInstruction = getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), i, i2);
        GPUContext gPUContext = executionContext.getGPUContext(0);
        String extendedOpcode = getExtendedOpcode();
        LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("update_nesterov_x", ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(i * i2)), LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction, extendedOpcode), LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction2, extendedOpcode), LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction3, extendedOpcode), Double.valueOf(doubleValue), LibMatrixCUDA.getDensePointer(gPUContext, denseMatrixOutputForGPUInstruction, extendedOpcode), Integer.valueOf(i * i2));
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static int toInt(long j) throws DMLRuntimeException {
        if (j >= OptimizerUtils.MAX_NUMCELLS_CP_DENSE || j <= -2147483648L) {
            throw new DMLRuntimeException("GPU : Exceeded supported size " + j);
        }
        return (int) j;
    }

    private void processLstmBackwardInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        GPUContext gPUContext = executionContext.getGPUContext(0);
        String extendedOpcode = getExtendedOpcode();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input4.getName());
        int i = toInt(matrixInputForGPUInstruction.getNumColumns());
        Pointer densePointer = LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction, extendedOpcode);
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        int i2 = toInt(matrixInputForGPUInstruction2.getNumRows()) - i;
        Pointer densePointerForCuDNN = LibMatrixCuDNN.getDensePointerForCuDNN(gPUContext, matrixInputForGPUInstruction2, extendedOpcode, i2 + i, 4 * i);
        Pointer densePointerForCuDNN2 = LibMatrixCuDNN.getDensePointerForCuDNN(gPUContext, matrixInputForGPUInstruction3, extendedOpcode, 1, 4 * i);
        Pointer allocate = gPUContext.allocate(extendedOpcode, (i2 + i + 2) * 4 * i * LibMatrixCUDA.sizeOfDataType, true);
        LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((i2 + i + 2) * 4 * i), densePointerForCuDNN, densePointerForCuDNN2, allocate, Integer.valueOf(i2), Integer.valueOf(i));
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        Pointer densePointer2 = LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction4, extendedOpcode);
        int i3 = toInt(matrixInputForGPUInstruction4.getNumRows());
        int i4 = toInt(matrixInputForGPUInstruction4.getNumColumns() / i2);
        Pointer allocate2 = gPUContext.allocate(extendedOpcode, i3 * i4 * i2 * LibMatrixCUDA.sizeOfDataType, false);
        LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(i3 * i4 * i2), densePointer2, allocate2, Integer.valueOf(i3), Integer.valueOf(i2), Integer.valueOf(i4 * i2), Integer.valueOf(i3 * i4 * i2));
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        LibMatrixCuDNN.lstmBackward(executionContext, gPUContext, extendedOpcode, allocate2, densePointer, LibMatrixCUDA.getDensePointer(gPUContext, getMatrixInputForGPUInstruction(executionContext, this._input5.getName()), extendedOpcode), allocate, this._input7.getName(), this._input8.getName(), this._output.getName(), this._output2.getName(), this._output3.getName(), this._output4.getName(), this._output5.getName(), executionContext.getScalarInput(this._input6).getBooleanValue(), i3, i, i2, i4);
        gPUContext.cudaFreeHelper(extendedOpcode, allocate, DMLScript.EAGER_CUDA_FREE);
        gPUContext.cudaFreeHelper(extendedOpcode, allocate2, DMLScript.EAGER_CUDA_FREE);
        executionContext.releaseMatrixInputForGPUInstruction(this._input4.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
    }

    private void processLstmInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        GPUContext gPUContext = executionContext.getGPUContext(0);
        String extendedOpcode = getExtendedOpcode();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input4.getName());
        int i = toInt(matrixInputForGPUInstruction.getNumColumns());
        Pointer densePointer = LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction, extendedOpcode);
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
        int i2 = toInt(matrixInputForGPUInstruction2.getNumRows()) - i;
        Pointer densePointerForCuDNN = LibMatrixCuDNN.getDensePointerForCuDNN(gPUContext, matrixInputForGPUInstruction2, extendedOpcode, i2 + i, 4 * i);
        Pointer densePointerForCuDNN2 = LibMatrixCuDNN.getDensePointerForCuDNN(gPUContext, matrixInputForGPUInstruction3, extendedOpcode, 1, 4 * i);
        Pointer allocate = gPUContext.allocate(extendedOpcode, (i2 + i + 2) * 4 * i * LibMatrixCUDA.sizeOfDataType, false);
        LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((i2 + i + 2) * 4 * i), densePointerForCuDNN, densePointerForCuDNN2, allocate, Integer.valueOf(i2), Integer.valueOf(i));
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        boolean booleanValue = executionContext.getScalarInput(this._input6).getBooleanValue();
        MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        Pointer densePointer2 = LibMatrixCUDA.getDensePointer(gPUContext, matrixInputForGPUInstruction4, extendedOpcode);
        int i3 = toInt(matrixInputForGPUInstruction4.getNumRows());
        int i4 = toInt(matrixInputForGPUInstruction4.getNumColumns() / i2);
        Pointer allocate2 = gPUContext.allocate(extendedOpcode, i3 * i4 * i2 * LibMatrixCUDA.sizeOfDataType, false);
        LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(i3 * i4 * i2), densePointer2, allocate2, Integer.valueOf(i3), Integer.valueOf(i2), Integer.valueOf(i4 * i2), Integer.valueOf(i3 * i4 * i2));
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        LibMatrixCuDNN.lstm(executionContext, gPUContext, extendedOpcode, allocate2, allocate, densePointer, LibMatrixCUDA.getDensePointer(gPUContext, getMatrixInputForGPUInstruction(executionContext, this._input5.getName()), extendedOpcode), booleanValue, this._output.getName(), this._output2.getName(), i3, i, i2, i4);
        gPUContext.cudaFreeHelper(extendedOpcode, allocate, DMLScript.EAGER_CUDA_FREE);
        gPUContext.cudaFreeHelper(extendedOpcode, allocate2, DMLScript.EAGER_CUDA_FREE);
        executionContext.releaseMatrixInputForGPUInstruction(this._input4.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input5.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (this.instOpcode.equalsIgnoreCase("bias_add") || this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            processBiasInstruction(this.instOpcode, executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            processReLUBackwardInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("channel_sums")) {
            processChannelSumsInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("update_nesterov_x")) {
            processNesterovUpdateInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("lstm")) {
            processLstmInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("lstm_backward")) {
            processLstmBackwardInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d")) {
            processBatchNorm2dInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
            processBatchNorm2dBackwardInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
            processBatchNorm2dTestInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
            processBatchNorm2dTrainInstruction(executionContext);
            return;
        }
        GPUStatistics.incrementNoOfExecutedGPUInst();
        int scalarInput = getScalarInput(executionContext, this._padding, 0);
        int scalarInput2 = getScalarInput(executionContext, this._padding, 1);
        int scalarInput3 = getScalarInput(executionContext, this._stride, 0);
        int scalarInput4 = getScalarInput(executionContext, this._stride, 1);
        int scalarInput5 = getScalarInput(executionContext, this._input_shape, 0);
        int scalarInput6 = getScalarInput(executionContext, this._input_shape, 1);
        int scalarInput7 = getScalarInput(executionContext, this._input_shape, 2);
        int scalarInput8 = getScalarInput(executionContext, this._input_shape, 3);
        int scalarInput9 = getScalarInput(executionContext, this._filter_shape, 0);
        int scalarInput10 = getScalarInput(executionContext, this._filter_shape, 2);
        int scalarInput11 = getScalarInput(executionContext, this._filter_shape, 3);
        int p = (int) DnnUtils.getP(scalarInput7, scalarInput10, scalarInput3, scalarInput);
        int q = (int) DnnUtils.getQ(scalarInput8, scalarInput11, scalarInput4, scalarInput2);
        if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
            if (matrixInputForGPUInstruction.getNumRows() != scalarInput5 || matrixInputForGPUInstruction.getNumColumns() != scalarInput6 * scalarInput7 * scalarInput8) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (matrixInputForGPUInstruction2.getNumRows() != scalarInput9 || matrixInputForGPUInstruction2.getNumColumns() != scalarInput6 * scalarInput10 * scalarInput11) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            LibMatrixCuDNN.conv2d(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput5, scalarInput9 * p * q), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            MatrixObject matrixInputForGPUInstruction3 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            MatrixObject matrixInputForGPUInstruction4 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
            MatrixObject matrixInputForGPUInstruction5 = getMatrixInputForGPUInstruction(executionContext, this._input3.getName());
            if (matrixInputForGPUInstruction3.getNumRows() != scalarInput5 || matrixInputForGPUInstruction3.getNumColumns() != scalarInput6 * scalarInput7 * scalarInput8) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (matrixInputForGPUInstruction5.getNumRows() != scalarInput9 || matrixInputForGPUInstruction5.getNumColumns() != scalarInput6 * scalarInput10 * scalarInput11) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            LibMatrixCuDNN.conv2dBiasAdd(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction3, matrixInputForGPUInstruction4, matrixInputForGPUInstruction5, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput5, scalarInput9 * p * q), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixObject matrixInputForGPUInstruction6 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            MatrixObject matrixInputForGPUInstruction7 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
            if (matrixInputForGPUInstruction6.getNumRows() != scalarInput5 || matrixInputForGPUInstruction6.getNumColumns() != scalarInput6 * scalarInput7 * scalarInput8) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter");
            }
            if (matrixInputForGPUInstruction7.getNumRows() != scalarInput5 || matrixInputForGPUInstruction7.getNumColumns() != scalarInput9 * p * q) {
                int i = scalarInput9 * p * q;
                DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " + matrixInputForGPUInstruction7.getNumRows() + " != " + dMLRuntimeException + " || " + scalarInput5 + " != " + matrixInputForGPUInstruction7.getNumColumns());
                throw dMLRuntimeException;
            }
            LibMatrixCuDNN.conv2dBackwardFilter(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction6, matrixInputForGPUInstruction7, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput9, scalarInput6 * scalarInput10 * scalarInput11), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixObject matrixInputForGPUInstruction8 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            MatrixObject matrixInputForGPUInstruction9 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
            if (matrixInputForGPUInstruction8.getNumRows() != scalarInput9 || matrixInputForGPUInstruction8.getNumColumns() != scalarInput6 * scalarInput10 * scalarInput11) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data");
            }
            if (matrixInputForGPUInstruction9.getNumRows() != scalarInput5 || matrixInputForGPUInstruction9.getNumColumns() != scalarInput9 * p * q) {
                int i2 = scalarInput9 * p * q;
                DMLRuntimeException dMLRuntimeException2 = new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " + matrixInputForGPUInstruction9.getNumRows() + " != " + dMLRuntimeException2 + " || " + scalarInput5 + " != " + matrixInputForGPUInstruction9.getNumColumns());
                throw dMLRuntimeException2;
            }
            LibMatrixCuDNN.conv2dBackwardData(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction8, matrixInputForGPUInstruction9, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput5, scalarInput6 * scalarInput7 * scalarInput8), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling")) {
            MatrixObject matrixInputForGPUInstruction10 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            if (matrixInputForGPUInstruction10.getNumRows() != scalarInput5 || matrixInputForGPUInstruction10.getNumColumns() != scalarInput6 * scalarInput7 * scalarInput8) {
                int i3 = scalarInput6 * scalarInput7 * scalarInput8;
                DMLRuntimeException dMLRuntimeException3 = new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + matrixInputForGPUInstruction10.getNumRows() + " != " + dMLRuntimeException3 + " || " + scalarInput5 + " != " + matrixInputForGPUInstruction10.getNumColumns());
                throw dMLRuntimeException3;
            }
            LibMatrixCuDNN.pooling(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction10, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput5, scalarInput6 * p * q), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this.instOpcode.equalsIgnoreCase("maxpooling") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG, this._intermediateMemoryBudget);
        } else {
            if (!this.instOpcode.equalsIgnoreCase("maxpooling_backward") && !this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
                throw new DMLRuntimeException("Unsupported GPU context for " + this.instOpcode);
            }
            MatrixObject matrixInputForGPUInstruction11 = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
            MatrixObject matrixInputForGPUInstruction12 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
            MatrixObject matrixInputForGPUInstruction13 = this._input3 != null ? getMatrixInputForGPUInstruction(executionContext, this._input3.getName()) : null;
            if (matrixInputForGPUInstruction12.getNumRows() != scalarInput5 || matrixInputForGPUInstruction12.getNumColumns() != scalarInput6 * p * q) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
            }
            if (matrixInputForGPUInstruction11.getNumRows() != scalarInput5 || matrixInputForGPUInstruction11.getNumColumns() != scalarInput6 * scalarInput7 * scalarInput8) {
                int i4 = scalarInput9 * p * q;
                DMLRuntimeException dMLRuntimeException4 = new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + matrixInputForGPUInstruction11.getNumRows() + " != " + dMLRuntimeException4 + " || " + scalarInput5 + " != " + matrixInputForGPUInstruction11.getNumColumns());
                throw dMLRuntimeException4;
            }
            LibMatrixCuDNN.poolingBackward(executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction11, matrixInputForGPUInstruction12, matrixInputForGPUInstruction13, getDenseMatrixOutputForGPUInstruction(executionContext, this._output.getName(), scalarInput5, scalarInput6 * scalarInput7 * scalarInput8), scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput, scalarInput2, scalarInput3, scalarInput4, p, q, this.instOpcode.equalsIgnoreCase("maxpooling_backward") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG, this._intermediateMemoryBudget);
        }
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        boolean z = this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling");
        boolean z2 = this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward");
        if (!z) {
            executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        }
        if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add") || (z2 && this._input3 != null)) {
            executionContext.releaseMatrixInputForGPUInstruction(this._input3.getName());
        }
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static int getScalarInput(ExecutionContext executionContext, ArrayList<CPOperand> arrayList, int i) {
        return (int) executionContext.getScalarInput(arrayList.get(i)).getLongValue();
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this._input1);
        arrayList.add(this._input2);
        arrayList.add(this._input3);
        arrayList.add(this._input4);
        arrayList.add(this._input5);
        arrayList.add(this._input6);
        arrayList.add(this._input7);
        arrayList.add(this._input8);
        if (this._input_shape != null && !this._input_shape.isEmpty()) {
            arrayList.addAll(this._input_shape);
        }
        if (this._filter_shape != null && !this._filter_shape.isEmpty()) {
            arrayList.addAll(this._filter_shape);
        }
        if (this._stride != null && !this._stride.isEmpty()) {
            arrayList.addAll(this._stride);
        }
        if (this._padding != null && !this._padding.isEmpty()) {
            arrayList.addAll(this._padding);
        }
        return Pair.of(this._output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, (CPOperand[]) arrayList.toArray(new CPOperand[0]))));
    }
}
