package org.apache.sysds.runtime.transform.encode;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.ACompressedArray;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.DDCArray;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/CompressedEncode.class */
public class CompressedEncode {
    protected static final Log LOG = LogFactory.getLog(CompressedEncode.class.getName());
    private final MultiColumnEncoder enc;
    private final FrameBlock in;
    private final int k;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/CompressedEncode$EncodeTask.class */
    public class EncodeTask implements Callable<AColGroup> {
        ColumnEncoderComposite c;

        protected EncodeTask(ColumnEncoderComposite columnEncoderComposite) {
            this.c = columnEncoderComposite;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public AColGroup call() throws Exception {
            return CompressedEncode.this.encode(this.c);
        }
    }

    private CompressedEncode(MultiColumnEncoder multiColumnEncoder, FrameBlock frameBlock, int i) {
        this.enc = multiColumnEncoder;
        this.in = frameBlock;
        this.k = i;
    }

    public static MatrixBlock encode(MultiColumnEncoder multiColumnEncoder, FrameBlock frameBlock, int i) throws InterruptedException, ExecutionException {
        return new CompressedEncode(multiColumnEncoder, frameBlock, i).apply();
    }

    private MatrixBlock apply() throws InterruptedException, ExecutionException {
        List<ColumnEncoderComposite> columnEncoders = this.enc.getColumnEncoders();
        List<AColGroup> multiThread = isParallel() ? multiThread(columnEncoders) : singleThread(columnEncoders);
        CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(this.in.getNumRows(), shiftGroups(multiThread), -1L, false, multiThread);
        compressedMatrixBlock.recomputeNonZeros();
        logging(compressedMatrixBlock);
        return compressedMatrixBlock;
    }

    private boolean isParallel() {
        return this.k > 1 && this.enc.getEncoders().size() > 1;
    }

    private List<AColGroup> singleThread(List<ColumnEncoderComposite> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<ColumnEncoderComposite> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(encode(it.next()));
        }
        return arrayList;
    }

    private List<AColGroup> multiThread(List<ColumnEncoderComposite> list) throws InterruptedException, ExecutionException {
        ExecutorService executorService = CommonThreadPool.get(this.k);
        try {
            ArrayList arrayList = new ArrayList(list.size());
            Iterator<ColumnEncoderComposite> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(new EncodeTask(it.next()));
            }
            ArrayList arrayList2 = new ArrayList(list.size());
            Iterator it2 = executorService.invokeAll(arrayList).iterator();
            while (it2.hasNext()) {
                arrayList2.add((AColGroup) ((Future) it2.next()).get());
            }
            return arrayList2;
        } finally {
            executorService.shutdown();
        }
    }

    private int shiftGroups(List<AColGroup> list) {
        int size = list.get(0).getColIndices().size();
        for (int i = 1; i < list.size(); i++) {
            list.set(i, list.get(i).shiftColIndices(size));
            size += list.get(i).getColIndices().size();
        }
        return size;
    }

    private AColGroup encode(ColumnEncoderComposite columnEncoderComposite) {
        if (columnEncoderComposite.isRecodeToDummy()) {
            return recodeToDummy(columnEncoderComposite);
        }
        if (columnEncoderComposite.isRecode()) {
            return recode(columnEncoderComposite);
        }
        if (columnEncoderComposite.isPassThrough()) {
            return passThrough(columnEncoderComposite);
        }
        if (columnEncoderComposite.isBin()) {
            return bin(columnEncoderComposite);
        }
        if (columnEncoderComposite.isBinToDummy()) {
            return binToDummy(columnEncoderComposite);
        }
        if (columnEncoderComposite.isHash()) {
            return hash(columnEncoderComposite);
        }
        if (columnEncoderComposite.isHashToDummy()) {
            return hashToDummy(columnEncoderComposite);
        }
        throw new NotImplementedException("Not supporting : " + columnEncoderComposite);
    }

    private AColGroup recodeToDummy(ColumnEncoderComposite columnEncoderComposite) {
        int i = columnEncoderComposite._colID;
        Array<?> column = this.in.getColumn(i - 1);
        boolean containsNull = column.containsNull();
        Map<?, Long> recodeMap = column.getRecodeMap();
        columnEncoderComposite.getEncoders().set(0, new ColumnEncoderRecode(i, (HashMap) recodeMap));
        int size = recodeMap.size();
        if (containsNull && size == 0) {
            return new ColGroupEmpty(ColIndexFactory.create(1));
        }
        IColIndex create = ColIndexFactory.create(0, size);
        return (size != 1 || containsNull) ? ColGroupDDC.create(create, new IdentityDictionary(create.size(), containsNull), createMappingAMapToData(column, recodeMap, containsNull), null) : ColGroupConst.create(create, new double[]{1.0d});
    }

    private AColGroup bin(ColumnEncoderComposite columnEncoderComposite) {
        Array<?> column = this.in.getColumn(columnEncoderComposite._colID - 1);
        boolean containsNull = column.containsNull();
        ColumnEncoderBin columnEncoderBin = (ColumnEncoderBin) columnEncoderComposite.getEncoders().get(0);
        columnEncoderBin.build(this.in);
        return ColGroupDDC.create(ColIndexFactory.create(1), createIncrementingVector(columnEncoderBin._numBin, containsNull), binEncode(column, columnEncoderBin, containsNull), null);
    }

    private AMapToData binEncode(Array<?> array, ColumnEncoderBin columnEncoderBin, boolean z) {
        AMapToData create = MapToFactory.create(array.size(), columnEncoderBin._numBin + (z ? 1 : 0));
        if (z) {
            for (int i = 0; i < array.size(); i++) {
                double asNaNDouble = array.getAsNaNDouble(i);
                try {
                    if (Double.isNaN(asNaNDouble)) {
                        create.set(i, columnEncoderBin._numBin);
                    } else {
                        int codeIndex = ((int) columnEncoderBin.getCodeIndex(asNaNDouble)) - 1;
                        if (codeIndex < 0) {
                            codeIndex = 0;
                        }
                        create.set(i, codeIndex);
                    }
                } catch (Exception e) {
                    create.set(i, ((int) columnEncoderBin.getCodeIndex(asNaNDouble - 1.0E-5d)) - 1);
                }
            }
        } else {
            for (int i2 = 0; i2 < array.size(); i2++) {
                try {
                    int codeIndex2 = ((int) columnEncoderBin.getCodeIndex(array.getAsDouble(i2))) - 1;
                    if (codeIndex2 < 0) {
                        codeIndex2 = 0;
                    }
                    create.set(i2, codeIndex2);
                } catch (Exception e2) {
                    create.set(i2, ((int) columnEncoderBin.getCodeIndex(array.getAsDouble(i2) - 1.0E-5d)) - 1);
                }
            }
        }
        return create;
    }

    private MatrixBlockDictionary createIncrementingVector(int i, boolean z) {
        MatrixBlock matrixBlock = new MatrixBlock(i + (z ? 1 : 0), 1, false);
        for (int i2 = 0; i2 < i; i2++) {
            matrixBlock.quickSetValue(i2, 0, i2 + 1);
        }
        if (z) {
            matrixBlock.quickSetValue(i, 0, Double.NaN);
        }
        return MatrixBlockDictionary.create(matrixBlock);
    }

    private AColGroup binToDummy(ColumnEncoderComposite columnEncoderComposite) {
        Array<?> column = this.in.getColumn(columnEncoderComposite._colID - 1);
        boolean containsNull = column.containsNull();
        ColumnEncoderBin columnEncoderBin = (ColumnEncoderBin) columnEncoderComposite.getEncoders().get(0);
        columnEncoderBin.build(this.in);
        IColIndex create = ColIndexFactory.create(0, columnEncoderBin._numBin);
        return ColGroupDDC.create(create, new IdentityDictionary(create.size(), containsNull), binEncode(column, columnEncoderBin, containsNull), null);
    }

    private AColGroup recode(ColumnEncoderComposite columnEncoderComposite) {
        int i = columnEncoderComposite._colID;
        Array<?> column = this.in.getColumn(i - 1);
        Map<?, Long> recodeMap = column.getRecodeMap();
        boolean containsNull = column.containsNull();
        int size = recodeMap.size();
        IColIndex create = ColIndexFactory.create(1);
        if (size == 1) {
            return ColGroupConst.create(create, new double[]{1.0d});
        }
        MatrixBlock matrixBlock = new MatrixBlock(size + (containsNull ? 1 : 0), 1, false);
        for (int i2 = 0; i2 < size; i2++) {
            matrixBlock.quickSetValue(i2, 0, i2 + 1);
        }
        if (containsNull) {
            matrixBlock.quickSetValue(size, 0, Double.NaN);
        }
        MatrixBlockDictionary create2 = MatrixBlockDictionary.create(matrixBlock);
        AMapToData createMappingAMapToData = createMappingAMapToData(column, recodeMap, containsNull);
        columnEncoderComposite.getEncoders().set(0, new ColumnEncoderRecode(i, (HashMap) recodeMap));
        return ColGroupDDC.create(create, create2, createMappingAMapToData, null);
    }

    private AColGroup passThrough(ColumnEncoderComposite columnEncoderComposite) {
        IColIndex create = ColIndexFactory.create(1);
        Array<?> column = this.in.getColumn(columnEncoderComposite._colID - 1);
        if (column instanceof ACompressedArray) {
            switch (column.getFrameArrayType()) {
                case DDC:
                    DDCArray dDCArray = (DDCArray) column;
                    Array dict = dDCArray.getDict();
                    double[] dArr = new double[dict.size()];
                    for (int i = 0; i < dict.size(); i++) {
                        dArr[i] = dict.getAsDouble(i);
                    }
                    return ColGroupDDC.create(create, Dictionary.create(dArr), dDCArray.getMap(), null);
                default:
                    throw new NotImplementedException();
            }
        }
        boolean containsNull = column.containsNull();
        HashMap hashMap = (HashMap) column.getRecodeMap();
        if (hashMap.size() >= ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE)) {
            MatrixBlock matrixBlock = new MatrixBlock(column.size(), 1, (double[]) column.changeType(Types.ValueType.FP64).get());
            matrixBlock.recomputeNonZeros();
            return ColGroupUncompressed.create(create, matrixBlock, false);
        }
        double[] dArr2 = new double[hashMap.size() + (containsNull ? 1 : 0)];
        if (containsNull) {
            dArr2[hashMap.size()] = Double.NaN;
        }
        Types.ValueType valueType = column.getValueType();
        hashMap.forEach((obj, l) -> {
            dArr2[l.intValue() - 1] = UtilFunctions.objectToDouble(valueType, obj);
        });
        return ColGroupDDC.create(create, Dictionary.create(dArr2), createMappingAMapToData(column, hashMap, containsNull), null);
    }

    private AMapToData createMappingAMapToData(Array<?> array, Map<?, Long> map, boolean z) {
        try {
            int size = map.size();
            AMapToData create = MapToFactory.create(this.in.getNumRows(), size + (z ? 1 : 0));
            Array<?>.ArrayIterator iterator = array.getIterator();
            if (z) {
                while (iterator.hasNext()) {
                    Object next = iterator.next();
                    if (next != null) {
                        try {
                            create.set(iterator.getIndex(), map.get(next).intValue() - 1);
                        } catch (Exception e) {
                            throw new RuntimeException("failed on " + next + " " + array.getValueType(), e);
                        }
                    } else {
                        create.set(iterator.getIndex(), size);
                    }
                }
            } else {
                while (iterator.hasNext()) {
                    create.set(iterator.getIndex(), map.get(iterator.next()).intValue() - 1);
                }
            }
            return create;
        } catch (Exception e2) {
            throw new RuntimeException("failed constructing map: " + map, e2);
        }
    }

    private AMapToData createHashMappingAMapToData(Array<?> array, int i, boolean z) {
        AMapToData create = MapToFactory.create(array.size(), i + (z ? 1 : 0));
        if (z) {
            for (int i2 = 0; i2 < array.size(); i2++) {
                double abs = Math.abs(array.hashDouble(i2));
                if (Double.isNaN(abs)) {
                    create.set(i2, i);
                } else {
                    create.set(i2, ((int) abs) % i);
                }
            }
        } else {
            for (int i3 = 0; i3 < array.size(); i3++) {
                create.set(i3, ((int) Math.abs(array.hashDouble(i3))) % i);
            }
        }
        return create;
    }

    private AColGroup hash(ColumnEncoderComposite columnEncoderComposite) {
        Array<?> column = this.in.getColumn(columnEncoderComposite._colID - 1);
        int k = (int) ((ColumnEncoderFeatureHash) columnEncoderComposite.getEncoders().get(0)).getK();
        boolean containsNull = column.containsNull();
        IColIndex create = ColIndexFactory.create(0, 1);
        if (k == 1 && !containsNull) {
            return ColGroupConst.create(create, new double[]{1.0d});
        }
        MatrixBlock matrixBlock = new MatrixBlock(k + (containsNull ? 1 : 0), 1, false);
        for (int i = 0; i < k; i++) {
            matrixBlock.quickSetValue(i, 0, i + 1);
        }
        if (containsNull) {
            matrixBlock.quickSetValue(k, 0, Double.NaN);
        }
        return ColGroupDDC.create(create, MatrixBlockDictionary.create(matrixBlock), createHashMappingAMapToData(column, k, containsNull), null);
    }

    private AColGroup hashToDummy(ColumnEncoderComposite columnEncoderComposite) {
        Array<?> column = this.in.getColumn(columnEncoderComposite._colID - 1);
        int k = (int) ((ColumnEncoderFeatureHash) columnEncoderComposite.getEncoders().get(0)).getK();
        boolean containsNull = column.containsNull();
        IColIndex create = ColIndexFactory.create(0, k);
        return (k != 1 || containsNull) ? ColGroupDDC.create(create, new IdentityDictionary(create.size(), containsNull), createHashMappingAMapToData(column, k, containsNull), null) : ColGroupConst.create(create, new double[]{1.0d});
    }

    private void logging(MatrixBlock matrixBlock) {
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("Uncompressed transform encode Dense size:   %16d", Long.valueOf(matrixBlock.estimateSizeDenseInMemory())));
            LOG.debug(String.format("Uncompressed transform encode Sparse size:  %16d", Long.valueOf(matrixBlock.estimateSizeSparseInMemory())));
            LOG.debug(String.format("Compressed transform encode size:           %16d", Long.valueOf(matrixBlock.estimateSizeInMemory())));
            double min = Math.min(matrixBlock.estimateSizeDenseInMemory(), matrixBlock.estimateSizeSparseInMemory()) / matrixBlock.estimateSizeInMemory();
            double estimateSizeDenseInMemory = matrixBlock.estimateSizeDenseInMemory() / matrixBlock.estimateSizeInMemory();
            LOG.debug(String.format("Compression ratio: %10.3f", Double.valueOf(min)));
            LOG.debug(String.format("Dense ratio:       %10.3f", Double.valueOf(estimateSizeDenseInMemory)));
        }
    }
}
