package org.apache.sysds.runtime.instructions.spark;

import java.util.concurrent.Callable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CommonThreadPool;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.class */
public class AggregateUnarySPInstruction extends UnarySPInstruction {
    private AggBinaryOp.SparkAggType _aggtype;
    private AggregateOperator _aop;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDAggregateTask.class */
    public static class RDDAggregateTask implements Callable<MatrixBlock> {
        Operator _optr;
        AggregateOperator _aop;
        JavaPairRDD<MatrixIndexes, MatrixBlock> _in;
        DataCharacteristics _mc;

        RDDAggregateTask(Operator operator, AggregateOperator aggregateOperator, JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, DataCharacteristics dataCharacteristics) {
            this._optr = operator;
            this._aop = aggregateOperator;
            this._in = javaPairRDD;
            this._mc = dataCharacteristics;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public MatrixBlock call() {
            AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
            JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD = this._in;
            if (aggregateUnaryOperator.sparseSafe) {
                javaPairRDD = javaPairRDD.filter(new FilterNonEmptyBlocksFunction());
            }
            MatrixBlock aggStable = RDDAggregateUtils.aggStable((JavaRDD<MatrixBlock>) javaPairRDD.map(new RDDUAggFunction2(aggregateUnaryOperator, this._mc.getBlocksize())), this._aop);
            aggStable.dropLastRowsOrColumns(this._aop.correction);
            return aggStable;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUAggFunction.class */
    public static class RDDUAggFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2672082409287856038L;
        private AggregateUnaryOperator _op;
        private int _blen;

        public RDDUAggFunction(AggregateUnaryOperator aggregateUnaryOperator, int i) {
            this._op = null;
            this._blen = -1;
            this._op = aggregateUnaryOperator;
            this._blen = i;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            MatrixIndexes matrixIndexes2 = new MatrixIndexes();
            MatrixBlock matrixBlock2 = new MatrixBlock();
            OperationsOnMatrixValues.performAggregateUnary(matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock2, this._op, this._blen);
            return new Tuple2<>(matrixIndexes2, matrixBlock2);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUAggFunction2.class */
    public static class RDDUAggFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 2672082409287856038L;
        private AggregateUnaryOperator _op;
        private int _blen;

        public RDDUAggFunction2(AggregateUnaryOperator aggregateUnaryOperator, int i) {
            this._op = null;
            this._blen = -1;
            this._op = aggregateUnaryOperator;
            this._blen = i;
            this._blen = i;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return (MatrixBlock) ((MatrixBlock) tuple2._2).aggregateUnaryOperations(this._op, new MatrixBlock(), this._blen, (MatrixIndexes) tuple2._1());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUAggValueFunction.class */
    public static class RDDUAggValueFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5352374590399929673L;
        private AggregateUnaryOperator _op;
        private int _blen;
        private MatrixIndexes _ix;

        public RDDUAggValueFunction(AggregateUnaryOperator aggregateUnaryOperator, int i) {
            this._op = null;
            this._blen = -1;
            this._ix = null;
            this._op = aggregateUnaryOperator;
            this._blen = i;
            this._ix = new MatrixIndexes(1L, 1L);
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            MatrixBlock matrixBlock2 = new MatrixBlock();
            matrixBlock.aggregateUnaryOperations(this._op, (MatrixValue) matrixBlock2, this._blen, this._ix, true);
            return matrixBlock2;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUTensorAggFunction2.class */
    public static class RDDUTensorAggFunction2 implements Function<Tuple2<TensorIndexes, TensorBlock>, TensorBlock> {
        private static final long serialVersionUID = -6258769067791011763L;
        private AggregateUnaryOperator _op;

        public RDDUTensorAggFunction2(AggregateUnaryOperator aggregateUnaryOperator) {
            this._op = null;
            this._op = aggregateUnaryOperator;
        }

        public TensorBlock call(Tuple2<TensorIndexes, TensorBlock> tuple2) throws Exception {
            return new TensorBlock(((TensorBlock) tuple2._2).getBasicTensor().aggregateUnaryOperations(this._op, new BasicTensorBlock()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUTensorAggValueFunction.class */
    public static class RDDUTensorAggValueFunction implements Function<TensorBlock, TensorBlock> {
        private static final long serialVersionUID = -968274963539513423L;
        private AggregateUnaryOperator _op;

        public RDDUTensorAggValueFunction(AggregateUnaryOperator aggregateUnaryOperator) {
            this._op = null;
            this._op = aggregateUnaryOperator;
        }

        public TensorBlock call(TensorBlock tensorBlock) throws Exception {
            BasicTensorBlock basicTensorBlock = new BasicTensorBlock();
            tensorBlock.getBasicTensor().aggregateUnaryOperations(this._op, basicTensorBlock);
            TensorBlock tensorBlock2 = new TensorBlock(basicTensorBlock.getValueType(), new int[]{1, 1});
            tensorBlock2.set(0, 0, basicTensorBlock.get(0, 0));
            return tensorBlock2;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AggregateUnarySPInstruction(SPInstruction.SPType sPType, AggregateUnaryOperator aggregateUnaryOperator, AggregateOperator aggregateOperator, CPOperand cPOperand, CPOperand cPOperand2, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(sPType, aggregateUnaryOperator, cPOperand, cPOperand2, str, str2);
        this._aggtype = null;
        this._aop = null;
        this._aggtype = sparkAggType;
        this._aop = aggregateOperator;
    }

    public static AggregateUnarySPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        AggBinaryOp.SparkAggType valueOf = AggBinaryOp.SparkAggType.valueOf(instructionPartsWithValueType[3]);
        String deriveAggregateOperatorOpcode = InstructionUtils.deriveAggregateOperatorOpcode(str2);
        Types.CorrectionLocationType deriveAggregateOperatorCorrectionLocation = InstructionUtils.deriveAggregateOperatorCorrectionLocation(str2);
        return new AggregateUnarySPInstruction(SPInstruction.SPType.AggregateUnary, InstructionUtils.parseBasicAggregateUnaryOperator(str2), InstructionUtils.parseAggregateOperator(deriveAggregateOperatorOpcode, deriveAggregateOperatorCorrectionLocation.toString()), cPOperand, cPOperand2, valueOf, str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (this.input1.getDataType() == Types.DataType.MATRIX) {
            processMatrixAggregate(executionContext);
        } else {
            processTensorAggregate(executionContext);
        }
    }

    private void processMatrixAggregate(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD = binaryMatrixBlockRDDHandleForVariable;
        if (getOpcode().equalsIgnoreCase("uaktrace")) {
            javaPairRDD = javaPairRDD.filter(new FilterDiagMatrixBlocksFunction());
        }
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        AggregateOperator aggregateOperator = this._aop;
        if (this._aggtype != AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            if (this._aggtype == AggBinaryOp.SparkAggType.NONE) {
                javaPairRDD = javaPairRDD.mapValues(new RDDUAggValueFunction(aggregateUnaryOperator, dataCharacteristics.getBlocksize()));
            } else if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
                javaPairRDD = RDDAggregateUtils.aggByKeyStable(javaPairRDD.mapToPair(new RDDUAggFunction(aggregateUnaryOperator, dataCharacteristics.getBlocksize())), aggregateOperator, false);
                if (aggregateUnaryOperator.aggOp.existsCorrection()) {
                    javaPairRDD = javaPairRDD.mapValues(new AggregateDropCorrectionFunction(aggregateOperator));
                }
            }
            updateUnaryAggOutputDataCharacteristics(sparkExecutionContext, aggregateUnaryOperator.indexFn);
            sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), javaPairRDD);
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
            return;
        }
        if (ConfigurationManager.isMaxPrallelizeEnabled()) {
            try {
                sparkExecutionContext.setMatrixOutputAndLineage(this.output.getName(), CommonThreadPool.getDynamicPool().submit(new RDDAggregateTask(this._optr, this._aop, binaryMatrixBlockRDDHandleForVariable, dataCharacteristics)), !LineageCacheConfig.ReuseCacheType.isNone() ? (LineageItem) getLineageItem(executionContext).getValue() : null);
                return;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        if (aggregateUnaryOperator.sparseSafe) {
            javaPairRDD = javaPairRDD.filter(new FilterNonEmptyBlocksFunction());
        }
        MatrixBlock aggStable = RDDAggregateUtils.aggStable((JavaRDD<MatrixBlock>) javaPairRDD.map(new RDDUAggFunction2(aggregateUnaryOperator, dataCharacteristics.getBlocksize())), aggregateOperator);
        aggStable.dropLastRowsOrColumns(aggregateOperator.correction);
        sparkExecutionContext.setMatrixOutput(this.output.getName(), aggStable);
    }

    private void processTensorAggregate(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<TensorIndexes, TensorBlock> binaryTensorBlockRDDHandleForVariable = sparkExecutionContext.getBinaryTensorBlockRDDHandleForVariable(this.input1.getName());
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        AggregateOperator aggregateOperator = this._aop;
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            TensorBlock aggStableTensor = RDDAggregateUtils.aggStableTensor((JavaRDD<TensorBlock>) binaryTensorBlockRDDHandleForVariable.map(new RDDUTensorAggFunction2(aggregateUnaryOperator)), aggregateOperator);
            TensorBlock tensorBlock = new TensorBlock(aggStableTensor.getValueType(), new int[]{1, 1});
            tensorBlock.set(0, 0, aggStableTensor.get(0, 0));
            sparkExecutionContext.setTensorOutput(this.output.getName(), tensorBlock);
            return;
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.NONE) {
            binaryTensorBlockRDDHandleForVariable = binaryTensorBlockRDDHandleForVariable.mapValues(new RDDUTensorAggValueFunction(aggregateUnaryOperator));
        } else if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
            throw new DMLRuntimeException("Multi block spark aggregations are not supported for tensors yet.");
        }
        updateUnaryAggOutputDataCharacteristics(sparkExecutionContext, aggregateUnaryOperator.indexFn);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), binaryTensorBlockRDDHandleForVariable);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
    }

    public AggBinaryOp.SparkAggType getAggType() {
        return this._aggtype;
    }
}
