package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.class */
public class CheckpointSPInstruction extends UnarySPInstruction {
    private StorageLevel _level;

    private CheckpointSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, StorageLevel storageLevel, String str, String str2) {
        super(SPInstruction.SPType.Checkpoint, operator, cPOperand, cPOperand2, str, str2);
        this._level = null;
        this._level = storageLevel;
    }

    public static CheckpointSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        return new CheckpointSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), StorageLevel.fromString(instructionPartsWithValueType[3]), instructionPartsWithValueType[0], str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        JavaPairRDD<?, ?> javaPairRDD;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        if (sparkExecutionContext.getVariable(this.input1.getName()) == null || (sparkExecutionContext.getVariable(this.input1.getName()) instanceof BooleanObject)) {
            sparkExecutionContext.setVariable(this.input1.getName(), new BooleanObject(false));
            sparkExecutionContext.setVariable(this.output.getName(), new BooleanObject(false));
            return;
        }
        CacheableData<?> cacheableData = sparkExecutionContext.getCacheableData(this.input1.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        if (cacheableData.isCached(true) || Recompiler.checkCPCheckpoint(dataCharacteristics) || ((sparkExecutionContext.getCacheableData(this.input1.getName()) instanceof MatrixObject) && sparkExecutionContext.getMatrixObject(this.input1.getName()).isFederated())) {
            sparkExecutionContext.setVariable(this.output.getName(), cacheableData);
            Statistics.decrementNoOfExecutedSPInst();
            return;
        }
        JavaPairRDD<?, ?> rDDHandleForVariable = sparkExecutionContext.getRDDHandleForVariable(this.input1.getName(), Types.FileFormat.BINARY, -1, true);
        JavaPairRDD<?, ?> javaPairRDD2 = null;
        if (rDDHandleForVariable.getStorageLevel().equals(this._level)) {
            javaPairRDD = rDDHandleForVariable;
        } else {
            int numPreferredPartitions = SparkUtils.getNumPreferredPartitions(dataCharacteristics, rDDHandleForVariable);
            boolean z = 1.2d * ((double) numPreferredPartitions) < ((double) rDDHandleForVariable.getNumPartitions()) && !SparkUtils.isHashPartitioned(rDDHandleForVariable) && rDDHandleForVariable.getNumPartitions() > SparkExecutionContext.getDefaultParallelism(true);
            boolean z2 = dataCharacteristics.dimsKnown(true) && dataCharacteristics.isUltraSparse() && numPreferredPartitions > rDDHandleForVariable.getNumPartitions();
            boolean z3 = this.input1.getDataType() == Types.DataType.MATRIX && OptimizerUtils.checkSparseBlockCSRConversion(dataCharacteristics) && !this._level.equals(Checkpoint.SER_STORAGE_LEVEL);
            if (z) {
                javaPairRDD2 = rDDHandleForVariable.coalesce(numPreferredPartitions);
            } else if (z2) {
                javaPairRDD2 = rDDHandleForVariable.repartition(UtilFunctions.roundToNext(numPreferredPartitions, SparkExecutionContext.getDefaultParallelism(true)));
            } else if (z3) {
                javaPairRDD2 = rDDHandleForVariable;
            } else if (this.input1.getDataType() == Types.DataType.MATRIX) {
                javaPairRDD2 = SparkUtils.copyBinaryBlockMatrix(rDDHandleForVariable, false);
            } else if (this.input1.getDataType() == Types.DataType.FRAME) {
                javaPairRDD2 = rDDHandleForVariable.mapValues(new CopyFrameBlockFunction(false));
            }
            if (z3) {
                javaPairRDD2 = javaPairRDD2.mapValues(new CreateSparseBlockFunction(SparseBlock.Type.CSR));
            }
            javaPairRDD = javaPairRDD2.persist(this._level);
            if (this.input1.isMatrix() && dataCharacteristics.dimsKnown() && !dataCharacteristics.dimsKnown(true) && !OptimizerUtils.isValidCPDimensions(dataCharacteristics)) {
                dataCharacteristics.setNonZeros(SparkUtils.getNonZeros(javaPairRDD));
            }
        }
        CacheableData<?> cacheableData2 = sparkExecutionContext.getCacheableData(this.input1.getName());
        if (javaPairRDD != rDDHandleForVariable) {
            RDDObject rDDHandle = cacheableData2.getRDDHandle();
            RDDObject rDDObject = new RDDObject(javaPairRDD);
            rDDObject.setCheckpointRDD(true);
            rDDObject.addLineageChild(rDDHandle);
            cacheableData2.setRDDHandle(rDDObject);
        }
        sparkExecutionContext.setVariable(this.output.getName(), cacheableData2);
    }
}
