package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.class */
public class ColumnEncoderComposite extends ColumnEncoder {
    private static final long serialVersionUID = -8473768154646831882L;
    private List<ColumnEncoder> _columnEncoders;
    private FrameBlock _meta;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite$ColumnCompositeUpdateDCTask.class */
    private static class ColumnCompositeUpdateDCTask implements Callable<Object> {
        private final ColumnEncoderComposite _encoder;

        protected ColumnCompositeUpdateDCTask(ColumnEncoderComposite columnEncoderComposite) {
            this._encoder = columnEncoderComposite;
        }

        @Override // java.util.concurrent.Callable
        /* renamed from: call, reason: merged with bridge method [inline-methods] */
        public Object call2() throws Exception {
            this._encoder.updateAllDCEncoders();
            return null;
        }

        public String toString() {
            return getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    public ColumnEncoderComposite() {
        super(-1);
        this._columnEncoders = null;
        this._meta = null;
    }

    public ColumnEncoderComposite(List<ColumnEncoder> list, FrameBlock frameBlock) {
        super(-1);
        this._columnEncoders = null;
        this._meta = null;
        if (list.size() <= 0 || !list.stream().allMatch(columnEncoder -> {
            return columnEncoder._colID == ((ColumnEncoder) list.get(0))._colID;
        })) {
            throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columIDs");
        }
        this._colID = list.get(0)._colID;
        this._meta = frameBlock;
        this._columnEncoders = list;
    }

    public ColumnEncoderComposite(List<ColumnEncoder> list) {
        this(list, null);
    }

    public ColumnEncoderComposite(ColumnEncoder columnEncoder) {
        super(columnEncoder._colID);
        this._columnEncoders = null;
        this._meta = null;
        this._columnEncoders = new ArrayList();
        this._columnEncoders.add(columnEncoder);
    }

    public List<ColumnEncoder> getEncoders() {
        return this._columnEncoders;
    }

    public <T extends ColumnEncoder> T getEncoder(Class<T> cls) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (columnEncoder.getClass().equals(cls)) {
                return cls.cast(columnEncoder);
            }
        }
        return null;
    }

    public boolean isEncoder(int i, Class<?> cls) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (columnEncoder.getClass().equals(cls) && columnEncoder._colID == i) {
                return true;
            }
        }
        return false;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void build(CacheBlock cacheBlock) {
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().build(cacheBlock);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public void build(CacheBlock cacheBlock, Map<Integer, double[]> map) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if ((columnEncoder instanceof ColumnEncoderBin) && ((ColumnEncoderBin) columnEncoder).getBinMethod() == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
                columnEncoder.build(cacheBlock, map.get(Integer.valueOf(columnEncoder.getColID())));
            } else {
                columnEncoder.build(cacheBlock);
            }
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public List<DependencyTask<?>> getApplyTasks(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i2 = 0;
        while (i2 < this._columnEncoders.size()) {
            List<DependencyTask<?>> applyTasks = i2 == 0 ? this._columnEncoders.get(i2).getApplyTasks(cacheBlock, matrixBlock, i) : this._columnEncoders.get(i2).getApplyTasks(matrixBlock, matrixBlock, i);
            if (applyTasks != null) {
                arrayList2.add(Integer.valueOf(applyTasks.size()));
                arrayList.addAll(applyTasks);
            }
            i2++;
        }
        ArrayList arrayList3 = new ArrayList(Collections.nCopies(arrayList.size(), null));
        int i3 = 0;
        int intValue = ((Integer) arrayList2.get(0)).intValue();
        while (true) {
            int i4 = intValue;
            if (i4 >= arrayList.size()) {
                return DependencyThreadPool.createDependencyTasks(arrayList, arrayList3);
            }
            for (int i5 = i4; i5 < i4 + ((Integer) arrayList2.get(i3 + 1)).intValue(); i5++) {
                arrayList3.set(i5, arrayList.subList(i4 - 1, i4));
            }
            i3++;
            intValue = i4 + ((Integer) arrayList2.get(i3)).intValue();
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected ColumnEncoder.ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public List<DependencyTask<?>> getBuildTasks(CacheBlock cacheBlock) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = null;
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            List<DependencyTask<?>> buildTasks = it.next().getBuildTasks(cacheBlock);
            if (buildTasks != null) {
                if (arrayList.size() != 0) {
                    hashMap = hashMap == null ? new HashMap() : hashMap;
                    hashMap.put(new Integer[]{Integer.valueOf(arrayList.size()), Integer.valueOf(arrayList.size() + buildTasks.size())}, new Integer[]{Integer.valueOf(arrayList.size() - 1), Integer.valueOf(arrayList.size())});
                }
                arrayList.addAll(buildTasks);
            }
        }
        ArrayList arrayList2 = new ArrayList(Collections.nCopies(arrayList.size(), null));
        DependencyThreadPool.createDependencyList(arrayList, hashMap, arrayList2);
        if (hasEncoder(ColumnEncoderDummycode.class)) {
            arrayList.add(DependencyThreadPool.createDependencyTask(new ColumnCompositeUpdateDCTask(this)));
            if (this._columnEncoders.get(0) instanceof ColumnEncoderRecode) {
                arrayList2.add(arrayList.subList(arrayList.size() - 2, arrayList.size() - 1));
                return DependencyThreadPool.createDependencyTasks(arrayList, arrayList2);
            }
        }
        return DependencyThreadPool.createDependencyTasks(arrayList, null);
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, org.apache.sysds.runtime.transform.encode.Encoder
    public void prepareBuildPartial() {
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().prepareBuildPartial();
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, org.apache.sysds.runtime.transform.encode.Encoder
    public void buildPartial(FrameBlock frameBlock) {
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().buildPartial(frameBlock);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public MatrixBlock apply(CacheBlock cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        for (int i4 = 0; i4 < this._columnEncoders.size(); i4++) {
            try {
                if (i4 == 0) {
                    this._columnEncoders.get(i4).apply(cacheBlock, matrixBlock, i, i2, i3);
                } else {
                    this._columnEncoders.get(i4).apply(matrixBlock, matrixBlock, i, i2, i3);
                }
            } catch (Exception e) {
                LOG.error("Failed to transform-apply frame with \n" + this);
                throw e;
            }
        }
        return matrixBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected double getCode(CacheBlock cacheBlock, int i) {
        throw new DMLRuntimeException("CompositeEncoder does not have a Code");
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected double[] getCodeCol(CacheBlock cacheBlock, int i, int i2) {
        throw new DMLRuntimeException("CompositeEncoder does not have a Code");
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.N_A;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        ColumnEncoderComposite columnEncoderComposite = (ColumnEncoderComposite) obj;
        return this._columnEncoders.equals(columnEncoderComposite._columnEncoders) && Objects.equals(this._meta, columnEncoderComposite._meta);
    }

    public int hashCode() {
        return Objects.hash(this._columnEncoders, this._meta);
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public void mergeAt(ColumnEncoder columnEncoder) {
        if (columnEncoder instanceof ColumnEncoderComposite) {
            ColumnEncoderComposite columnEncoderComposite = (ColumnEncoderComposite) columnEncoder;
            if (!$assertionsDisabled && columnEncoderComposite._colID != this._colID) {
                throw new AssertionError();
            }
            Iterator<ColumnEncoder> it = columnEncoderComposite.getEncoders().iterator();
            while (it.hasNext()) {
                addEncoder(it.next());
            }
        } else {
            addEncoder(columnEncoder);
        }
        updateAllDCEncoders();
    }

    public void updateAllDCEncoders() {
        ColumnEncoderDummycode columnEncoderDummycode = (ColumnEncoderDummycode) getEncoder(ColumnEncoderDummycode.class);
        if (columnEncoderDummycode != null) {
            columnEncoderDummycode.updateDomainSizes(this._columnEncoders);
        }
        ColumnEncoderUDF columnEncoderUDF = (ColumnEncoderUDF) getEncoder(ColumnEncoderUDF.class);
        if (columnEncoderUDF == null || columnEncoderDummycode == null) {
            return;
        }
        columnEncoderUDF.updateDomainSizes(this._columnEncoders);
    }

    public void addEncoder(ColumnEncoder columnEncoder) {
        ColumnEncoder encoder = getEncoder(columnEncoder.getClass());
        if (!$assertionsDisabled && this._colID != columnEncoder._colID) {
            throw new AssertionError();
        }
        if (encoder != null) {
            encoder.mergeAt(columnEncoder);
        } else {
            this._columnEncoders.add(columnEncoder);
            this._columnEncoders.sort(null);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, org.apache.sysds.runtime.transform.encode.Encoder
    public void updateIndexRanges(long[] jArr, long[] jArr2, int i) {
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().updateIndexRanges(jArr, jArr2, i);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void allocateMetaData(FrameBlock frameBlock) {
        if (this._meta != null) {
            return;
        }
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().allocateMetaData(frameBlock);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public FrameBlock getMetaData(FrameBlock frameBlock) {
        if (this._meta != null) {
            return this._meta;
        }
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().getMetaData(frameBlock);
        }
        return frameBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void initMetaData(FrameBlock frameBlock) {
        Iterator<ColumnEncoder> it = this._columnEncoders.iterator();
        while (it.hasNext()) {
            it.next().initMetaData(frameBlock);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("CompositeEncoder(").append(this._columnEncoders.size()).append("):\n");
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            sb.append("-- ");
            sb.append(columnEncoder.getClass().getSimpleName());
            sb.append(": ");
            sb.append(columnEncoder._colID);
            sb.append(ProgramConverter.NEWLINE);
        }
        return sb.toString();
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeInt(this._columnEncoders.size());
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            objectOutput.writeInt(columnEncoder._colID);
            objectOutput.writeByte(EncoderFactory.getEncoderType(columnEncoder));
            columnEncoder.writeExternal(objectOutput);
        }
        objectOutput.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(objectOutput);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException {
        int readInt = objectInput.readInt();
        this._columnEncoders = new ArrayList();
        for (int i = 0; i < readInt; i++) {
            int readInt2 = objectInput.readInt();
            ColumnEncoder createInstance = EncoderFactory.createInstance(objectInput.readByte());
            createInstance.readExternal(objectInput);
            createInstance.setColID(readInt2);
            this._columnEncoders.add(createInstance);
        }
        if (objectInput.readBoolean()) {
            FrameBlock frameBlock = new FrameBlock();
            frameBlock.readFields(objectInput);
            this._meta = frameBlock;
        }
    }

    public <T extends ColumnEncoder> boolean hasEncoder(Class<T> cls) {
        return this._columnEncoders.stream().anyMatch(columnEncoder -> {
            return columnEncoder.getClass().equals(cls);
        });
    }

    public <T extends ColumnEncoder> boolean hasBuild() {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (columnEncoder.getClass().equals(ColumnEncoderRecode.class) || columnEncoder.getClass().equals(ColumnEncoderDummycode.class) || columnEncoder.getClass().equals(ColumnEncoderBin.class)) {
                return true;
            }
        }
        return false;
    }

    public void computeRCDMapSizeEstimate(CacheBlock cacheBlock, int[] iArr) {
        int i = 0;
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (columnEncoder.getClass().equals(ColumnEncoderRecode.class)) {
                ((ColumnEncoderRecode) columnEncoder).computeRCDMapSizeEstimate(cacheBlock, iArr);
                i = columnEncoder.getEstNumDistincts();
            }
        }
        setEstMetaSize(this._columnEncoders.stream().mapToLong((v0) -> {
            return v0.getEstMetaSize();
        }).sum());
        setEstNumDistincts(i);
    }

    public void setNumPartitions(int i, int i2) {
        this._columnEncoders.forEach(columnEncoder -> {
            columnEncoder.setBuildRowBlocksPerColumn(i);
            if (columnEncoder.getClass().equals(ColumnEncoderUDF.class)) {
                columnEncoder.setApplyRowBlocksPerColumn(1);
            } else {
                columnEncoder.setApplyRowBlocksPerColumn(i2);
            }
        });
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public void shiftCol(int i) {
        super.shiftCol(i);
        this._columnEncoders.forEach(columnEncoder -> {
            columnEncoder.shiftCol(i);
        });
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public Set<Integer> getSparseRowsWZeros() {
        return (Set) this._columnEncoders.stream().map((v0) -> {
            return v0.getSparseRowsWZeros();
        }).flatMap(set -> {
            if (set == null) {
                return null;
            }
            return set.stream();
        }).collect(Collectors.toSet());
    }

    static {
        $assertionsDisabled = !ColumnEncoderComposite.class.desiredAssertionStatus();
    }
}
