package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoder.class */
public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder> {
    protected static final Log LOG = LogFactory.getLog(ColumnEncoder.class.getName());
    public static int APPLY_ROW_BLOCKS_PER_COLUMN = -1;
    public static int BUILD_ROW_BLOCKS_PER_COLUMN = -1;
    private static final long serialVersionUID = 2299156350718979064L;
    protected int _colID;
    protected ArrayList<Integer> _sparseRowsWZeros = null;
    protected long _estMetaSize = 0;
    protected int _estNumDistincts = 0;
    protected int _nBuildPartitions = 0;
    protected int _nApplyPartitions = 0;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoder$ColumnApplyTask.class */
    public static class ColumnApplyTask<T extends ColumnEncoder> implements Callable<Object> {
        protected final T _encoder;
        protected final CacheBlock _input;
        protected final MatrixBlock _out;
        protected final int _outputCol;
        protected final int _startRow;
        protected final int _blk;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: protected */
        public ColumnApplyTask(T t, CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
            this(t, cacheBlock, matrixBlock, i, 0, -1);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public ColumnApplyTask(T t, CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
            this._encoder = t;
            this._input = cacheBlock;
            this._out = matrixBlock;
            this._outputCol = i;
            this._startRow = i2;
            this._blk = i3;
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws Exception {
            if (!$assertionsDisabled && this._outputCol < 0) {
                throw new AssertionError();
            }
            this._encoder.apply(this._input, this._out, this._outputCol, this._startRow, this._blk);
            return null;
        }

        public String toString() {
            return getClass().getSimpleName() + "<Encoder: " + this._encoder.getClass().getSimpleName() + "; ColId: " + this._encoder._colID + ">";
        }

        static {
            $assertionsDisabled = !ColumnEncoder.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoder$EncoderType.class */
    public enum EncoderType {
        Recode,
        FeatureHash,
        PassThrough,
        Bin,
        Dummycode,
        Omit,
        MVImpute,
        Composite
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoder$TransformType.class */
    public enum TransformType {
        BIN,
        RECODE,
        DUMMYCODE,
        FEATURE_HASH,
        PASS_THROUGH,
        UDF,
        N_A
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColumnEncoder(int i) {
        this._colID = i;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public MatrixBlock apply(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
        return apply(cacheBlock, matrixBlock, i, 0, -1);
    }

    public MatrixBlock apply(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (matrixBlock.isInSparseFormat()) {
            applySparse(cacheBlock, matrixBlock, i, i2, i3);
        } else {
            applyDense(cacheBlock, matrixBlock, i, i2, i3);
        }
        if (DMLScript.STATISTICS) {
            long nanoTime2 = System.nanoTime() - nanoTime;
            switch (getTransformType()) {
                case RECODE:
                    TransformStatistics.incRecodeApplyTime(nanoTime2);
                    break;
                case BIN:
                    TransformStatistics.incBinningApplyTime(nanoTime2);
                    break;
                case DUMMYCODE:
                    TransformStatistics.incDummyCodeApplyTime(nanoTime2);
                    break;
                case FEATURE_HASH:
                    TransformStatistics.incFeatureHashingApplyTime(nanoTime2);
                    break;
                case PASS_THROUGH:
                    TransformStatistics.incPassThroughApplyTime(nanoTime2);
                    break;
            }
        }
        return matrixBlock;
    }

    protected abstract double getCode(CacheBlock cacheBlock, int i);

    protected abstract double[] getCodeCol(CacheBlock cacheBlock, int i, int i2);

    /* JADX INFO: Access modifiers changed from: protected */
    public void applySparse(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        boolean z = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
        int i4 = this._colID - 1;
        double[] codeCol = getCodeCol(cacheBlock, i2, i3);
        int endIndex = UtilFunctions.getEndIndex(cacheBlock.getNumRows(), i2, i3);
        int i5 = i2;
        while (true) {
            int i6 = i5;
            if (i6 >= endIndex) {
                return;
            }
            int min = Math.min(i6 + 32, endIndex);
            for (int i7 = i6; i7 < min; i7++) {
                if (0 != 0) {
                    SparseRowVector sparseRowVector = (SparseRowVector) matrixBlock.getSparseBlock().get(i7);
                    sparseRowVector.values()[i4] = codeCol[i7 - i2];
                    sparseRowVector.indexes()[i4] = i;
                } else {
                    SparseBlockCSR sparseBlockCSR = (SparseBlockCSR) matrixBlock.getSparseBlock();
                    int[] rowPointers = sparseBlockCSR.rowPointers();
                    sparseBlockCSR.indexes()[rowPointers[i7] + i4] = i;
                    sparseBlockCSR.values()[rowPointers[i7] + i4] = codeCol[i7 - i2];
                }
            }
            i5 = i6 + 32;
        }
    }

    protected void applyDense(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        double[] codeCol = getCodeCol(cacheBlock, i2, i3);
        int endIndex = UtilFunctions.getEndIndex(cacheBlock.getNumRows(), i2, i3);
        int i4 = i2;
        while (true) {
            int i5 = i4;
            if (i5 >= endIndex) {
                return;
            }
            int min = Math.min(i5 + 32, endIndex);
            for (int i6 = i5; i6 < min; i6++) {
                matrixBlock.quickSetValue(i6, i, codeCol[i6 - i2]);
            }
            i4 = i5 + 32;
        }
    }

    protected abstract TransformType getTransformType();

    public boolean isApplicable() {
        return this._colID != -1;
    }

    public boolean isApplicable(int i) {
        return i == this._colID;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void prepareBuildPartial() {
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void buildPartial(FrameBlock frameBlock) {
    }

    public void build(CacheBlock cacheBlock, double[] dArr) {
    }

    public void build(CacheBlock cacheBlock, Map<Integer, double[]> map) {
    }

    public void mergeAt(ColumnEncoder columnEncoder) {
        throw new DMLRuntimeException(getClass().getSimpleName() + " does not support merging with " + columnEncoder.getClass().getSimpleName());
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void updateIndexRanges(long[] jArr, long[] jArr2, int i) {
    }

    public MatrixBlock getColMapping(FrameBlock frameBlock) {
        return null;
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeInt(this._colID);
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException {
        this._colID = objectInput.readInt();
    }

    public int getColID() {
        return this._colID;
    }

    public void setColID(int i) {
        this._colID = i;
    }

    public void shiftCol(int i) {
        this._colID += i;
    }

    public void setEstMetaSize(long j) {
        this._estMetaSize = j;
    }

    public long getEstMetaSize() {
        return this._estMetaSize;
    }

    public void setEstNumDistincts(int i) {
        this._estNumDistincts = i;
    }

    public int getEstNumDistincts() {
        return this._estNumDistincts;
    }

    @Override // java.lang.Comparable
    public int compareTo(ColumnEncoder columnEncoder) {
        return Integer.compare(EncoderFactory.getEncoderType(this), EncoderFactory.getEncoderType(columnEncoder));
    }

    public List<DependencyTask<?>> getBuildTasks(CacheBlock cacheBlock) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = null;
        int[] blockSizes = UtilFunctions.getBlockSizes(cacheBlock.getNumRows(), this._nBuildPartitions);
        if (blockSizes.length == 1) {
            arrayList.add(getBuildTask(cacheBlock));
        } else {
            HashMap<Integer, Object> hashMap = new HashMap<>();
            int i = 0;
            for (int i2 = 0; i2 < blockSizes.length; i2++) {
                arrayList.add(getPartialBuildTask(cacheBlock, i, blockSizes[i2], hashMap));
                i += blockSizes[i2];
            }
            arrayList.add(getPartialMergeBuildTask(hashMap));
            arrayList2 = new ArrayList(Collections.nCopies(arrayList.size() - 1, null));
            arrayList2.add(arrayList.subList(0, arrayList.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(arrayList, arrayList2);
    }

    public Callable<Object> getBuildTask(CacheBlock cacheBlock) {
        throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building");
    }

    public Callable<Object> getPartialBuildTask(CacheBlock cacheBlock, int i, int i2, HashMap<Integer, Object> hashMap) {
        throw new DMLRuntimeException("Trying to get the PartialBuild task of an Encoder which does not support  partial building");
    }

    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> hashMap) {
        throw new DMLRuntimeException("Trying to get the BuildMergeTask task of an Encoder which does not support partial building");
    }

    public List<DependencyTask<?>> getApplyTasks(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = null;
        int[] blockSizes = UtilFunctions.getBlockSizes(cacheBlock.getNumRows(), this._nApplyPartitions);
        int i2 = 0;
        for (int i3 = 0; i3 < blockSizes.length; i3++) {
            if (matrixBlock.isInSparseFormat()) {
                arrayList.add(getSparseTask(cacheBlock, matrixBlock, i, i2, blockSizes[i3]));
            } else {
                arrayList.add(getDenseTask(cacheBlock, matrixBlock, i, i2, blockSizes[i3]));
            }
            i2 += blockSizes[i3];
        }
        if (arrayList.size() > 1) {
            arrayList2 = new ArrayList(Collections.nCopies(arrayList.size(), null));
            arrayList.add(() -> {
                return null;
            });
            arrayList2.add(arrayList.subList(0, arrayList.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(arrayList, arrayList2);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        return new ColumnApplyTask<>(this, cacheBlock, matrixBlock, i, i2, i3);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getDenseTask(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        return new ColumnApplyTask<>(this, cacheBlock, matrixBlock, i, i2, i3);
    }

    public Set<Integer> getSparseRowsWZeros() {
        if (this._sparseRowsWZeros != null) {
            return new HashSet(this._sparseRowsWZeros);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addSparseRowsWZeros(ArrayList<Integer> arrayList) {
        synchronized (this) {
            if (this._sparseRowsWZeros == null) {
                this._sparseRowsWZeros = new ArrayList<>();
            }
            this._sparseRowsWZeros.addAll(arrayList);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setBuildRowBlocksPerColumn(int i) {
        this._nBuildPartitions = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setApplyRowBlocksPerColumn(int i) {
        this._nApplyPartitions = i;
    }
}
