package org.apache.sysds.runtime.compress.colgroup.mapping;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.columns.BitSetArray;
import org.apache.sysds.utils.MemoryEstimates;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.class */
public class MapToBit extends AMapToData {
    private static final long serialVersionUID = -8065234231282619903L;
    private final long[] _data;
    private final int _size;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit$JoinBitSets.class */
    public static class JoinBitSets {
        int tt = 0;
        int ft = 0;
        int tf = 0;
        int ff = 0;

        protected JoinBitSets(MapToBit mapToBit, MapToBit mapToBit2, int i) {
            long[] jArr = mapToBit._data;
            long[] jArr2 = mapToBit2._data;
            int min = Math.min(jArr.length, jArr2.length);
            for (int i2 = 0; i2 < min; i2++) {
                long j = jArr[i2];
                long j2 = jArr2[i2];
                this.tt += Long.bitCount(j & j2);
                this.ft += Long.bitCount(j & (j2 ^ (-1)));
                this.tf += Long.bitCount((j ^ (-1)) & j2);
                this.ff += Long.bitCount((j ^ (-1)) & (j2 ^ (-1)));
            }
            if (jArr.length > min) {
                for (int i3 = min; i3 < jArr.length; i3++) {
                    int bitCount = Long.bitCount(jArr[i3]);
                    this.ft += bitCount;
                    this.ff += 64 - bitCount;
                }
            } else if (jArr2.length > min) {
                for (int i4 = min; i4 < jArr2.length; i4++) {
                    int bitCount2 = Long.bitCount(jArr2[i4]);
                    this.tf += bitCount2;
                    this.ff += 64 - bitCount2;
                }
            }
            this.ff += i - (Math.max(jArr.length, jArr2.length) * 64);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MapToBit(int i) {
        this(2, i);
    }

    public MapToBit(int i, int i2) {
        super(Math.min(i, 2));
        this._data = new long[longSize(i2)];
        this._size = i2;
    }

    private MapToBit(int i, BitSet bitSet, int i2) {
        super(i);
        long[] longArray = bitSet.toLongArray();
        if (longArray.length == longSize(i2)) {
            this._data = longArray;
        } else {
            this._data = new long[longSize(i2)];
            System.arraycopy(longArray, 0, this._data, 0, longArray.length);
        }
        this._size = i2;
    }

    private MapToBit(int i, long[] jArr, int i2) {
        super(i);
        if (jArr.length == longSize(i2)) {
            this._data = jArr;
        } else {
            this._data = new long[longSize(i2)];
            System.arraycopy(jArr, 0, this._data, 0, jArr.length);
        }
        this._size = i2;
    }

    protected long[] getData() {
        return this._data;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public MapToFactory.MAP_TYPE getType() {
        return MapToFactory.MAP_TYPE.BIT;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int getIndex(int i) {
        return (this._data[i >> 6] & (1 << i)) != 0 ? 1 : 0;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void fill(int i) {
        long length = (this._data.length * 64) - this._size;
        if (length == 0 || i == 0) {
            Arrays.fill(this._data, i == 0 ? 0L : -1L);
        } else {
            Arrays.fill(this._data, 0, this._data.length - 1, i == 0 ? 0L : -1L);
            this._data[this._data.length - 1] = (-1) >>> ((int) length);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public long getInMemorySize() {
        return getInMemorySize(this._size);
    }

    public static long getInMemorySize(int i) {
        return (long) (20 + MemoryEstimates.longArrayCost(i >> 7));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void set(int i, int i2) {
        int i3 = i >> 6;
        if (i2 == 1) {
            long[] jArr = this._data;
            jArr[i3] = jArr[i3] | (1 << i);
        } else {
            long[] jArr2 = this._data;
            jArr2[i3] = jArr2[i3] & ((1 << i) ^ (-1));
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int setAndGet(int i, int i2) {
        set(i, i2);
        return 1;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int size() {
        return this._size;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void replace(int i, int i2) {
        if (i == 0) {
            fill(1);
        } else {
            fill(0);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public long getExactSizeOnDisk() {
        return 9 + (this._data.length * 8);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeByte(MapToFactory.MAP_TYPE.BIT.ordinal());
        dataOutput.writeInt(this._size);
        dataOutput.writeInt(this._data.length);
        for (int i = 0; i < this._data.length; i++) {
            dataOutput.writeLong(this._data[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static MapToBit readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        long[] jArr = new long[dataInput.readInt()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = dataInput.readLong();
        }
        return new MapToBit(2, jArr, readInt);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int getUpperBoundValue() {
        return 1;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int[] getCounts(int[] iArr) {
        int size = size();
        for (int i = 0; i < this._data.length; i++) {
            iArr[1] = iArr[1] + Long.bitCount(this._data[i]);
        }
        iArr[0] = size - iArr[1];
        return iArr;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void preAggregateDDC_DDCSingleCol(AMapToData aMapToData, double[] dArr, double[] dArr2) {
        if (aMapToData instanceof MapToBit) {
            preAggregateDDCSingleColBitBit((MapToBit) aMapToData, dArr, dArr2);
        } else {
            super.preAggregateDDC_DDCSingleCol(aMapToData, dArr, dArr2);
        }
    }

    private void preAggregateDDCSingleColBitBit(MapToBit mapToBit, double[] dArr, double[] dArr2) {
        JoinBitSets joinBitSets = new JoinBitSets(mapToBit, this, this._size);
        dArr2[1] = dArr2[1] + (dArr[1] * joinBitSets.tt);
        dArr2[0] = dArr2[0] + (dArr[1] * joinBitSets.ft);
        dArr2[1] = dArr2[1] + (dArr[0] * joinBitSets.tf);
        dArr2[0] = dArr2[0] + (dArr[0] * joinBitSets.ff);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void preAggregateDDC_DDCMultiCol(AMapToData aMapToData, IDictionary iDictionary, double[] dArr, int i) {
        if (aMapToData instanceof MapToBit) {
            preAggregateDDCMultiColBitBit((MapToBit) aMapToData, iDictionary, dArr, i);
        } else {
            super.preAggregateDDC_DDCMultiCol(aMapToData, iDictionary, dArr, i);
        }
    }

    private void preAggregateDDCMultiColBitBit(MapToBit mapToBit, IDictionary iDictionary, double[] dArr, int i) {
        JoinBitSets joinBitSets = new JoinBitSets(mapToBit, this, this._size);
        double[] values = iDictionary.getValues();
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = i + i2;
            int i4 = i2;
            dArr[i4] = dArr[i4] + (values[i2] * joinBitSets.ff);
            dArr[i3] = dArr[i3] + (values[i2] * joinBitSets.tf);
            dArr[i3] = dArr[i3] + (values[i3] * joinBitSets.tt);
            int i5 = i2;
            dArr[i5] = dArr[i5] + (values[i3] * joinBitSets.ft);
        }
    }

    public boolean isEmpty() {
        for (int i = 0; i < this._data.length; i++) {
            if (this._data[i] != 0) {
                return false;
            }
        }
        return true;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void copy(AMapToData aMapToData) {
        if (aMapToData instanceof MapToInt) {
            copyInt((MapToInt) aMapToData);
            return;
        }
        int size = size();
        for (int i = 0; i < size; i++) {
            set(i, aMapToData.getIndex(i));
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void copyInt(int[] iArr) {
        for (int i = 0; i < this._size; i++) {
            set(i, iArr[i]);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public void copyBit(BitSet bitSet) {
        long[] longArray = bitSet.toLongArray();
        System.arraycopy(longArray, 0, this._data, 0, longArray.length);
        if (longArray.length < this._data.length) {
            Arrays.fill(this._data, longArray.length, this._data.length, 0L);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public AMapToData resize(int i) {
        return i <= 1 ? new MapToZero(size()) : this;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int countRuns() {
        if (this._size <= 64) {
            long j = this._data[0];
            if (this._size != 64 && getIndex(this._size - 1) == 1) {
                j |= ((-1) ^ ((-1) << (this._size - 64))) ^ (-1);
            }
            return 1 + Long.bitCount(j ^ ((j << 1) | (j & 1)));
        }
        long[] jArr = this._data;
        long j2 = jArr[0];
        int bitCount = 1 + Long.bitCount(j2 ^ ((j2 << 1) | (j2 & 1)));
        for (int i = 1; i < jArr.length - 1; i++) {
            bitCount += Long.bitCount(jArr[i] ^ ((jArr[i] << 1) | ((jArr[i - 1] & Long.MIN_VALUE) >>> 63)));
        }
        int length = jArr.length - 1;
        long j3 = (this._size % 64 == 0 || getIndex(this._size - 1) != 1) ? jArr[length] : jArr[length] | (((-1) ^ ((-1) << (this._size - 64))) ^ (-1));
        return bitCount + Long.bitCount(j3 ^ ((j3 << 1) | ((jArr[length - 1] & Long.MIN_VALUE) >>> 63)));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public AMapToData slice(int i, int i2) {
        return new MapToBit(getUnique(), BitSetArray.sliceVectorized(this._data, i, i2), i2 - i);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public AMapToData append(AMapToData aMapToData) {
        if (!(aMapToData instanceof MapToBit)) {
            throw new NotImplementedException("Not implemented append on Bit map different type");
        }
        int size = this._size + aMapToData.size();
        long[] jArr = new long[longSize(size)];
        System.arraycopy(this._data, 0, jArr, 0, this._data.length);
        BitSetArray.setVectorizedLongs(this._size, size, jArr, ((MapToBit) aMapToData)._data);
        return new MapToBit(2, jArr, size);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public boolean equals(AMapToData aMapToData) {
        return (aMapToData instanceof MapToBit) && aMapToData.getUnique() == getUnique() && ((MapToBit) aMapToData)._size == this._size && Arrays.equals(((MapToBit) aMapToData)._data, this._data);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public AMapToData appendN(IMapToDataGroup[] iMapToDataGroupArr) {
        int i = 0;
        for (IMapToDataGroup iMapToDataGroup : iMapToDataGroupArr) {
            i += iMapToDataGroup.getMapToData().size();
        }
        long[] jArr = new long[longSize(i)];
        int i2 = 0;
        for (int i3 = 0; i3 < iMapToDataGroupArr.length; i3++) {
            if (iMapToDataGroupArr[i3].getMapToData().size() > 0) {
                MapToBit mapToBit = (MapToBit) iMapToDataGroupArr[i3].getMapToData();
                int size = mapToBit.size();
                BitSetArray.setVectorizedLongs(i2, i2 + size, jArr, mapToBit._data);
                i2 += size;
            }
        }
        return new MapToBit(getUnique(), BitSet.valueOf(jArr), i2);
    }

    private static int longSize(int i) {
        return Math.max(i >> 6, 0) + 1;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public int getMaxPossible() {
        return 2;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData
    public String toString() {
        return super.toString() + (" size: " + this._size) + " longLength:[" + this._data.length + "]";
    }
}
