package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInList;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDSortUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.class */
public class ReorgSPInstruction extends UnarySPInstruction {
    private static final Log LOG = LogFactory.getLog(ReorgSPInstruction.class.getName());
    private CPOperand _col;
    private CPOperand _desc;
    private CPOperand _ixret;
    private boolean _bSortIndInMem;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction$ExtractColumn.class */
    private static class ExtractColumn implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -1472164797288449559L;
        private int _col;

        public ExtractColumn(int i) {
            this._col = i;
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return matrixBlock.slice(0, matrixBlock.getNumRows() - 1, this._col, this._col, (CacheBlock) new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction$ExtractColumns.class */
    private static class ExtractColumns implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2902729186431711506L;
        private final long[] _cols;
        private final int _blen;

        public ExtractColumns(long[] jArr, DataCharacteristics dataCharacteristics) {
            this._cols = jArr;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            MatrixBlock matrixBlock2 = new MatrixBlock(matrixBlock.getNumRows(), this._cols.length, true);
            for (int i = 0; i < this._cols.length; i++) {
                if (UtilFunctions.isInBlockRange(matrixIndexes, this._blen, new IndexRange(1L, Long.MAX_VALUE, this._cols[i], this._cols[i]))) {
                    int computeCellInBlock = UtilFunctions.computeCellInBlock(this._cols[i], this._blen);
                    matrixBlock2.leftIndexingOperations(matrixBlock.slice(0, matrixBlock.getNumRows() - 1, computeCellInBlock, computeCellInBlock, (CacheBlock) new MatrixBlock()), 0, matrixBlock.getNumRows() - 1, i, i, matrixBlock2, MatrixObject.UpdateType.INPLACE);
                }
            }
            return new Tuple2<>(new MatrixIndexes(matrixIndexes.getRowIndex(), 1L), matrixBlock2);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction$RDDDiagV2MFunction.class */
    private static class RDDDiagV2MFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 31065772250744103L;
        private ReorgOperator _reorgOp;
        private DataCharacteristics _mcIn;

        public RDDDiagV2MFunction(DataCharacteristics dataCharacteristics) {
            this._reorgOp = null;
            this._mcIn = null;
            this._reorgOp = new ReorgOperator(DiagIndex.getDiagIndexFnObject());
            this._mcIn = dataCharacteristics;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
            ArrayList arrayList = new ArrayList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long rowIndex = matrixIndexes.getRowIndex();
            MatrixIndexes matrixIndexes2 = new MatrixIndexes(rowIndex, rowIndex);
            arrayList.add(new Tuple2(matrixIndexes2, matrixBlock.reorgOperations(this._reorgOp, (MatrixValue) new MatrixBlock(), -1, -1, -1)));
            int ceil = (int) Math.ceil(this._mcIn.getRows() / this._mcIn.getBlocksize());
            for (int i = 1; i <= ceil; i++) {
                if (i != matrixIndexes2.getColumnIndex()) {
                    arrayList.add(new Tuple2(new MatrixIndexes(rowIndex, i), new MatrixBlock(UtilFunctions.computeBlockSize(this._mcIn.getRows(), rowIndex, this._mcIn.getBlocksize()), UtilFunctions.computeBlockSize(this._mcIn.getRows(), i, this._mcIn.getBlocksize()), true)));
                }
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction$RDDRevFunction.class */
    private static class RDDRevFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1183373828539843938L;
        private DataCharacteristics _mcIn;

        public RDDRevFunction(DataCharacteristics dataCharacteristics) {
            this._mcIn = null;
            this._mcIn = dataCharacteristics;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
            IndexedMatrixValue indexedMatrixBlock = SparkUtils.toIndexedMatrixBlock(tuple2);
            ArrayList arrayList = new ArrayList();
            LibMatrixReorg.rev(indexedMatrixBlock, this._mcIn.getRows(), this._mcIn.getBlocksize(), arrayList);
            return SparkUtils.fromIndexedMatrixBlock(arrayList).iterator();
        }
    }

    private ReorgSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2) {
        super(SPInstruction.SPType.Reorg, operator, cPOperand, cPOperand2, str, str2);
        this._col = null;
        this._desc = null;
        this._ixret = null;
        this._bSortIndInMem = false;
    }

    private ReorgSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, String str, boolean z, String str2) {
        this(operator, cPOperand, cPOperand5, str, str2);
        this._col = cPOperand2;
        this._desc = cPOperand3;
        this._ixret = cPOperand4;
        this._bSortIndInMem = z;
    }

    public static ReorgSPInstruction parseInstruction(String str) {
        CPOperand cPOperand = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String opCode = InstructionUtils.getOpCode(str);
        if (opCode.equalsIgnoreCase("r'")) {
            parseUnaryInstruction(str, cPOperand, cPOperand2);
            return new ReorgSPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, opCode, str);
        }
        if (opCode.equalsIgnoreCase("rev")) {
            parseUnaryInstruction(str, cPOperand, cPOperand2);
            return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), cPOperand, cPOperand2, opCode, str);
        }
        if (opCode.equalsIgnoreCase("rdiag")) {
            parseUnaryInstruction(str, cPOperand, cPOperand2);
            return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), cPOperand, cPOperand2, opCode, str);
        }
        if (!opCode.equalsIgnoreCase("rsort")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a ReorgInstruction: " + str);
        }
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5, 6);
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[5]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[3]);
        CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[4]);
        boolean z = false;
        if (instructionPartsWithValueType.length > 5) {
            z = Boolean.parseBoolean(instructionPartsWithValueType[6]);
        }
        return new ReorgSPInstruction(new ReorgOperator(new SortIndex(1, false, false)), cPOperand, cPOperand3, cPOperand4, cPOperand5, cPOperand2, opCode, z, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> sortDataByValMemSort;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String opcode = getOpcode();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        if (opcode.equalsIgnoreCase("r'")) {
            sortDataByValMemSort = binaryMatrixBlockRDDHandleForVariable.mapToPair(new ReorgMapFunction(opcode));
        } else if (opcode.equalsIgnoreCase("rev")) {
            sortDataByValMemSort = binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new RDDRevFunction(dataCharacteristics));
            if (dataCharacteristics.getRows() % dataCharacteristics.getBlocksize() != 0) {
                sortDataByValMemSort = RDDAggregateUtils.mergeByKey(sortDataByValMemSort, false);
            }
        } else if (opcode.equalsIgnoreCase("rdiag")) {
            sortDataByValMemSort = dataCharacteristics.getCols() == 1 ? binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new RDDDiagV2MFunction(dataCharacteristics)) : binaryMatrixBlockRDDHandleForVariable.filter(new FilterDiagMatrixBlocksFunction()).mapToPair(new ReorgMapFunction(opcode));
        } else {
            if (!opcode.equalsIgnoreCase("rsort")) {
                throw new DMLRuntimeException("Error: Incorrect opcode in ReorgSPInstruction:" + opcode);
            }
            long[] convertToLongVector = this._col.getDataType().isMatrix() ? DataConverter.convertToLongVector(executionContext.getMatrixInput(this._col.getName())) : new long[]{executionContext.getScalarInput(this._col).getLongValue()};
            boolean booleanValue = executionContext.getScalarInput(this._desc).getBooleanValue();
            boolean booleanValue2 = executionContext.getScalarInput(this._ixret).getBooleanValue();
            boolean z = dataCharacteristics.getCols() == 1;
            JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD = binaryMatrixBlockRDDHandleForVariable;
            if (convertToLongVector.length > dataCharacteristics.getBlocksize()) {
                LOG.warn("Unsupported sort with number of order-by columns large than blocksize: " + convertToLongVector.length);
            }
            if (z || convertToLongVector.length == 1) {
                if (!z) {
                    javaPairRDD = javaPairRDD.filter(new IsBlockInRange(1L, dataCharacteristics.getRows(), convertToLongVector[0], convertToLongVector[0], dataCharacteristics)).mapValues(new ExtractColumn(UtilFunctions.computeCellInBlock(convertToLongVector[0], dataCharacteristics.getBlocksize())));
                }
                if (booleanValue2) {
                    sortDataByValMemSort = RDDSortUtils.sortIndexesByVal(javaPairRDD, !booleanValue, dataCharacteristics.getRows(), dataCharacteristics.getBlocksize());
                } else if (z && !booleanValue) {
                    sortDataByValMemSort = RDDSortUtils.sortByVal(javaPairRDD, dataCharacteristics.getRows(), dataCharacteristics.getBlocksize());
                } else if (this._bSortIndInMem) {
                    sortDataByValMemSort = RDDSortUtils.sortDataByValMemSort(javaPairRDD, binaryMatrixBlockRDDHandleForVariable, !booleanValue, dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), sparkExecutionContext, (ReorgOperator) this._optr);
                } else {
                    sortDataByValMemSort = RDDSortUtils.sortDataByVal(javaPairRDD, binaryMatrixBlockRDDHandleForVariable, !booleanValue, dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize());
                }
            } else {
                if (convertToLongVector.length < dataCharacteristics.getCols()) {
                    javaPairRDD = javaPairRDD.filter(new IsBlockInList(convertToLongVector, dataCharacteristics)).mapToPair(new ExtractColumns(convertToLongVector, dataCharacteristics));
                }
                if (dataCharacteristics.getCols() > dataCharacteristics.getBlocksize()) {
                    javaPairRDD = RDDAggregateUtils.mergeByKey(javaPairRDD);
                }
                if (booleanValue2) {
                    sortDataByValMemSort = RDDSortUtils.sortIndexesByVals(javaPairRDD, !booleanValue, dataCharacteristics.getRows(), convertToLongVector.length, dataCharacteristics.getBlocksize());
                } else if (convertToLongVector.length != dataCharacteristics.getCols() || booleanValue) {
                    sortDataByValMemSort = RDDSortUtils.sortDataByVals(javaPairRDD, binaryMatrixBlockRDDHandleForVariable, !booleanValue, dataCharacteristics.getRows(), dataCharacteristics.getCols(), convertToLongVector.length, dataCharacteristics.getBlocksize());
                } else {
                    sortDataByValMemSort = RDDSortUtils.sortByVals(javaPairRDD, dataCharacteristics.getRows(), convertToLongVector.length, dataCharacteristics.getBlocksize());
                }
            }
        }
        if (opcode.equalsIgnoreCase("rsort") && this._col.getDataType().isMatrix()) {
            sparkExecutionContext.releaseMatrixInput(this._col.getName());
        }
        updateReorgDataCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), sortDataByValMemSort);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
    }

    private void updateReorgDataCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        if (!dataCharacteristics2.dimsKnown()) {
            if (!dataCharacteristics.dimsKnown()) {
                throw new DMLRuntimeException("Unable to compute output matrix characteristics from input.");
            }
            if (getOpcode().equalsIgnoreCase("r'")) {
                dataCharacteristics2.set(dataCharacteristics.getCols(), dataCharacteristics.getRows(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            } else if (getOpcode().equalsIgnoreCase("rdiag")) {
                dataCharacteristics2.set(dataCharacteristics.getRows(), dataCharacteristics.getCols() > 1 ? 1L : dataCharacteristics.getRows(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            } else if (getOpcode().equalsIgnoreCase("rsort")) {
                dataCharacteristics2.set(dataCharacteristics.getRows(), sparkExecutionContext.getScalarInput(this._ixret).getBooleanValue() ? 1L : dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            }
        }
        if (dataCharacteristics2.nnzKnown() || !dataCharacteristics.nnzKnown()) {
            return;
        }
        if (getOpcode().equalsIgnoreCase("rsort") && sparkExecutionContext.getScalarInput(this._ixret.getName(), this._ixret.getValueType(), this._ixret.isLiteral()).getBooleanValue()) {
            dataCharacteristics2.setNonZeros(dataCharacteristics.getRows());
        } else {
            dataCharacteristics2.setNonZeros(dataCharacteristics.getNonZeros());
        }
    }
}
