package org.apache.sysds.runtime.instructions.spark;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FileFormatPropertiesLIBSVM;
import org.apache.sysds.runtime.lineage.LineageDedupUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/WriteSPInstruction.class */
public class WriteSPInstruction extends SPInstruction implements LineageTraceable {
    public CPOperand input1;
    private CPOperand input2;
    private CPOperand input3;
    private CPOperand input4;
    private FileFormatProperties formatProperties;

    private WriteSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(SPInstruction.SPType.Write, str, str2);
        this.input1 = null;
        this.input2 = null;
        this.input3 = null;
        this.input4 = null;
        this.input1 = cPOperand;
        this.input2 = cPOperand2;
        this.input3 = cPOperand3;
        this.formatProperties = null;
    }

    public static WriteSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equals("write")) {
            throw new DMLRuntimeException("Unsupported opcode");
        }
        if (instructionPartsWithValueType.length != 6 && instructionPartsWithValueType.length != 10) {
            throw new DMLRuntimeException("Invalid number of operands in write instruction: " + str);
        }
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        WriteSPInstruction writeSPInstruction = new WriteSPInstruction(cPOperand, cPOperand2, cPOperand3, str2, str);
        if (cPOperand3.getName().equalsIgnoreCase("csv")) {
            writeSPInstruction.setFormatProperties(new FileFormatPropertiesCSV(Boolean.parseBoolean(instructionPartsWithValueType[4]), instructionPartsWithValueType[5], Boolean.parseBoolean(instructionPartsWithValueType[6])));
            writeSPInstruction.input4 = new CPOperand(instructionPartsWithValueType[8]);
        } else if (cPOperand3.getName().equalsIgnoreCase("libsvm")) {
            writeSPInstruction.setFormatProperties(new FileFormatPropertiesLIBSVM(instructionPartsWithValueType[4], instructionPartsWithValueType[5], Boolean.parseBoolean(instructionPartsWithValueType[6])));
            writeSPInstruction.input4 = new CPOperand(instructionPartsWithValueType[8]);
        } else {
            FileFormatProperties fileFormatProperties = new FileFormatProperties();
            writeSPInstruction.input4 = new CPOperand(instructionPartsWithValueType[5]);
            writeSPInstruction.setFormatProperties(fileFormatProperties);
        }
        return writeSPInstruction;
    }

    public FileFormatProperties getFormatProperties() {
        return this.formatProperties;
    }

    public void setFormatProperties(FileFormatProperties fileFormatProperties) {
        this.formatProperties = fileFormatProperties;
    }

    public CPOperand getInput1() {
        return this.input1;
    }

    public CPOperand getInput2() {
        return this.input2;
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String stringValue = executionContext.getScalarInput(this.input2.getName(), Types.ValueType.STRING, this.input2.isLiteral()).getStringValue();
        this.formatProperties.setDescription(executionContext.getScalarInput(this.input4.getName(), Types.ValueType.STRING, this.input4.isLiteral()).getStringValue());
        Types.ValueType[] schema = this.input1.getDataType() == Types.DataType.FRAME ? sparkExecutionContext.getFrameObject(this.input1.getName()).getSchema() : null;
        try {
            HDFSTool.deleteFileIfExistOnHDFS(stringValue);
            Types.FileFormat safeValueOf = Types.FileFormat.safeValueOf(this.input3.getName());
            switch (this.input1.getDataType()) {
                case MATRIX:
                    processMatrixWriteInstruction(sparkExecutionContext, stringValue, safeValueOf);
                    break;
                case FRAME:
                    processFrameWriteInstruction(sparkExecutionContext, stringValue, safeValueOf, schema);
                    break;
                default:
                    throw new DMLRuntimeException("Unsupported data type " + this.input1.getDataType() + " in WriteSPInstruction.");
            }
        } catch (IOException e) {
            throw new DMLRuntimeException("Failed to process write instruction", e);
        }
    }

    protected void processMatrixWriteInstruction(SparkExecutionContext sparkExecutionContext, String str, Types.FileFormat fileFormat) throws IOException {
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        if (fileFormat == Types.FileFormat.MM || fileFormat == Types.FileFormat.TEXT) {
            LongAccumulator longAccumulator = null;
            if (!dataCharacteristics.nnzKnown()) {
                longAccumulator = sparkExecutionContext.getSparkContext().sc().longAccumulator(DataExpression.READNNZPARAM);
                binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.mapValues(new ComputeBinaryBlockNnzFunction(longAccumulator));
            }
            JavaRDD javaRDD = null;
            if (fileFormat == Types.FileFormat.MM) {
                ArrayList arrayList = new ArrayList(1);
                arrayList.add("%%MatrixMarket matrix coordinate real general\n" + dataCharacteristics.getRows() + " " + dataCharacteristics.getCols() + " " + dataCharacteristics.getNonZeros());
                javaRDD = sparkExecutionContext.getSparkContext().parallelize(arrayList);
            }
            JavaRDD<String> binaryBlockToTextCell = RDDConverterUtils.binaryBlockToTextCell(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics);
            if (javaRDD != null) {
                customSaveTextFile(javaRDD.union(binaryBlockToTextCell), str, true);
            } else {
                customSaveTextFile(binaryBlockToTextCell, str, false);
            }
            if (!dataCharacteristics.nnzKnown()) {
                dataCharacteristics.setNonZeros(longAccumulator.value().longValue());
            }
        } else if (fileFormat == Types.FileFormat.CSV) {
            if (dataCharacteristics.getRows() == 0 || dataCharacteristics.getCols() == 0) {
                throw new IOException("Write of matrices with zero rows or columns not supported (" + dataCharacteristics.getRows() + "x" + dataCharacteristics.getCols() + ").");
            }
            LongAccumulator longAccumulator2 = null;
            if (!dataCharacteristics.nnzKnown()) {
                longAccumulator2 = sparkExecutionContext.getSparkContext().sc().longAccumulator(DataExpression.READNNZPARAM);
                binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.mapValues(new ComputeBinaryBlockNnzFunction(longAccumulator2));
            }
            customSaveTextFile(RDDConverterUtils.binaryBlockToCsv(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, (FileFormatPropertiesCSV) this.formatProperties, true), str, false);
            if (!dataCharacteristics.nnzKnown()) {
                dataCharacteristics.setNonZeros(longAccumulator2.value().longValue());
            }
        } else if (fileFormat == Types.FileFormat.BINARY) {
            int parseInt = Integer.parseInt(this.input4.getName());
            DataCharacteristics blocksize = new MatrixCharacteristics(dataCharacteristics).setBlocksize(parseInt);
            if (ConfigurationManager.getBlocksize() != parseInt) {
                binaryMatrixBlockRDDHandleForVariable = RDDConverterUtils.binaryBlockToBinaryBlock(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, blocksize);
            }
            LongAccumulator longAccumulator3 = null;
            if (!dataCharacteristics.nnzKnown()) {
                longAccumulator3 = sparkExecutionContext.getSparkContext().sc().longAccumulator(DataExpression.READNNZPARAM);
                binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.mapValues(new ComputeBinaryBlockNnzFunction(longAccumulator3));
            }
            binaryMatrixBlockRDDHandleForVariable.saveAsHadoopFile(str, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);
            if (!dataCharacteristics.nnzKnown()) {
                dataCharacteristics.setNonZeros(longAccumulator3.value().longValue());
            }
        } else {
            if (fileFormat != Types.FileFormat.LIBSVM) {
                throw new DMLRuntimeException("Unexpected data format: " + fileFormat.toString());
            }
            if (dataCharacteristics.getRows() == 0 || dataCharacteristics.getCols() == 0) {
                throw new IOException("Write of matrices with zero rows or columns not supported (" + dataCharacteristics.getRows() + "x" + dataCharacteristics.getCols() + ").");
            }
            LongAccumulator longAccumulator4 = null;
            if (!dataCharacteristics.nnzKnown()) {
                longAccumulator4 = sparkExecutionContext.getSparkContext().sc().longAccumulator(DataExpression.READNNZPARAM);
                binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.mapValues(new ComputeBinaryBlockNnzFunction(longAccumulator4));
            }
            customSaveTextFile(RDDConverterUtils.binaryBlockToLibsvm(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, (FileFormatPropertiesLIBSVM) this.formatProperties, true), str, false);
            if (!dataCharacteristics.nnzKnown()) {
                dataCharacteristics.setNonZeros(longAccumulator4.value().longValue());
            }
        }
        HDFSTool.writeMetaDataFile(str + ".mtd", Types.ValueType.FP64, dataCharacteristics, fileFormat, this.formatProperties);
    }

    protected void processFrameWriteInstruction(SparkExecutionContext sparkExecutionContext, String str, Types.FileFormat fileFormat, Types.ValueType[] valueTypeArr) throws IOException {
        JavaPairRDD<Long, FrameBlock> frameBinaryBlockRDDHandleForVariable = sparkExecutionContext.getFrameBinaryBlockRDDHandleForVariable(this.input1.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        switch (fileFormat) {
            case TEXT:
                customSaveTextFile(FrameRDDConverterUtils.binaryBlockToTextCell(frameBinaryBlockRDDHandleForVariable, dataCharacteristics), str, false);
                break;
            case CSV:
                customSaveTextFile(FrameRDDConverterUtils.binaryBlockToCsv(frameBinaryBlockRDDHandleForVariable, dataCharacteristics, this.formatProperties != null ? (FileFormatPropertiesCSV) this.formatProperties : null, true), str, false);
                break;
            case LIBSVM:
                break;
            case BINARY:
                frameBinaryBlockRDDHandleForVariable.mapToPair(new FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction()).saveAsHadoopFile(str, LongWritable.class, FrameBlock.class, SequenceFileOutputFormat.class);
                break;
            default:
                throw new DMLRuntimeException("Unexpected data format: " + fileFormat.toString());
        }
        HDFSTool.writeMetaDataFile(str + ".mtd", this.input1.getValueType(), valueTypeArr, Types.DataType.FRAME, dataCharacteristics, fileFormat, this.formatProperties);
    }

    private static void customSaveTextFile(JavaRDD<String> javaRDD, String str, boolean z) {
        if (!z) {
            javaRDD.saveAsTextFile(str);
            return;
        }
        Random random = new Random();
        String str2 = str + LineageDedupUtils.DEDUP_DELIM + random.nextLong() + LineageDedupUtils.DEDUP_DELIM + random.nextLong();
        while (HDFSTool.existsFileOnHDFS(str2)) {
            try {
                try {
                    str2 = str + LineageDedupUtils.DEDUP_DELIM + random.nextLong() + LineageDedupUtils.DEDUP_DELIM + random.nextLong();
                } catch (Throwable th) {
                    try {
                        HDFSTool.deleteFileIfExistOnHDFS(str2);
                        throw th;
                    } catch (IOException e) {
                        throw new DMLRuntimeException("Cannot merge the output into single file: " + e.getMessage());
                    }
                }
            } catch (IOException e2) {
                throw new DMLRuntimeException("Cannot merge the output into single file: " + e2.getMessage());
            }
        }
        javaRDD.saveAsTextFile(str2);
        HDFSTool.mergeIntoSingleFile(str2, str);
        try {
            HDFSTool.deleteFileIfExistOnHDFS(str2);
        } catch (IOException e3) {
            throw new DMLRuntimeException("Cannot merge the output into single file: " + e3.getMessage());
        }
    }

    @Override // org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        LineageItem[] lineage = LineageItemUtils.getLineage(executionContext, this.input1, this.input2, this.input3, this.input4);
        if (this.formatProperties != null && this.formatProperties.getDescription() != null && !this.formatProperties.getDescription().isEmpty()) {
            lineage = (LineageItem[]) ArrayUtils.add(lineage, new LineageItem(this.formatProperties.getDescription()));
        }
        return Pair.of(this.input1.getName(), new LineageItem(getOpcode(), lineage));
    }
}
