package org.apache.sysds.runtime.instructions.spark;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.util.AccumulatorV2;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.io.FrameReaderFactory;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.transform.meta.TfOffsetMap;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.class */
public class MultiReturnParameterizedBuiltinSPInstruction extends ComputationSPInstruction {
    protected ArrayList<CPOperand> _outputs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction$MaxLongAccumulator.class */
    public static class MaxLongAccumulator extends AccumulatorV2<Long, Long> {
        private static final long serialVersionUID = -3739727823287550826L;
        private long _value;

        public MaxLongAccumulator(long j) {
            this._value = Long.MIN_VALUE;
            this._value = j;
        }

        public void add(Long l) {
            this._value = Math.max(this._value, l.longValue());
        }

        public AccumulatorV2<Long, Long> copy() {
            return new MaxLongAccumulator(this._value);
        }

        public boolean isZero() {
            return this._value == Long.MIN_VALUE;
        }

        public void merge(AccumulatorV2<Long, Long> accumulatorV2) {
            this._value = Math.max(this._value, ((Long) accumulatorV2.value()).longValue());
        }

        public void reset() {
            this._value = Long.MIN_VALUE;
        }

        /* renamed from: value, reason: merged with bridge method [inline-methods] */
        public Long m634value() {
            return Long.valueOf(this._value);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction$TransformEncodeBuild2Function.class */
    public static class TransformEncodeBuild2Function implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, FrameBlock.ColumnMetadata> {
        private static final long serialVersionUID = 6336375833412029279L;
        private EncoderMVImpute _encoder;

        public TransformEncodeBuild2Function(EncoderMVImpute encoderMVImpute) {
            this._encoder = null;
            this._encoder = encoderMVImpute;
        }

        public Iterator<Tuple2<Integer, FrameBlock.ColumnMetadata>> call(Iterator<Tuple2<Long, FrameBlock>> it) throws Exception {
            while (it.hasNext()) {
                this._encoder.build((FrameBlock) it.next()._2());
            }
            ArrayList arrayList = new ArrayList();
            int[] colList = this._encoder.getColList();
            for (int i = 0; i < colList.length; i++) {
                if (this._encoder.getMethod(colList[i]) == EncoderMVImpute.MVMethod.GLOBAL_MODE) {
                    for (Map.Entry<String, Long> entry : this._encoder.getHistogram(colList[i]).entrySet()) {
                        arrayList.add(new Tuple2(Integer.valueOf(colList[i]), new FrameBlock.ColumnMetadata(entry.getValue().longValue(), entry.getKey())));
                    }
                } else if (this._encoder.getMethod(colList[i]) == EncoderMVImpute.MVMethod.GLOBAL_MEAN) {
                    arrayList.add(new Tuple2(Integer.valueOf(colList[i]), new FrameBlock.ColumnMetadata(this._encoder.getNonMVCount(colList[i]), String.valueOf(this._encoder.getMeans()[i]._sum))));
                } else if (this._encoder.getMethod(colList[i]) == EncoderMVImpute.MVMethod.CONSTANT) {
                    arrayList.add(new Tuple2(Integer.valueOf(colList[i]), new FrameBlock.ColumnMetadata(0L, this._encoder.getReplacement(colList[i]))));
                }
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction$TransformEncodeBuildFunction.class */
    public static class TransformEncodeBuildFunction implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, Object> {
        private static final long serialVersionUID = 6336375833412029279L;
        private EncoderRecode _raEncoder;

        public TransformEncodeBuildFunction(Encoder encoder) {
            this._raEncoder = null;
            for (Encoder encoder2 : ((EncoderComposite) encoder).getEncoders()) {
                if (encoder2 instanceof EncoderRecode) {
                    this._raEncoder = (EncoderRecode) encoder2;
                }
            }
        }

        public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>> it) throws Exception {
            if (this._raEncoder != null) {
                this._raEncoder.prepareBuildPartial();
                while (it.hasNext()) {
                    this._raEncoder.buildPartial((FrameBlock) it.next()._2());
                }
            }
            ArrayList arrayList = new ArrayList();
            for (Map.Entry<Integer, HashSet<Object>> entry : this._raEncoder.getCPRecodeMapsPartial().entrySet()) {
                Iterator<Object> it2 = entry.getValue().iterator();
                while (it2.hasNext()) {
                    arrayList.add(new Tuple2(entry.getKey(), it2.next()));
                }
            }
            if (this._raEncoder != null) {
                this._raEncoder.getCPRecodeMapsPartial().clear();
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction$TransformEncodeGroup2Function.class */
    public static class TransformEncodeGroup2Function implements FlatMapFunction<Tuple2<Integer, Iterable<FrameBlock.ColumnMetadata>>, String> {
        private static final long serialVersionUID = 702100641492347459L;
        private EncoderMVImpute _encoder;

        public TransformEncodeGroup2Function(EncoderMVImpute encoderMVImpute) {
            this._encoder = null;
            this._encoder = encoderMVImpute;
        }

        public Iterator<String> call(Tuple2<Integer, Iterable<FrameBlock.ColumnMetadata>> tuple2) throws Exception {
            int i;
            int intValue = ((Integer) tuple2._1()).intValue();
            Iterator it = ((Iterable) tuple2._2()).iterator();
            ArrayList arrayList = new ArrayList();
            if (this._encoder.getMethod(intValue) == EncoderMVImpute.MVMethod.GLOBAL_MODE) {
                HashMap hashMap = new HashMap();
                while (it.hasNext()) {
                    FrameBlock.ColumnMetadata columnMetadata = (FrameBlock.ColumnMetadata) it.next();
                    Long l = (Long) hashMap.get(columnMetadata.getMvValue());
                    hashMap.put(columnMetadata.getMvValue(), Long.valueOf(columnMetadata.getNumDistinct() + (l != null ? l.longValue() : 0L)));
                }
                long j = Long.MIN_VALUE;
                String str = null;
                for (Map.Entry entry : hashMap.entrySet()) {
                    if (((Long) entry.getValue()).longValue() > j) {
                        str = (String) entry.getKey();
                        j = ((Long) entry.getValue()).longValue();
                    }
                }
                arrayList.add("-2 " + intValue + " " + str);
            } else if (this._encoder.getMethod(intValue) == EncoderMVImpute.MVMethod.GLOBAL_MEAN) {
                KahanObject kahanObject = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
                KahanPlus kahanPlusFnObject = KahanPlus.getKahanPlusFnObject();
                int i2 = 0;
                while (true) {
                    i = i2;
                    if (!it.hasNext()) {
                        break;
                    }
                    FrameBlock.ColumnMetadata columnMetadata2 = (FrameBlock.ColumnMetadata) it.next();
                    kahanPlusFnObject.execute2(kahanObject, Double.parseDouble(columnMetadata2.getMvValue()));
                    i2 = (int) (i + columnMetadata2.getNumDistinct());
                }
                if (i > 0) {
                    arrayList.add("-2 " + intValue + " " + String.valueOf(kahanObject._sum / i));
                }
            } else if (this._encoder.getMethod(intValue) == EncoderMVImpute.MVMethod.CONSTANT && it.hasNext()) {
                arrayList.add("-2 " + intValue + " " + ((FrameBlock.ColumnMetadata) it.next()).getMvValue());
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction$TransformEncodeGroupFunction.class */
    public static class TransformEncodeGroupFunction implements FlatMapFunction<Tuple2<Integer, Iterable<Object>>, String> {
        private static final long serialVersionUID = -1034187226023517119L;
        private MaxLongAccumulator _accMax;

        public TransformEncodeGroupFunction(MaxLongAccumulator maxLongAccumulator) {
            this._accMax = null;
            this._accMax = maxLongAccumulator;
        }

        public Iterator<String> call(Tuple2<Integer, Iterable<Object>> tuple2) throws Exception {
            String valueOf = String.valueOf(tuple2._1());
            Iterator it = ((Iterable) tuple2._2()).iterator();
            ArrayList arrayList = new ArrayList();
            StringBuilder sb = new StringBuilder();
            long j = 1;
            while (true) {
                long j2 = j;
                if (!it.hasNext()) {
                    this._accMax.add(Long.valueOf(j2 - 1));
                    return arrayList.iterator();
                }
                sb.append(j2);
                sb.append(' ');
                sb.append(valueOf);
                sb.append(' ');
                sb.append(EncoderRecode.constructRecodeMapEntry(it.next().toString(), Long.valueOf(j2)));
                arrayList.add(sb.toString());
                sb.setLength(0);
                j = j2 + 1;
            }
        }
    }

    private MultiReturnParameterizedBuiltinSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, ArrayList<CPOperand> arrayList, String str, String str2) {
        super(SPInstruction.SPType.MultiReturnBuiltin, operator, cPOperand, cPOperand2, arrayList.get(0), str, str2);
        this._outputs = arrayList;
    }

    public static MultiReturnParameterizedBuiltinSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList arrayList = new ArrayList();
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase("transformencode")) {
            throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + str2);
        }
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        arrayList.add(new CPOperand(instructionPartsWithValueType[3], Types.ValueType.FP64, Types.DataType.MATRIX));
        arrayList.add(new CPOperand(instructionPartsWithValueType[4], Types.ValueType.STRING, Types.DataType.FRAME));
        return new MultiReturnParameterizedBuiltinSPInstruction(null, cPOperand, cPOperand2, arrayList, str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        try {
            FrameObject frameObject = sparkExecutionContext.getFrameObject(this.input1.getName());
            FrameObject frameObject2 = sparkExecutionContext.getFrameObject(this._outputs.get(1).getName());
            JavaPairRDD<?, ?> rDDHandleForFrameObject = sparkExecutionContext.getRDDHandleForFrameObject(frameObject, Types.FileFormat.BINARY);
            String stringValue = executionContext.getScalarInput(this.input2).getStringValue();
            DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
            DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
            String[] columnNames = !TfMetaUtils.isIDSpec(stringValue) ? ((FrameBlock) rDDHandleForFrameObject.lookup(1L).get(0)).getColumnNames() : null;
            Encoder createEncoder = EncoderFactory.createEncoder(stringValue, columnNames, frameObject.getSchema(), (int) frameObject.getNumColumns(), null);
            MaxLongAccumulator registerMaxLongAccumulator = registerMaxLongAccumulator(sparkExecutionContext.getSparkContext());
            JavaRDD flatMap = rDDHandleForFrameObject.mapPartitionsToPair(new TransformEncodeBuildFunction(createEncoder)).distinct().groupByKey().flatMap(new TransformEncodeGroupFunction(registerMaxLongAccumulator));
            if (containsMVImputeEncoder(createEncoder)) {
                EncoderMVImpute mVImputeEncoder = getMVImputeEncoder(createEncoder);
                flatMap = flatMap.union(rDDHandleForFrameObject.mapPartitionsToPair(new TransformEncodeBuild2Function(mVImputeEncoder)).groupByKey().flatMap(new TransformEncodeGroup2Function(mVImputeEncoder)));
            }
            flatMap.saveAsTextFile(frameObject2.getFileName());
            FrameBlock readFrameFromHDFS = FrameReaderFactory.createFrameReader(Types.FileFormat.TEXT).readFrameFromHDFS(frameObject2.getFileName(), registerMaxLongAccumulator.m634value().longValue(), frameObject.getNumColumns());
            readFrameFromHDFS.recomputeColumnCardinality();
            readFrameFromHDFS.setColumnNames(columnNames != null ? columnNames : readFrameFromHDFS.getColumnNames());
            TfOffsetMap tfOffsetMap = null;
            if (TfMetaUtils.containsOmitSpec(stringValue, columnNames)) {
                tfOffsetMap = new TfOffsetMap(SparkUtils.toIndexedLong(rDDHandleForFrameObject.mapToPair(new ParameterizedBuiltinSPInstruction.RDDTransformApplyOffsetFunction(stringValue, columnNames)).collect()));
            }
            Encoder createEncoder2 = EncoderFactory.createEncoder(stringValue, columnNames, frameObject.getSchema(), (int) frameObject.getNumColumns(), readFrameFromHDFS);
            dataCharacteristics2.setDimension(dataCharacteristics.getRows() - (tfOffsetMap != null ? tfOffsetMap.getNumRmRows() : 0L), createEncoder2.getNumCols());
            sparkExecutionContext.setRDDHandleForVariable(this._outputs.get(0).getName(), FrameRDDConverterUtils.binaryBlockToMatrixBlock(rDDHandleForFrameObject.mapToPair(new ParameterizedBuiltinSPInstruction.RDDTransformApplyFunction(sparkExecutionContext.getSparkContext().broadcast(createEncoder2), tfOffsetMap != null ? sparkExecutionContext.getSparkContext().broadcast(tfOffsetMap) : null)), dataCharacteristics2, dataCharacteristics2));
            sparkExecutionContext.addLineageRDD(this._outputs.get(0).getName(), this.input1.getName());
            sparkExecutionContext.setFrameOutput(this._outputs.get(1).getName(), readFrameFromHDFS);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static boolean containsMVImputeEncoder(Encoder encoder) {
        if (!(encoder instanceof EncoderComposite)) {
            return false;
        }
        Iterator<Encoder> it = ((EncoderComposite) encoder).getEncoders().iterator();
        while (it.hasNext()) {
            if (it.next() instanceof EncoderMVImpute) {
                return true;
            }
        }
        return false;
    }

    private static EncoderMVImpute getMVImputeEncoder(Encoder encoder) {
        if (!(encoder instanceof EncoderComposite)) {
            return null;
        }
        for (Encoder encoder2 : ((EncoderComposite) encoder).getEncoders()) {
            if (encoder2 instanceof EncoderMVImpute) {
                return (EncoderMVImpute) encoder2;
            }
        }
        return null;
    }

    private static MaxLongAccumulator registerMaxLongAccumulator(JavaSparkContext javaSparkContext) {
        MaxLongAccumulator maxLongAccumulator = new MaxLongAccumulator(Long.MIN_VALUE);
        javaSparkContext.sc().register(maxLongAccumulator, "max");
        return maxLongAccumulator;
    }
}
