package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.DependencyWrapperTask;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.utils.stats.TransformStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.class */
public class MultiColumnEncoder implements Encoder {
    protected static final Log LOG;
    public static boolean MULTI_THREADED_STAGES;
    public static boolean APPLY_ENCODER_SEPARATE_STAGES;
    private List<ColumnEncoderComposite> _columnEncoders;
    private EncoderMVImpute _legacyMVImpute;
    private EncoderOmit _legacyOmit;
    private int _colOffset;
    private FrameBlock _meta;
    private boolean _partitionDone;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$AllocMetaTask.class */
    public static class AllocMetaTask implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _meta;

        private AllocMetaTask(MultiColumnEncoder multiColumnEncoder, FrameBlock frameBlock) {
            this._encoder = multiColumnEncoder;
            this._meta = frameBlock;
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws Exception {
            this._encoder.allocateMetaData(this._meta);
            return null;
        }

        public String toString() {
            return getClass().getSimpleName();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$ApplyTasksWrapperTask.class */
    public static class ApplyTasksWrapperTask extends DependencyWrapperTask<Object> {
        private final ColumnEncoder _encoder;
        private final MatrixBlock _out;
        private final CacheBlock _in;
        private int _offset;

        private ApplyTasksWrapperTask(ColumnEncoder columnEncoder, CacheBlock cacheBlock, MatrixBlock matrixBlock, DependencyThreadPool dependencyThreadPool) {
            super(dependencyThreadPool);
            this._offset = -1;
            this._encoder = columnEncoder;
            this._out = matrixBlock;
            this._in = cacheBlock;
        }

        @Override // org.apache.sysds.runtime.util.DependencyWrapperTask
        public List<DependencyTask<?>> getWrappedTasks() {
            return this._encoder.getApplyTasks(this._in, this._out, (this._encoder._colID - 1) + this._offset);
        }

        @Override // org.apache.sysds.runtime.util.DependencyWrapperTask, org.apache.sysds.runtime.util.DependencyTask, java.util.concurrent.Callable
        public Object call() throws Exception {
            if (this._offset == -1) {
                throw new DMLRuntimeException("OutputCol for apply task wrapper has not been updated!, Most likely some concurrency issues");
            }
            return super.call();
        }

        public void setOffset(int i) {
            this._offset = i;
        }

        @Override // org.apache.sysds.runtime.util.DependencyTask
        public String toString() {
            return getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$ColumnMetaDataTask.class */
    public static class ColumnMetaDataTask<T extends ColumnEncoder> implements Callable<Object> {
        private final T _colEncoder;
        private final FrameBlock _out;

        protected ColumnMetaDataTask(T t, FrameBlock frameBlock) {
            this._colEncoder = t;
            this._out = frameBlock;
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws Exception {
            this._colEncoder.getMetaData(this._out);
            return null;
        }

        public String toString() {
            return getClass().getSimpleName() + "<ColId: " + this._colEncoder._colID + ">";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$InitOutputMatrixTask.class */
    public static class InitOutputMatrixTask implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final CacheBlock _input;
        private final MatrixBlock _output;

        private InitOutputMatrixTask(MultiColumnEncoder multiColumnEncoder, CacheBlock cacheBlock, MatrixBlock matrixBlock) {
            this._encoder = multiColumnEncoder;
            this._input = cacheBlock;
            this._output = matrixBlock;
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws Exception {
            boolean anyMatch = this._encoder.getColumnEncoders().stream().anyMatch(columnEncoderComposite -> {
                return columnEncoderComposite.hasEncoder(ColumnEncoderUDF.class);
            });
            int numColumns = this._input.getNumColumns() + this._encoder.getNumExtraCols();
            boolean z = this._encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
            long numRows = this._input.getNumRows() * (anyMatch ? numColumns : this._input.getNumColumns());
            this._output.reset(this._input.getNumRows(), numColumns, MatrixBlock.evalSparseFormatInMemory((long) this._input.getNumRows(), (long) numColumns, numRows) && !anyMatch, numRows);
            MultiColumnEncoder.outputMatrixPreProcessing(this._output, this._input, z);
            return null;
        }

        public String toString() {
            return getClass().getSimpleName();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$MultiColumnLegacyBuildTask.class */
    private static class MultiColumnLegacyBuildTask implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyBuildTask(MultiColumnEncoder multiColumnEncoder, FrameBlock frameBlock) {
            this._encoder = multiColumnEncoder;
            this._input = frameBlock;
        }

        @Override // java.util.concurrent.Callable
        /* renamed from: call, reason: merged with bridge method [inline-methods] */
        public Object call2() throws Exception {
            this._encoder.legacyBuild(this._input);
            return null;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$MultiColumnLegacyMVImputeMetaPrepareTask.class */
    private static class MultiColumnLegacyMVImputeMetaPrepareTask implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyMVImputeMetaPrepareTask(MultiColumnEncoder multiColumnEncoder, FrameBlock frameBlock) {
            this._encoder = multiColumnEncoder;
            this._input = frameBlock;
        }

        @Override // java.util.concurrent.Callable
        /* renamed from: call, reason: merged with bridge method [inline-methods] */
        public Object call2() throws Exception {
            this._encoder._meta = this._encoder.getMetaData(new FrameBlock(this._input.getNumColumns(), Types.ValueType.STRING));
            this._encoder.initMetaData(this._encoder._meta);
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/MultiColumnEncoder$UpdateOutputColTask.class */
    public static class UpdateOutputColTask implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final List<DependencyTask<?>> _applyTasksWrappers;

        private UpdateOutputColTask(MultiColumnEncoder multiColumnEncoder, List<DependencyTask<?>> list) {
            this._encoder = multiColumnEncoder;
            this._applyTasksWrappers = list;
        }

        public String toString() {
            return getClass().getSimpleName();
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws Exception {
            int i = -1;
            int i2 = 0;
            for (DependencyTask<?> dependencyTask : this._applyTasksWrappers) {
                int i3 = ((ApplyTasksWrapperTask) dependencyTask)._encoder._colID - 1;
                if (i3 > i) {
                    i = i3;
                    i2 = this._encoder._columnEncoders.subList(0, i3).stream().mapToInt(columnEncoderComposite -> {
                        ColumnEncoderDummycode columnEncoderDummycode = (ColumnEncoderDummycode) columnEncoderComposite.getEncoder(ColumnEncoderDummycode.class);
                        if (columnEncoderDummycode == null) {
                            return 0;
                        }
                        return columnEncoderDummycode._domainSize - 1;
                    }).sum();
                }
                ((ApplyTasksWrapperTask) dependencyTask).setOffset(i2);
            }
            return null;
        }
    }

    public MultiColumnEncoder(List<ColumnEncoderComposite> list) {
        this._legacyMVImpute = null;
        this._legacyOmit = null;
        this._colOffset = 0;
        this._meta = null;
        this._partitionDone = false;
        this._columnEncoders = list;
    }

    public MultiColumnEncoder() {
        this._legacyMVImpute = null;
        this._legacyOmit = null;
        this._colOffset = 0;
        this._meta = null;
        this._partitionDone = false;
        this._columnEncoders = new ArrayList();
    }

    public MatrixBlock encode(CacheBlock cacheBlock) {
        return encode(cacheBlock, 1);
    }

    public MatrixBlock encode(CacheBlock cacheBlock, int i) {
        MatrixBlock matrixBlock;
        deriveNumRowPartitions(cacheBlock, i);
        if (i > 1) {
            try {
                if (!MULTI_THREADED_STAGES && !hasLegacyEncoder()) {
                    matrixBlock = new MatrixBlock();
                    DependencyThreadPool dependencyThreadPool = new DependencyThreadPool(i);
                    LOG.debug("Encoding with full DAG on " + i + " Threads");
                    try {
                        dependencyThreadPool.submitAllAndWait(getEncodeTasks(cacheBlock, matrixBlock, dependencyThreadPool));
                    } catch (InterruptedException | ExecutionException e) {
                        LOG.error("MT Column encode failed");
                        e.printStackTrace();
                    }
                    dependencyThreadPool.shutdown();
                    outputMatrixPostProcessing(matrixBlock);
                    return matrixBlock;
                }
            } catch (Exception e2) {
                LOG.error("Failed transform-encode frame with \n" + this);
                throw e2;
            }
        }
        LOG.debug("Encoding with staged approach on: " + i + " Threads");
        long nanoTime = System.nanoTime();
        build(cacheBlock, i);
        LOG.debug("Elapsed time for build phase: " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
        if (this._legacyMVImpute != null) {
            this._meta = getMetaData(new FrameBlock(cacheBlock.getNumColumns(), Types.ValueType.STRING));
            initMetaData(this._meta);
        }
        long nanoTime2 = System.nanoTime();
        matrixBlock = apply(cacheBlock, i);
        LOG.debug("Elapsed time for apply phase: " + ((System.nanoTime() - nanoTime2) / 1000000.0d) + " ms");
        return matrixBlock;
    }

    private List<DependencyTask<?>> getEncodeTasks(CacheBlock cacheBlock, MatrixBlock matrixBlock, DependencyThreadPool dependencyThreadPool) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = null;
        HashMap hashMap = new HashMap();
        boolean z = getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
        boolean z2 = false;
        boolean z3 = false;
        this._meta = new FrameBlock(cacheBlock.getNumColumns(), Types.ValueType.STRING);
        arrayList.add(DependencyThreadPool.createDependencyTask(new InitOutputMatrixTask(this, cacheBlock, matrixBlock)));
        arrayList.add(DependencyThreadPool.createDependencyTask(new AllocMetaTask(this, this._meta)));
        for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
            List<DependencyTask<?>> buildTasks = columnEncoderComposite.getBuildTasks(cacheBlock);
            arrayList.addAll(buildTasks);
            if (buildTasks.size() > 0) {
                if (columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1 && !buildTasks.get(buildTasks.size() - 2).hasDependency(buildTasks.get(buildTasks.size() - 1))) {
                    z3 = true;
                }
                if (z3) {
                    hashMap.put(new Integer[]{Integer.valueOf(arrayList.size()), Integer.valueOf(arrayList.size() + 1)}, new Integer[]{Integer.valueOf(arrayList.size() - 2), Integer.valueOf(arrayList.size() - 1)});
                    hashMap.put(new Integer[]{Integer.valueOf(arrayList.size() + 1), Integer.valueOf(arrayList.size() + 2)}, new Integer[]{Integer.valueOf(arrayList.size() - 2), Integer.valueOf(arrayList.size() - 1)});
                } else {
                    hashMap.put(new Integer[]{Integer.valueOf(arrayList.size()), Integer.valueOf(arrayList.size() + 1)}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                    hashMap.put(new Integer[]{Integer.valueOf(arrayList.size() + 1), Integer.valueOf(arrayList.size() + 2)}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                }
                if (!columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class) || buildTasks.size() <= 1) {
                    hashMap.put(new Integer[]{1, 2}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                } else {
                    hashMap.put(new Integer[]{1, 2}, new Integer[]{Integer.valueOf(arrayList.size() - 2), Integer.valueOf(arrayList.size() - 1)});
                }
            }
            hashMap.put(new Integer[]{Integer.valueOf(arrayList.size() + 1), Integer.valueOf(arrayList.size() + 2)}, new Integer[]{1, 2});
            hashMap.put(new Integer[]{Integer.valueOf(arrayList.size()), Integer.valueOf(arrayList.size() + 1)}, new Integer[]{0, 1});
            ApplyTasksWrapperTask applyTasksWrapperTask = new ApplyTasksWrapperTask(columnEncoderComposite, cacheBlock, matrixBlock, dependencyThreadPool);
            if (columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class)) {
                hashMap.put(new Integer[]{0, 1}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                hashMap.put(new Integer[]{-2, -1}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                buildTasks.forEach(dependencyTask -> {
                    dependencyTask.setPriority(5);
                });
                z2 = true;
            }
            if (z && z2) {
                hashMap.put(new Integer[]{Integer.valueOf(arrayList.size()), Integer.valueOf(arrayList.size() + 1)}, new Integer[]{-2, -1});
                arrayList2 = arrayList2 == null ? new ArrayList() : arrayList2;
                arrayList2.add(applyTasksWrapperTask);
            } else {
                applyTasksWrapperTask.setOffset(0);
            }
            arrayList.add(applyTasksWrapperTask);
            arrayList.add(DependencyThreadPool.createDependencyTask(new ColumnMetaDataTask(columnEncoderComposite, this._meta)));
        }
        if (z) {
            arrayList.add(DependencyThreadPool.createDependencyTask(new UpdateOutputColTask(this, arrayList2)));
        }
        ArrayList arrayList3 = new ArrayList(Collections.nCopies(arrayList.size(), null));
        DependencyThreadPool.createDependencyList(arrayList, hashMap, arrayList3);
        return DependencyThreadPool.createDependencyTasks(arrayList, arrayList3);
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void build(CacheBlock cacheBlock) {
        build(cacheBlock, 1);
    }

    public void build(CacheBlock cacheBlock, int i) {
        if (hasLegacyEncoder() && !(cacheBlock instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (!this._partitionDone) {
            deriveNumRowPartitions(cacheBlock, i);
        }
        if (i > 1) {
            buildMT(cacheBlock, i);
        } else {
            for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
                columnEncoderComposite.build(cacheBlock);
                columnEncoderComposite.updateAllDCEncoders();
            }
        }
        if (hasLegacyEncoder()) {
            legacyBuild((FrameBlock) cacheBlock);
        }
    }

    public void build(CacheBlock cacheBlock, int i, Map<Integer, double[]> map) {
        if (hasLegacyEncoder() && !(cacheBlock instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (!this._partitionDone) {
            deriveNumRowPartitions(cacheBlock, i);
        }
        if (i > 1) {
            buildMT(cacheBlock, i);
        } else {
            for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
                columnEncoderComposite.build(cacheBlock, map);
                columnEncoderComposite.updateAllDCEncoders();
            }
        }
        if (hasLegacyEncoder()) {
            legacyBuild((FrameBlock) cacheBlock);
        }
    }

    private List<DependencyTask<?>> getBuildTasks(CacheBlock cacheBlock) {
        ArrayList arrayList = new ArrayList();
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next().getBuildTasks(cacheBlock));
        }
        return arrayList;
    }

    private void buildMT(CacheBlock cacheBlock, int i) {
        DependencyThreadPool dependencyThreadPool = new DependencyThreadPool(i);
        try {
            dependencyThreadPool.submitAllAndWait(getBuildTasks(cacheBlock));
        } catch (InterruptedException | ExecutionException e) {
            LOG.error("MT Column build failed");
            e.printStackTrace();
        }
        dependencyThreadPool.shutdown();
    }

    public void legacyBuild(FrameBlock frameBlock) {
        if (this._legacyOmit != null) {
            this._legacyOmit.build(frameBlock);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.build(frameBlock);
        }
    }

    public MatrixBlock apply(CacheBlock cacheBlock) {
        return apply(cacheBlock, 1);
    }

    public MatrixBlock apply(CacheBlock cacheBlock, int i) {
        boolean anyMatch = this._columnEncoders.stream().anyMatch(columnEncoderComposite -> {
            return columnEncoderComposite.hasEncoder(ColumnEncoderUDF.class);
        });
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().updateAllDCEncoders();
        }
        int numColumns = cacheBlock.getNumColumns() + getNumExtraCols();
        long numRows = cacheBlock.getNumRows() * (anyMatch ? numColumns : cacheBlock.getNumColumns());
        return apply(cacheBlock, new MatrixBlock(cacheBlock.getNumRows(), numColumns, MatrixBlock.evalSparseFormatInMemory((long) cacheBlock.getNumRows(), (long) numColumns, numRows) && !anyMatch, numRows), 0, i);
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public MatrixBlock apply(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
        return apply(cacheBlock, matrixBlock, i, 1);
    }

    public MatrixBlock apply(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2) {
        if (hasLegacyEncoder() && !(cacheBlock instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (cacheBlock.getNumColumns() != getFromAll(ColumnEncoderComposite.class, (v0) -> {
            return v0.getColID();
        }).size()) {
            throw new DMLRuntimeException("Not every column in has a CompositeEncoder. Please make sure every column has a encoder or slice the input accordingly");
        }
        boolean z = false;
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            z = it.next().hasEncoder(ColumnEncoderDummycode.class);
        }
        outputMatrixPreProcessing(matrixBlock, cacheBlock, z);
        if (i2 > 1) {
            if (!this._partitionDone) {
                deriveNumRowPartitions(cacheBlock, i2);
            }
            applyMT(cacheBlock, matrixBlock, i, i2);
        } else {
            int i3 = i;
            for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
                columnEncoderComposite.apply(cacheBlock, matrixBlock, (columnEncoderComposite._colID - 1) + i3);
                if (columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class)) {
                    i3 += ((ColumnEncoderDummycode) columnEncoderComposite.getEncoder(ColumnEncoderDummycode.class))._domainSize - 1;
                }
            }
        }
        outputMatrixPostProcessing(matrixBlock);
        if (this._legacyOmit != null) {
            matrixBlock = this._legacyOmit.apply((FrameBlock) cacheBlock, matrixBlock);
        }
        if (this._legacyMVImpute != null) {
            matrixBlock = this._legacyMVImpute.apply((FrameBlock) cacheBlock, matrixBlock);
        }
        return matrixBlock;
    }

    private List<DependencyTask<?>> getApplyTasks(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
        ArrayList arrayList = new ArrayList();
        int i2 = i;
        for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
            arrayList.addAll(columnEncoderComposite.getApplyTasks(cacheBlock, matrixBlock, (columnEncoderComposite._colID - 1) + i2));
            if (columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class)) {
                i2 += ((ColumnEncoderDummycode) columnEncoderComposite.getEncoder(ColumnEncoderDummycode.class))._domainSize - 1;
            }
        }
        return arrayList;
    }

    private void applyMT(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2) {
        DependencyThreadPool dependencyThreadPool = new DependencyThreadPool(i2);
        try {
            if (APPLY_ENCODER_SEPARATE_STAGES) {
                int i3 = i;
                for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
                    dependencyThreadPool.submitAllAndWait(columnEncoderComposite.getApplyTasks(cacheBlock, matrixBlock, (columnEncoderComposite._colID - 1) + i3));
                    if (columnEncoderComposite.hasEncoder(ColumnEncoderDummycode.class)) {
                        i3 += ((ColumnEncoderDummycode) columnEncoderComposite.getEncoder(ColumnEncoderDummycode.class))._domainSize - 1;
                    }
                }
            } else {
                dependencyThreadPool.submitAllAndWait(getApplyTasks(cacheBlock, matrixBlock, i));
            }
        } catch (InterruptedException | ExecutionException e) {
            LOG.error("MT Column apply failed");
            e.printStackTrace();
        }
        dependencyThreadPool.shutdown();
    }

    private void deriveNumRowPartitions(CacheBlock cacheBlock, int i) {
        int[] iArr = new int[2];
        if (i == 1) {
            iArr[0] = 1;
            iArr[1] = 1;
            this._columnEncoders.forEach(columnEncoderComposite -> {
                columnEncoderComposite.setNumPartitions(1, 1);
            });
            this._partitionDone = true;
            return;
        }
        if (ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN > 0) {
            iArr[0] = ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN;
        }
        if (ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN > 0) {
            iArr[1] = ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN;
        }
        if (iArr[0] == 0 && ConfigurationManager.getParallelBuildBlocks() > 0) {
            iArr[0] = ConfigurationManager.getParallelBuildBlocks();
        }
        if (iArr[1] == 0 && ConfigurationManager.getParallelApplyBlocks() > 0) {
            iArr[1] = ConfigurationManager.getParallelApplyBlocks();
        }
        int numRows = cacheBlock.getNumRows();
        int transformNumThreads = OptimizerUtils.getTransformNumThreads();
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        for (ColumnEncoderComposite columnEncoderComposite2 : this._columnEncoders) {
            if (columnEncoderComposite2.hasBuild()) {
                i2++;
                if (columnEncoderComposite2.hasEncoder(ColumnEncoderRecode.class)) {
                    arrayList.add(columnEncoderComposite2);
                }
            }
        }
        int numColumns = cacheBlock.getNumColumns();
        if (iArr[0] == 0 && i2 > 0 && i2 < transformNumThreads) {
            iArr[0] = Math.round(transformNumThreads / i2);
        }
        if (iArr[1] == 0 && numColumns > 0 && numColumns < transformNumThreads * 2) {
            iArr[1] = Math.round((transformNumThreads * 2.0f) / numColumns);
        }
        while (iArr[0] > 1 && numRows / iArr[0] < 16000) {
            iArr[0] = iArr[0] - 1;
        }
        while (iArr[1] > 1 && numRows / iArr[1] < 16000) {
            iArr[1] = iArr[1] - 1;
        }
        int i3 = iArr[0];
        if (iArr[0] > 1 && arrayList.size() > 0) {
            estimateRCMapSize(cacheBlock, arrayList);
            long localMemBudget = (long) (OptimizerUtils.getLocalMemBudget() - cacheBlock.getInMemorySize());
            long totalMemOverhead = getTotalMemOverhead(cacheBlock, i3, arrayList);
            while (true) {
                long j = totalMemOverhead;
                if (i3 <= 1 || j <= localMemBudget) {
                    break;
                }
                i3--;
                totalMemOverhead = getTotalMemOverhead(cacheBlock, i3, arrayList);
            }
        }
        for (int i4 = 0; i4 < 2; i4++) {
            if (iArr[i4] == 0) {
                iArr[i4] = 1;
            }
        }
        this._partitionDone = true;
        this._columnEncoders.forEach(columnEncoderComposite3 -> {
            columnEncoderComposite3.setNumPartitions(iArr[0], iArr[1]);
        });
        if (i3 <= 0 || i3 == iArr[0]) {
            return;
        }
        int i5 = i3;
        arrayList.forEach(columnEncoderComposite4 -> {
            columnEncoderComposite4.setNumPartitions(i5, iArr[1]);
        });
    }

    private void estimateRCMapSize(CacheBlock cacheBlock, List<ColumnEncoderComposite> list) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        int transformNumThreads = OptimizerUtils.getTransformNumThreads();
        int[] sortedSample = CompressedSizeEstimatorSample.getSortedSample(cacheBlock.getNumRows(), (int) (0.1d * cacheBlock.getNumRows()), (int) System.nanoTime(), 1);
        try {
            CommonThreadPool.get(transformNumThreads).submit(() -> {
                ((Stream) list.stream().parallel()).forEach(columnEncoderComposite -> {
                    columnEncoderComposite.computeRCDMapSizeEstimate(cacheBlock, sortedSample);
                });
            }).get();
            if (DMLScript.STATISTICS) {
                LOG.debug("Elapsed time for RC map size estimation: " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
                TransformStatistics.incMapSizeEstimationTime(System.nanoTime() - nanoTime);
            }
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private long getTotalMemOverhead(CacheBlock cacheBlock, int i, List<ColumnEncoderComposite> list) {
        long j = 0;
        if (i == 1) {
            return list.stream().mapToLong((v0) -> {
                return v0.getEstMetaSize();
            }).sum();
        }
        Iterator<ColumnEncoderComposite> it = list.iterator();
        while (it.hasNext()) {
            j += Math.min(cacheBlock.getNumRows() / i, r0.getEstNumDistincts()) * (it.next().getEstMetaSize() / r0.getEstNumDistincts()) * i;
        }
        return j;
    }

    private static void outputMatrixPreProcessing(MatrixBlock matrixBlock, CacheBlock cacheBlock, boolean z) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (!matrixBlock.isInSparseFormat()) {
            matrixBlock.allocateBlock();
        } else {
            if (MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.CSR && MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR) {
                throw new RuntimeException("Transformapply is only supported for MCSR and CSR output matrix");
            }
            if (0 != 0) {
                matrixBlock.allocateBlock();
                SparseBlock sparseBlock = matrixBlock.getSparseBlock();
                if (!z || OptimizerUtils.getTransformNumThreads() <= 1) {
                    for (int i = 0; i < matrixBlock.getNumRows(); i++) {
                        sparseBlock.allocate(i, cacheBlock.getNumColumns());
                        ((SparseRowVector) sparseBlock.get(i)).setSize(cacheBlock.getNumColumns());
                    }
                } else {
                    IntStream.range(0, matrixBlock.getNumRows()).parallel().forEach(i2 -> {
                        sparseBlock.allocate(i2, cacheBlock.getNumColumns());
                        ((SparseRowVector) sparseBlock.get(i2)).setSize(cacheBlock.getNumColumns());
                    });
                }
            } else {
                int numRows = matrixBlock.getNumRows() * cacheBlock.getNumColumns();
                SparseBlockCSR sparseBlockCSR = new SparseBlockCSR(matrixBlock.getNumRows(), numRows, numRows);
                int[] rowPointers = sparseBlockCSR.rowPointers();
                for (int i3 = 0; i3 < rowPointers.length - 1; i3++) {
                    rowPointers[i3 + 1] = rowPointers[i3] + cacheBlock.getNumColumns();
                }
                matrixBlock.setSparseBlock(sparseBlockCSR);
            }
        }
        if (DMLScript.STATISTICS) {
            LOG.debug("Elapsed time for allocation: " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
            TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime() - nanoTime);
        }
    }

    private void outputMatrixPostProcessing(MatrixBlock matrixBlock) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        int transformNumThreads = OptimizerUtils.getTransformNumThreads();
        if (transformNumThreads == 1) {
            Set set = (Set) this._columnEncoders.stream().map((v0) -> {
                return v0.getSparseRowsWZeros();
            }).flatMap(set2 -> {
                if (set2 == null) {
                    return null;
                }
                return set2.stream();
            }).collect(Collectors.toSet());
            if (!set.stream().allMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                Iterator it = set.iterator();
                while (it.hasNext()) {
                    matrixBlock.getSparseBlock().get(((Integer) it.next()).intValue()).compact();
                }
            }
        } else {
            ExecutorService executorService = CommonThreadPool.get(transformNumThreads);
            try {
                Set set3 = (Set) executorService.submit(() -> {
                    return (Set) ((Stream) this._columnEncoders.stream().parallel()).map((v0) -> {
                        return v0.getSparseRowsWZeros();
                    }).flatMap(set4 -> {
                        if (set4 == null) {
                            return null;
                        }
                        return set4.stream();
                    }).collect(Collectors.toSet());
                }).get();
                if (((Boolean) executorService.submit(() -> {
                    return Boolean.valueOf(((Stream) set3.stream().parallel()).allMatch((v0) -> {
                        return Objects.isNull(v0);
                    }));
                }).get()).booleanValue()) {
                    executorService.submit(() -> {
                        ((Stream) set3.stream().parallel()).forEach(num -> {
                            matrixBlock.getSparseBlock().get(num.intValue()).compact();
                        });
                    }).get();
                }
                executorService.shutdown();
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        matrixBlock.recomputeNonZeros();
        if (DMLScript.STATISTICS) {
            TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime() - nanoTime);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void allocateMetaData(FrameBlock frameBlock) {
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().allocateMetaData(frameBlock);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public FrameBlock getMetaData(FrameBlock frameBlock) {
        getMetaData(frameBlock, 1);
        return frameBlock;
    }

    public FrameBlock getMetaData(FrameBlock frameBlock, int i) {
        long nanoTime = System.nanoTime();
        if (this._meta != null) {
            return this._meta;
        }
        allocateMetaData(frameBlock);
        if (i > 1) {
            try {
                ExecutorService executorService = CommonThreadPool.get(i);
                ArrayList arrayList = new ArrayList();
                Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
                while (it.hasNext()) {
                    arrayList.add(new ColumnMetaDataTask(it.next(), frameBlock));
                }
                List invokeAll = executorService.invokeAll(arrayList);
                executorService.shutdown();
                Iterator it2 = invokeAll.iterator();
                while (it2.hasNext()) {
                    ((Future) it2.next()).get();
                }
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        } else {
            Iterator<ColumnEncoderComposite> it3 = this._columnEncoders.iterator();
            while (it3.hasNext()) {
                it3.next().getMetaData(frameBlock);
            }
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.getMetaData(frameBlock);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.getMetaData(frameBlock);
        }
        LOG.debug("Time spent getting metadata " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
        return frameBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void initMetaData(FrameBlock frameBlock) {
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().initMetaData(frameBlock);
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.initMetaData(frameBlock);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.initMetaData(frameBlock);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void prepareBuildPartial() {
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().prepareBuildPartial();
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void buildPartial(FrameBlock frameBlock) {
        Iterator<ColumnEncoderComposite> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().buildPartial(frameBlock);
        }
    }

    public MatrixBlock getColMapping(FrameBlock frameBlock) {
        MatrixBlock matrixBlock = new MatrixBlock(frameBlock.getNumColumns(), 3, false);
        List columnEncoders = getColumnEncoders(ColumnEncoderDummycode.class);
        int i = 0;
        for (int i2 = 0; i2 < matrixBlock.getNumRows(); i2++) {
            int i3 = i2 + 1;
            int i4 = i + 1;
            List list = (List) columnEncoders.stream().filter(columnEncoderDummycode -> {
                return columnEncoderDummycode.getColID() == i3;
            }).collect(Collectors.toList());
            if (!$assertionsDisabled && list.size() > 1) {
                throw new AssertionError();
            }
            i = list.size() == 1 ? (int) (i + frameBlock.getColumnMetadata(i2).getNumDistinct()) : i + 1;
            matrixBlock.quickSetValue(i2, 0, i3);
            matrixBlock.quickSetValue(i2, 1, i4);
            matrixBlock.quickSetValue(i2, 2, i);
        }
        return matrixBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void updateIndexRanges(long[] jArr, long[] jArr2, int i) {
        this._columnEncoders.forEach(columnEncoderComposite -> {
            columnEncoderComposite.updateIndexRanges(jArr, jArr2, i);
        });
        if (this._legacyOmit != null) {
            this._legacyOmit.updateIndexRanges(jArr, jArr2);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.updateIndexRanges(jArr, jArr2);
        }
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeBoolean(this._legacyMVImpute != null);
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.writeExternal(objectOutput);
        }
        objectOutput.writeBoolean(this._legacyOmit != null);
        if (this._legacyOmit != null) {
            this._legacyOmit.writeExternal(objectOutput);
        }
        objectOutput.writeInt(this._colOffset);
        objectOutput.writeInt(this._columnEncoders.size());
        for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
            objectOutput.writeInt(columnEncoderComposite._colID);
            columnEncoderComposite.writeExternal(objectOutput);
        }
        objectOutput.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(objectOutput);
        }
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
        if (objectInput.readBoolean()) {
            this._legacyMVImpute = new EncoderMVImpute();
            this._legacyMVImpute.readExternal(objectInput);
        }
        if (objectInput.readBoolean()) {
            this._legacyOmit = new EncoderOmit();
            this._legacyOmit.readExternal(objectInput);
        }
        this._colOffset = objectInput.readInt();
        int readInt = objectInput.readInt();
        this._columnEncoders = new ArrayList();
        for (int i = 0; i < readInt; i++) {
            int readInt2 = objectInput.readInt();
            ColumnEncoderComposite columnEncoderComposite = new ColumnEncoderComposite();
            columnEncoderComposite.readExternal(objectInput);
            columnEncoderComposite.setColID(readInt2);
            this._columnEncoders.add(columnEncoderComposite);
        }
        if (objectInput.readBoolean()) {
            FrameBlock frameBlock = new FrameBlock();
            frameBlock.readFields(objectInput);
            this._meta = frameBlock;
        }
    }

    public <T extends ColumnEncoder> List<T> getColumnEncoders(Class<T> cls) {
        ArrayList arrayList = new ArrayList();
        for (Encoder encoder : this._columnEncoders) {
            if (encoder.getClass().equals(ColumnEncoderComposite.class) && cls != ColumnEncoderComposite.class) {
                encoder = ((ColumnEncoderComposite) encoder).getEncoder(cls);
            }
            if (encoder != null && encoder.getClass().equals(cls)) {
                arrayList.add(cls.cast(encoder));
            }
        }
        return arrayList;
    }

    public <T extends ColumnEncoder> T getColumnEncoder(int i, Class<T> cls) {
        for (T t : getColumnEncoders(cls)) {
            if (t._colID == i) {
                return t;
            }
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T extends ColumnEncoder, E> List<E> getFromAll(Class<T> cls, Function<? super T, ? extends E> function) {
        return (List) getColumnEncoders(cls).stream().map(function).collect(Collectors.toList());
    }

    public <T extends ColumnEncoder> int[] getFromAllIntArray(Class<T> cls, Function<? super T, ? extends Integer> function) {
        return getFromAll(cls, function).stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
    }

    public <T extends ColumnEncoder> double[] getFromAllDoubleArray(Class<T> cls, Function<? super T, ? extends Double> function) {
        return getFromAll(cls, function).stream().mapToDouble(d -> {
            return d.doubleValue();
        }).toArray();
    }

    public List<ColumnEncoderComposite> getColumnEncoders() {
        return this._columnEncoders;
    }

    public List<ColumnEncoderComposite> getCompositeEncodersForID(int i) {
        return (List) this._columnEncoders.stream().filter(columnEncoderComposite -> {
            return columnEncoderComposite._colID == i;
        }).collect(Collectors.toList());
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes(int i) {
        HashSet hashSet = new HashSet();
        for (ColumnEncoderComposite columnEncoderComposite : this._columnEncoders) {
            if (columnEncoderComposite._colID == i || i == -1) {
                Iterator<ColumnEncoder> it = columnEncoderComposite.getEncoders().iterator();
                while (it.hasNext()) {
                    hashSet.add(it.next().getClass());
                }
            }
        }
        return new ArrayList(hashSet);
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes() {
        return getEncoderTypes(-1);
    }

    public int getNumExtraCols() {
        List columnEncoders = getColumnEncoders(ColumnEncoderDummycode.class);
        if (columnEncoders.isEmpty()) {
            return 0;
        }
        if (columnEncoders.stream().anyMatch(columnEncoderDummycode -> {
            return columnEncoderDummycode.getDomainSize() < 0;
        })) {
            throw new DMLRuntimeException("Trying to get extra columns when DC encoders are not ready");
        }
        return columnEncoders.stream().map((v0) -> {
            return v0.getDomainSize();
        }).mapToInt(num -> {
            return num.intValue();
        }).sum() - columnEncoders.size();
    }

    public int getNumExtraCols(IndexRange indexRange) {
        List list = (List) getColumnEncoders(ColumnEncoderDummycode.class).stream().filter(columnEncoderDummycode -> {
            return indexRange.inColRange(columnEncoderDummycode._colID);
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            return 0;
        }
        return list.stream().map((v0) -> {
            return v0.getDomainSize();
        }).mapToInt(num -> {
            return num.intValue();
        }).sum() - list.size();
    }

    public <T extends ColumnEncoder> boolean containsEncoderForID(int i, Class<T> cls) {
        return getColumnEncoders(cls).stream().anyMatch(columnEncoder -> {
            return columnEncoder.getColID() == i;
        });
    }

    public <T extends ColumnEncoder, E> void applyToAll(Class<T> cls, Consumer<? super T> consumer) {
        getColumnEncoders(cls).forEach(consumer);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Consumer<? super ColumnEncoderComposite> consumer) {
        getColumnEncoders().forEach(consumer);
    }

    public MultiColumnEncoder subRangeEncoder(IndexRange indexRange) {
        ArrayList arrayList = new ArrayList();
        long j = indexRange.colStart;
        while (true) {
            long j2 = j;
            if (j2 >= indexRange.colEnd) {
                break;
            }
            arrayList.addAll(getCompositeEncodersForID((int) j2));
            j = j2 + 1;
        }
        MultiColumnEncoder multiColumnEncoder = new MultiColumnEncoder(arrayList);
        multiColumnEncoder._colOffset = ((int) (-indexRange.colStart)) + 1;
        if (this._legacyOmit != null) {
            multiColumnEncoder.addReplaceLegacyEncoder(this._legacyOmit.subRangeEncoder(indexRange));
        }
        if (this._legacyMVImpute != null) {
            multiColumnEncoder.addReplaceLegacyEncoder(this._legacyMVImpute.subRangeEncoder(indexRange));
        }
        return multiColumnEncoder;
    }

    public <T extends ColumnEncoder> MultiColumnEncoder subRangeEncoder(IndexRange indexRange, Class<T> cls) {
        ArrayList arrayList = new ArrayList();
        long j = indexRange.colStart;
        while (true) {
            long j2 = j;
            if (j2 >= indexRange.colEnd) {
                break;
            }
            arrayList.add(getColumnEncoder((int) j2, cls));
            j = j2 + 1;
        }
        return cls.equals(ColumnEncoderComposite.class) ? new MultiColumnEncoder((List) arrayList.stream().map(columnEncoder -> {
            return (ColumnEncoderComposite) columnEncoder;
        }).collect(Collectors.toList())) : new MultiColumnEncoder((List) arrayList.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList()));
    }

    public void mergeReplace(MultiColumnEncoder multiColumnEncoder) {
        for (ColumnEncoderComposite columnEncoderComposite : multiColumnEncoder._columnEncoders) {
            ColumnEncoderComposite columnEncoderComposite2 = (ColumnEncoderComposite) getColumnEncoder(columnEncoderComposite._colID, columnEncoderComposite.getClass());
            if (columnEncoderComposite2 != null) {
                this._columnEncoders.remove(columnEncoderComposite2);
            }
            this._columnEncoders.add(columnEncoderComposite);
        }
    }

    public void mergeAt(Encoder encoder, int i, int i2) {
        if (!(encoder instanceof MultiColumnEncoder)) {
            addEncoder((ColumnEncoder) encoder, i);
            return;
        }
        Iterator<ColumnEncoderComposite> it = ((MultiColumnEncoder) encoder)._columnEncoders.iterator();
        while (it.hasNext()) {
            addEncoder(it.next(), i);
        }
        legacyMergeAt((MultiColumnEncoder) encoder, i2, i + 1);
    }

    private void legacyMergeAt(MultiColumnEncoder multiColumnEncoder, int i, int i2) {
        if (multiColumnEncoder._legacyOmit != null) {
            multiColumnEncoder._legacyOmit.shiftCols(i2 - 1);
        }
        if (multiColumnEncoder._legacyOmit != null) {
            if (this._legacyOmit == null) {
                this._legacyOmit = new EncoderOmit();
            }
            this._legacyOmit.mergeAt(multiColumnEncoder._legacyOmit, i, i2);
        }
        if (multiColumnEncoder._legacyMVImpute != null) {
            multiColumnEncoder._legacyMVImpute.shiftCols(i2 - 1);
        }
        if (this._legacyMVImpute != null && multiColumnEncoder._legacyMVImpute != null) {
            this._legacyMVImpute.mergeAt(multiColumnEncoder._legacyMVImpute, i, i2);
        } else if (this._legacyMVImpute == null) {
            this._legacyMVImpute = multiColumnEncoder._legacyMVImpute;
        }
    }

    private void addEncoder(ColumnEncoder columnEncoder, int i) {
        int i2 = columnEncoder._colID + i;
        ColumnEncoder columnEncoder2 = getColumnEncoder(i2, columnEncoder.getClass());
        if (columnEncoder2 != null) {
            columnEncoder.shiftCol(i);
            columnEncoder2.mergeAt(columnEncoder);
            return;
        }
        ColumnEncoderComposite columnEncoderComposite = (ColumnEncoderComposite) getColumnEncoder(i2, ColumnEncoderComposite.class);
        if (columnEncoderComposite != null) {
            columnEncoder.shiftCol(i);
            columnEncoderComposite.mergeAt(columnEncoder);
            return;
        }
        columnEncoder.shiftCol(i);
        if (columnEncoder instanceof ColumnEncoderComposite) {
            this._columnEncoders.add((ColumnEncoderComposite) columnEncoder);
        } else {
            this._columnEncoders.add(new ColumnEncoderComposite(columnEncoder));
        }
    }

    public <T extends LegacyEncoder> void addReplaceLegacyEncoder(T t) {
        if (t.getClass() == EncoderMVImpute.class) {
            this._legacyMVImpute = (EncoderMVImpute) t;
        } else {
            if (!t.getClass().equals(EncoderOmit.class)) {
                throw new DMLRuntimeException("Tried to add non legacy Encoder");
            }
            this._legacyOmit = (EncoderOmit) t;
        }
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder() {
        return hasLegacyEncoder(EncoderMVImpute.class) || hasLegacyEncoder(EncoderOmit.class);
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T> cls) {
        if (cls.equals(EncoderMVImpute.class)) {
            return this._legacyMVImpute != null;
        }
        if (cls.equals(EncoderOmit.class)) {
            return this._legacyOmit != null;
        }
        if ($assertionsDisabled) {
            return false;
        }
        throw new AssertionError();
    }

    public <T extends LegacyEncoder> T getLegacyEncoder(Class<T> cls) {
        if (cls.equals(EncoderMVImpute.class)) {
            return cls.cast(this._legacyMVImpute);
        }
        if (cls.equals(EncoderOmit.class)) {
            return cls.cast(this._legacyOmit);
        }
        if ($assertionsDisabled) {
            return null;
        }
        throw new AssertionError();
    }

    public void applyColumnOffset() {
        applyToAll(columnEncoderComposite -> {
            columnEncoderComposite.shiftCol(this._colOffset);
        });
        if (this._legacyOmit != null) {
            this._legacyOmit.shiftCols(this._colOffset);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.shiftCols(this._colOffset);
        }
    }

    static {
        $assertionsDisabled = !MultiColumnEncoder.class.desiredAssertionStatus();
        LOG = LogFactory.getLog(MultiColumnEncoder.class.getName());
        MULTI_THREADED_STAGES = ConfigurationManager.isStagedParallelTransform();
        APPLY_ENCODER_SEPARATE_STAGES = false;
    }
}
