package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.NativeHelper;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/DnnSPInstruction.class */
public class DnnSPInstruction extends UnarySPInstruction {
    private CPOperand _in2;
    private CPOperand _in3;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride;
    private ArrayList<CPOperand> _padding;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/DnnSPInstruction$RDDConv2dMapMMFunction.class */
    public static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2106155380020232155L;
        Broadcast<MatrixBlock> filterBroadcast;
        Broadcast<MatrixBlock> biasBroadcast;
        DnnParameters params;
        String instOpcode;
        boolean enableNative;
        long numRows;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/DnnSPInstruction$RDDConv2dMapMMFunction$MapsideDnnPartitionIterator.class */
        public class MapsideDnnPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapsideDnnPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                if (((MatrixIndexes) tuple2._1).getRowIndex() > RDDConv2dMapMMFunction.this.numRows || ((MatrixIndexes) tuple2._1).getColumnIndex() != 1) {
                    throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD");
                }
                MatrixBlock processRectangularBlock = RDDConv2dMapMMFunction.this.processRectangularBlock((MatrixBlock) tuple2._2);
                if (processRectangularBlock.getNumRows() != 1) {
                    throw new RuntimeException("Expected the output to have 1 row");
                }
                return new Tuple2<>((MatrixIndexes) tuple2._1, processRectangularBlock);
            }
        }

        public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> broadcast, DnnParameters dnnParameters, String str, Broadcast<MatrixBlock> broadcast2, long j, boolean z) {
            this.filterBroadcast = null;
            this.biasBroadcast = null;
            this.params = null;
            this.instOpcode = null;
            this.numRows = 0L;
            this.filterBroadcast = broadcast;
            this.params = dnnParameters;
            this.instOpcode = str;
            this.biasBroadcast = broadcast2;
            this.numRows = j;
            this.enableNative = z;
        }

        private MatrixBlock processRectangularBlock(MatrixBlock matrixBlock) throws Exception {
            MatrixBlock allocateBlock;
            if (this.instOpcode.equalsIgnoreCase("conv2d")) {
                MatrixBlock matrixBlock2 = (MatrixBlock) this.filterBroadcast.getValue();
                if (matrixBlock2.isEmptyBlock() || matrixBlock.isEmptyBlock()) {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, true);
                } else {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, false).allocateDenseBlock();
                    if (this.enableNative) {
                        LibMatrixNative.conv2d(matrixBlock, matrixBlock2, allocateBlock, this.params);
                    } else {
                        LibMatrixDNN.conv2d(matrixBlock, matrixBlock2, allocateBlock, this.params);
                    }
                }
            } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
                MatrixBlock matrixBlock3 = (MatrixBlock) this.filterBroadcast.getValue();
                MatrixBlock matrixBlock4 = (MatrixBlock) this.biasBroadcast.getValue();
                if ((matrixBlock3.isEmptyBlock() || matrixBlock.isEmptyBlock()) && matrixBlock4.isEmptyBlock()) {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, true);
                } else {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, false).allocateDenseBlock();
                    if (!matrixBlock4.isEmptyBlock()) {
                        this.params.bias = matrixBlock4;
                    }
                    if (this.enableNative) {
                        LibMatrixNative.conv2d(matrixBlock, matrixBlock3, allocateBlock, this.params);
                    } else {
                        LibMatrixDNN.conv2d(matrixBlock, matrixBlock3, allocateBlock, this.params);
                    }
                }
            } else if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                if (matrixBlock.isEmptyBlock()) {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, true);
                } else {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, false).allocateBlock();
                    if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
                        allocateBlock.getDenseBlock().set(-1.7976931348623157E308d);
                    }
                    LibMatrixDNN.pooling(matrixBlock, allocateBlock, this.params, LibMatrixDNN.PoolingType.MAX);
                }
            } else {
                if (!this.instOpcode.equalsIgnoreCase("avgpooling") && !this.instOpcode.equalsIgnoreCase("relu_avgpooling")) {
                    throw new RuntimeException("Not implemented");
                }
                if (matrixBlock.isEmptyBlock()) {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, true);
                } else {
                    allocateBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, false).allocateBlock();
                    LibMatrixDNN.pooling(matrixBlock, allocateBlock, this.params, LibMatrixDNN.PoolingType.AVG);
                }
            }
            return allocateBlock;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new MapsideDnnPartitionIterator(it);
        }
    }

    private DnnSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
    }

    private DnnSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand3, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._in2 = cPOperand2;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
    }

    private DnnSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand4, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._in2 = cPOperand2;
        this._in3 = cPOperand3;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
    }

    private DnnSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand3, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._in2 = cPOperand2;
    }

    public static DnnSPInstruction parseInstruction(String str) {
        CPOperand cPOperand = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("maxpooling") || str2.equalsIgnoreCase("relu_maxpooling")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 14);
            cPOperand.split(instructionPartsWithValueType[1]);
            cPOperand2.split(instructionPartsWithValueType[14]);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList.add(new CPOperand(instructionPartsWithValueType[2]));
            arrayList.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[13]));
            return new DnnSPInstruction(cPOperand, cPOperand2, str2, str, arrayList, arrayList2, arrayList3, arrayList4);
        }
        if (str2.equalsIgnoreCase("maxpooling_backward") || str2.equalsIgnoreCase("conv2d") || str2.equalsIgnoreCase("conv2d_backward_filter") || str2.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 15);
            cPOperand.split(instructionPartsWithValueType[1]);
            CPOperand cPOperand3 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            cPOperand3.split(instructionPartsWithValueType[2]);
            cPOperand2.split(instructionPartsWithValueType[15]);
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            ArrayList arrayList8 = new ArrayList();
            arrayList5.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList5.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[13]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[14]));
            return new DnnSPInstruction(cPOperand, cPOperand3, cPOperand2, str2, str, arrayList5, arrayList6, arrayList7, arrayList8);
        }
        if (!str2.equalsIgnoreCase("conv2d_bias_add")) {
            if (!str2.equalsIgnoreCase("bias_add")) {
                throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
            }
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
            cPOperand.split(instructionPartsWithValueType[1]);
            CPOperand cPOperand4 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            cPOperand4.split(instructionPartsWithValueType[2]);
            cPOperand2.split(instructionPartsWithValueType[3]);
            return new DnnSPInstruction(cPOperand, cPOperand4, cPOperand2, str2, str);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 16);
        cPOperand.split(instructionPartsWithValueType[1]);
        CPOperand cPOperand5 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        cPOperand5.split(instructionPartsWithValueType[2]);
        CPOperand cPOperand6 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        cPOperand6.split(instructionPartsWithValueType[3]);
        cPOperand2.split(instructionPartsWithValueType[16]);
        ArrayList arrayList9 = new ArrayList();
        ArrayList arrayList10 = new ArrayList();
        ArrayList arrayList11 = new ArrayList();
        ArrayList arrayList12 = new ArrayList();
        arrayList9.add(new CPOperand(instructionPartsWithValueType[4]));
        arrayList9.add(new CPOperand(instructionPartsWithValueType[5]));
        arrayList10.add(new CPOperand(instructionPartsWithValueType[6]));
        arrayList10.add(new CPOperand(instructionPartsWithValueType[7]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[8]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[9]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[10]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[11]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[12]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[13]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[14]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[15]));
        return new DnnSPInstruction(cPOperand, cPOperand5, cPOperand6, cPOperand2, str2, str, arrayList9, arrayList10, arrayList11, arrayList12);
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sparkExecutionContext, String str, int i) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(str);
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(str);
        if (dataCharacteristics.getBlocksize() != 1) {
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(dataCharacteristics);
            matrixCharacteristics.setBlocksize(i);
            binaryMatrixBlockRDDHandleForVariable = RDDAggregateUtils.mergeByKey(binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new ExtractBlockForBinaryReblock(dataCharacteristics, matrixCharacteristics)));
        }
        return binaryMatrixBlockRDDHandleForVariable;
    }

    private static Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sparkExecutionContext, String str) {
        MatrixBlock matrixInput = sparkExecutionContext.getMatrixInput(str);
        sparkExecutionContext.releaseMatrixInput(str);
        return sparkExecutionContext.getSparkContext().broadcast(matrixInput);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        if (!this.instOpcode.equalsIgnoreCase("conv2d") && !this.instOpcode.equalsIgnoreCase("conv2d_bias_add") && !this.instOpcode.equalsIgnoreCase("maxpooling") && !this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
            throw new DMLRuntimeException("Not implemented: " + this.instOpcode);
        }
        String name = this.input1.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> reblockAsRectangularMatrices = reblockAsRectangularMatrices(sparkExecutionContext, name, 1);
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(name);
        Broadcast<MatrixBlock> broadcast = null;
        Broadcast<MatrixBlock> broadcast2 = null;
        if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            broadcast = getBroadcast(sparkExecutionContext, this._in2.getName());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            broadcast = getBroadcast(sparkExecutionContext, this._in3.getName());
            broadcast2 = getBroadcast(sparkExecutionContext, this._in2.getName());
        }
        int scalarInput = getScalarInput(executionContext, this._padding, 0);
        int scalarInput2 = getScalarInput(executionContext, this._padding, 1);
        int scalarInput3 = getScalarInput(executionContext, this._stride, 0);
        int scalarInput4 = getScalarInput(executionContext, this._stride, 1);
        int scalarInput5 = getScalarInput(executionContext, this._input_shape, 1);
        int scalarInput6 = getScalarInput(executionContext, this._input_shape, 2);
        int scalarInput7 = getScalarInput(executionContext, this._input_shape, 3);
        int scalarInput8 = getScalarInput(executionContext, this._filter_shape, 0);
        int scalarInput9 = getScalarInput(executionContext, this._filter_shape, 2);
        int scalarInput10 = getScalarInput(executionContext, this._filter_shape, 3);
        int p = (int) DnnUtils.getP(scalarInput6, scalarInput9, scalarInput3, scalarInput);
        int q = (int) DnnUtils.getQ(scalarInput7, scalarInput10, scalarInput4, scalarInput2);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), reblockAsRectangularMatrices.mapPartitionsToPair(new RDDConv2dMapMMFunction(broadcast, new DnnParameters(1, scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput3, scalarInput4, scalarInput, scalarInput2, 1), this.instOpcode, broadcast2, dataCharacteristics.getRows(), NativeHelper.isNativeLibraryLoaded()), true));
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        long j = scalarInput8 * p * q;
        if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
            j = scalarInput5 * p * q;
        }
        if (j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("The current operator doesnot support large outputs.");
        }
        sparkExecutionContext.setMetaData(this.output.getName(), new MetaDataFormat(new MatrixCharacteristics(dataCharacteristics.getRows(), j, 1, -1L), Types.FileFormat.BINARY));
    }

    private static int getScalarInput(ExecutionContext executionContext, ArrayList<CPOperand> arrayList, int i) {
        return (int) executionContext.getScalarInput(arrayList.get(i)).getLongValue();
    }
}
