package org.apache.sysds.runtime.controlprogram.parfor;

import java.util.Arrays;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.JobConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.io.InputOutputInfo;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.class */
public class ResultMergeRemoteSpark extends ResultMergeMatrix {
    private static final long serialVersionUID = -6924566953903424820L;
    private ExecutionContext _ec;
    private int _numMappers;
    private int _numReducers;

    public ResultMergeRemoteSpark(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, String str, boolean z, ExecutionContext executionContext, int i, int i2) {
        super(matrixObject, matrixObjectArr, str, z);
        this._ec = null;
        this._numMappers = -1;
        this._numReducers = -1;
        this._ec = executionContext;
        this._numMappers = i;
        this._numReducers = i2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.sysds.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeSerialMerge() {
        return executeParallelMerge(this._numMappers);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.sysds.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeParallelMerge(int i) {
        MatrixObject matrixObject;
        if (LOG.isTraceEnabled()) {
            LOG.trace("ResultMerge (remote, spark): Execute serial merge for output " + ((MatrixObject) this._output).hashCode() + " (fname=" + ((MatrixObject) this._output).getFileName() + ")");
        }
        try {
            if (this._inputs == 0 || ((MatrixObject[]) this._inputs).length <= 0) {
                matrixObject = (MatrixObject) this._output;
            } else {
                MetaDataFormat metaDataFormat = (MetaDataFormat) ((MatrixObject) this._output).getMetaData();
                DataCharacteristics dataCharacteristics = metaDataFormat.getDataCharacteristics();
                RDDObject executeMerge = executeMerge(dataCharacteristics.getNonZeros() == 0 ? null : (MatrixObject) this._output, (MatrixObject[]) this._inputs, dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize());
                matrixObject = new MatrixObject(((MatrixObject) this._output).getValueType(), this._outputFName);
                MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(dataCharacteristics);
                matrixCharacteristics.setNonZeros(this._isAccum ? -1L : computeNonZeros((MatrixObject) this._output, Arrays.asList(this._inputs)));
                matrixObject.setMetaData(new MetaDataFormat(matrixCharacteristics, metaDataFormat.getFileFormat()));
                matrixObject.setRDDHandle(executeMerge);
            }
            return matrixObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    protected RDDObject executeMerge(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, long j, long j2, int i) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) this._ec;
        boolean z = matrixObject != null;
        int determineNumReducers = determineNumReducers(j, j2, i, this._numReducers);
        if (matrixObjectArr == null || matrixObjectArr.length == 0) {
            throw new DMLRuntimeException("Execute merge should never be called with no inputs.");
        }
        try {
            InputOutputInfo inputOutputInfo = InputOutputInfo.get(Types.DataType.MATRIX, Types.FileFormat.BINARY);
            JobConf jobConf = new JobConf("test");
            jobConf.setJobName("ParFor-RMSP");
            jobConf.setInputFormat(inputOutputInfo.inputFormatClass);
            Path[] pathArr = new Path[matrixObjectArr.length];
            for (int i2 = 0; i2 < pathArr.length; i2++) {
                matrixObjectArr[i2].exportData();
                pathArr[i2] = new Path(matrixObjectArr[i2].getFileName());
                setRDDHandleForMerge(matrixObjectArr[i2], sparkExecutionContext);
            }
            FileInputFormat.setInputPaths(jobConf, pathArr);
            JavaPairRDD mapPartitionsToPair = sparkExecutionContext.getSparkContext().hadoopRDD(jobConf, inputOutputInfo.inputFormatClass, inputOutputInfo.keyClass, inputOutputInfo.valueClass).mapPartitionsToPair(new CopyMatrixBlockPairFunction(true), true);
            RDDObject rDDObject = new RDDObject(z ? mapPartitionsToPair.groupByKey(determineNumReducers).join(sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject, Types.FileFormat.BINARY)).mapToPair(new ResultMergeRemoteSparkWCompare(this._isAccum)) : this._isAccum ? RDDAggregateUtils.sumByKeyStable(mapPartitionsToPair, false) : RDDAggregateUtils.mergeByKey(mapPartitionsToPair, false));
            for (int i3 = 0; i3 < pathArr.length; i3++) {
                rDDObject.addLineageChild(matrixObjectArr[i3].getRDDHandle());
            }
            if (z) {
                rDDObject.addLineageChild(matrixObject.getRDDHandle());
            }
            Statistics.incrementNoOfCompiledSPInst();
            Statistics.incrementNoOfExecutedSPInst();
            if (DMLScript.STATISTICS) {
                Statistics.maintainCPHeavyHitters("ParFor-RMSP", System.nanoTime() - nanoTime);
            }
            return rDDObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static int determineNumReducers(long j, long j2, int i, long j3) {
        return (int) Math.min(j3, Math.max(j / i, 1L) * Math.max(j2 / i, 1L));
    }

    private static void setRDDHandleForMerge(MatrixObject matrixObject, SparkExecutionContext sparkExecutionContext) {
        InputOutputInfo inputOutputInfo = InputOutputInfo.get(Types.DataType.MATRIX, Types.FileFormat.BINARY);
        RDDObject rDDObject = new RDDObject(sparkExecutionContext.getSparkContext().hadoopFile(matrixObject.getFileName(), inputOutputInfo.inputFormatClass, inputOutputInfo.keyClass, inputOutputInfo.valueClass));
        rDDObject.setHDFSFile(true);
        matrixObject.setRDDHandle(rDDObject);
    }
}
