package org.apache.sysds.runtime.matrix.data.sketch.countdistinct;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.class */
public class CountDistinctFunctionSketch extends CountDistinctSketch {
    public CountDistinctFunctionSketch(Operator operator) {
        super(operator);
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public MatrixBlock getValue(MatrixBlock matrixBlock) {
        return null;
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public MatrixBlock getValueFromSketch(CorrMatrixBlock corrMatrixBlock) {
        MatrixBlock correction = corrMatrixBlock.getCorrection();
        MatrixBlock matrixBlock = new MatrixBlock(1, 1, false);
        long j = 0;
        for (int i = 0; i < correction.getNumRows(); i++) {
            j = (long) (j + correction.getValue(i, 1));
        }
        matrixBlock.setValue(0, 0, j);
        return matrixBlock;
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock create(MatrixBlock matrixBlock) {
        int numRows = matrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        if (numRows == 1 && numRows == numColumns) {
            MatrixBlock matrixBlock2 = new MatrixBlock(1, 2, false);
            matrixBlock2.setValue(0, 1, 1.0d);
            return new CorrMatrixBlock(matrixBlock, matrixBlock2);
        }
        if (matrixBlock.isEmpty()) {
            return new CorrMatrixBlock(matrixBlock, new MatrixBlock(1, 2, false));
        }
        HashMap hashMap = new HashMap();
        int pow = (int) Math.pow(1000.0d, 2.0d);
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numColumns; i2++) {
                short extractRightKBitsFromIndex = (short) extractRightKBitsFromIndex((long) matrixBlock.getValue(i, i2), 52, 12);
                long extractRightKBitsFromIndex2 = extractRightKBitsFromIndex((long) matrixBlock.getValue(i, i2), 0, 52);
                Set<Long> orDefault = hashMap.getOrDefault(Short.valueOf(extractRightKBitsFromIndex), new HashSet());
                orDefault.add(Long.valueOf(extractRightKBitsFromIndex2));
                hashMap.put(Short.valueOf(extractRightKBitsFromIndex), orDefault);
                pow = Math.max(pow, orDefault.size());
            }
        }
        return new CorrMatrixBlock(matrixBlock, serialize(hashMap, pow));
    }

    private long extractRightKBitsFromIndex(long j, int i, int i2) {
        return ((1 << i2) - 1) & (j >> i);
    }

    private MatrixBlock serialize(Map<Short, Set<Long>> map, int i) {
        MatrixBlock matrixBlock = new MatrixBlock(map.size(), i + 2, false);
        int i2 = 0;
        Iterator<Short> it = map.keySet().iterator();
        while (it.hasNext()) {
            short shortValue = it.next().shortValue();
            Set<Long> set = map.get(Short.valueOf(shortValue));
            matrixBlock.setValue(i2, 0, shortValue);
            matrixBlock.setValue(i2, 1, set.size());
            int i3 = 2;
            Iterator<Long> it2 = set.iterator();
            while (it2.hasNext()) {
                matrixBlock.setValue(i2, i3, it2.next().longValue());
                i3++;
            }
            i2++;
        }
        return matrixBlock;
    }

    private Map<Short, Set<Long>> deserialize(MatrixBlock matrixBlock) {
        int numRows = matrixBlock.getNumRows();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < numRows; i++) {
            short value = (short) matrixBlock.getValue(i, 0);
            Set set = (Set) hashMap.getOrDefault(Short.valueOf(value), new HashSet());
            int value2 = (int) matrixBlock.getValue(i, 1);
            for (int i2 = 0; i2 < value2; i2++) {
                set.add(Long.valueOf((long) matrixBlock.getValue(i, i2 + 2)));
            }
            hashMap.put(Short.valueOf(value), set);
        }
        return hashMap;
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock union(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) {
        Map<Short, Set<Long>> map = (Map) Stream.concat(deserialize(corrMatrixBlock.getCorrection()).entrySet().stream(), deserialize(corrMatrixBlock2.getCorrection()).entrySet().stream()).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, new BitMapValueCombiner()));
        OptionalInt max = map.values().stream().mapToInt((v0) -> {
            return v0.size();
        }).max();
        if (max.isEmpty()) {
            throw new IllegalArgumentException("Corrupt sketch: metadata is invalid");
        }
        return new CorrMatrixBlock(corrMatrixBlock.getValue(), serialize(map, max.getAsInt()));
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock intersection(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) {
        return null;
    }
}
