/*
 * Decompiled with CFR 0.152.
 */
package org.fastfilter.bloom.count;

import org.fastfilter.Filter;
import org.fastfilter.bloom.count.Select;
import org.fastfilter.utils.Hash;

public class SuccinctCountingBlockedBloom
implements Filter {
    private static final boolean VERIFY_COUNTS = false;
    private final int buckets;
    private final long seed;
    private final long[] data;
    private final long[] counts;
    private int nextFreeOverflow;
    private final long[] overflow;
    private final byte[] realCounts;

    public static SuccinctCountingBlockedBloom construct(long[] keys, int bitsPerKey) {
        long n = keys.length;
        int k = SuccinctCountingBlockedBloom.getBestK(bitsPerKey);
        SuccinctCountingBlockedBloom f = new SuccinctCountingBlockedBloom((int)n, bitsPerKey, k);
        for (long x : keys) {
            f.add(x);
        }
        return f;
    }

    private static int getBestK(double bitsPerKey) {
        return Math.max(1, (int)Math.round(bitsPerKey * Math.log(2.0)));
    }

    @Override
    public long getBitCount() {
        return 64L * (long)this.data.length + 64L * (long)this.counts.length + 64L * (long)this.overflow.length;
    }

    SuccinctCountingBlockedBloom(int entryCount, int bitsPerKey, int k) {
        entryCount = Math.max(1, entryCount);
        this.seed = Hash.randomSeed();
        long bits = (long)entryCount * (long)bitsPerKey;
        this.buckets = (int)bits / 64;
        int arrayLength = this.buckets + 16 + 1;
        this.data = new long[arrayLength];
        this.counts = new long[arrayLength];
        this.overflow = new long[100 + arrayLength * 10 / 100];
        for (int i = 0; i < this.overflow.length; i += 8) {
            this.overflow[i] = i + 8;
        }
        this.realCounts = null;
    }

    @Override
    public boolean supportsAdd() {
        return true;
    }

    @Override
    public void add(long key) {
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        hash ^= Long.rotateLeft(hash, 32);
        int a1 = (int)(hash & 0x3FL);
        int a2 = (int)(hash >> 6 & 0x3FL);
        this.increment(start, a1);
        if (a2 != a1) {
            this.increment(start, a2);
        }
        int second = start + 1 + (int)(hash >>> 60);
        int a3 = (int)(hash >> 12 & 0x3FL);
        int a4 = (int)(hash >> 18 & 0x3FL);
        this.increment(second, a3);
        if (a4 != a3) {
            this.increment(second, a4);
        }
    }

    @Override
    public boolean supportsRemove() {
        return true;
    }

    @Override
    public void remove(long key) {
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        hash ^= Long.rotateLeft(hash, 32);
        int a1 = (int)(hash & 0x3FL);
        int a2 = (int)(hash >> 6 & 0x3FL);
        this.decrement(start, a1);
        if (a2 != a1) {
            this.decrement(start, a2);
        }
        int second = start + 1 + (int)(hash >>> 60);
        int a3 = (int)(hash >> 12 & 0x3FL);
        int a4 = (int)(hash >> 18 & 0x3FL);
        this.decrement(second, a3);
        if (a4 != a3) {
            this.decrement(second, a4);
        }
    }

    @Override
    public long cardinality() {
        long sum = 0L;
        for (long x : this.data) {
            sum += (long)Long.bitCount(x);
        }
        for (long x : this.counts) {
            sum += (long)Long.bitCount(x);
        }
        return sum;
    }

    @Override
    public boolean mayContain(long key) {
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        hash ^= Long.rotateLeft(hash, 32);
        long a = this.data[start];
        long b = this.data[start + 1 + (int)(hash >>> 60)];
        long m1 = 1L << (int)hash | 1L << (int)(hash >> 6);
        long m2 = 1L << (int)(hash >> 12) | 1L << (int)(hash >> 18);
        return (m1 & a) == m1 && (m2 & b) == m2;
    }

    private void increment(int group, int x) {
        long m = this.data[group];
        long d = m >>> x & 1L;
        long c = this.counts[group];
        if ((c & 0xC000000000000000L) != 0L) {
            int index;
            if ((c & Long.MIN_VALUE) == 0L) {
                index = this.allocateOverflow();
                for (int i = 0; i < 64; ++i) {
                    int n = this.readCount((group << 6) + i);
                    int n2 = index + i / 8;
                    this.overflow[n2] = this.overflow[n2] + (long)n * SuccinctCountingBlockedBloom.getBit(i);
                }
                long count = 64L;
                c = Long.MIN_VALUE | count << 32 | (long)index;
            } else {
                index = (int)(c & 0xFFFFFFFL);
                c += 0x100000000L;
            }
            this.counts[group] = c;
            int bitIndex = x & 0x3F;
            int n = index + bitIndex / 8;
            this.overflow[n] = this.overflow[n] + SuccinctCountingBlockedBloom.getBit(bitIndex);
            int n3 = group;
            this.data[n3] = this.data[n3] | 1L << x;
            return;
        }
        int n = group;
        this.data[n] = this.data[n] | 1L << x;
        int bitsBefore = Long.bitCount(m & -1L >>> 63 - x);
        int before = Select.selectInLong(c << 1 | 1L, bitsBefore);
        int insertAt = before - (int)d;
        long mask = (1L << insertAt) - 1L;
        long left = c & (mask ^ 0xFFFFFFFFFFFFFFFFL);
        long right = c & mask;
        this.counts[group] = c = left << 1 | (1L ^ d) << insertAt | right;
    }

    private int allocateOverflow() {
        int result = this.nextFreeOverflow;
        this.nextFreeOverflow = (int)this.overflow[result];
        for (int i = 0; i < 8; ++i) {
            this.overflow[result + i] = 0L;
        }
        return result;
    }

    private void decrement(int group, int x) {
        long m = this.data[group];
        long c = this.counts[group];
        if ((c & Long.MIN_VALUE) != 0L) {
            int count = (int)(c >>> 32) & 0xFFFFFFF;
            this.counts[group] = c -= 0x100000000L;
            int index = (int)(c & 0xFFFFFFFL);
            int bitIndex = x & 0x3F;
            long n = this.overflow[index + bitIndex / 8];
            this.overflow[index + bitIndex / 8] = n - SuccinctCountingBlockedBloom.getBit(bitIndex);
            if (((n >>>= 8 * (bitIndex & 7)) & 0xFFL) == 1L) {
                int n2 = group;
                this.data[n2] = this.data[n2] & (1L << x ^ 0xFFFFFFFFFFFFFFFFL);
            }
            if (count < 64) {
                long c2 = 0L;
                for (int j = 63; j >= 0; --j) {
                    int cj = (int)(this.overflow[index + j / 8] >>> 8 * j & 0xFFL);
                    if (cj <= 0) continue;
                    c2 = (c2 << 1 | 1L) << cj - 1;
                }
                this.counts[group] = c2;
                this.freeOverflow(index);
            }
            return;
        }
        int bitsBefore = Long.bitCount(m & -1L >>> 63 - x);
        int before = Select.selectInLong(c << 1 | 1L, bitsBefore) - 1;
        int removeAt = Math.max(0, before - 1);
        long mask = (1L << removeAt) - 1L;
        long left = c >>> 1 & (mask ^ 0xFFFFFFFFFFFFFFFFL);
        long right = c & mask;
        this.counts[group] = left | right;
        long removed = c >> removeAt & 1L;
        this.data[group] = m & (removed << x ^ 0xFFFFFFFFFFFFFFFFL);
    }

    private void freeOverflow(int index) {
        this.overflow[index] = this.nextFreeOverflow;
        this.nextFreeOverflow = index;
    }

    private static long getBit(int index) {
        return 1L << index * 8;
    }

    private void verifyCounts(int from, int to) {
    }

    private int readCount(int x) {
        int group = x >>> 6;
        long m = this.data[group];
        long d = m >>> x & 1L;
        if (d == 0L) {
            return 0;
        }
        long c = this.counts[group];
        if ((c & Long.MIN_VALUE) != 0L) {
            int index = (int)(c & 0xFFFFFFFL);
            int bitIndex = x & 0x3F;
            long n = this.overflow[index + bitIndex / 8];
            return (int)((n >>>= 8 * (bitIndex & 7)) & 0xFFL);
        }
        int bitsBefore = Long.bitCount(m & -1L >>> 63 - x);
        int bitPos = Select.selectInLong(c, bitsBefore - 1);
        long y = c << 63 - bitPos << 1 | 1L << 63 - bitPos;
        return Long.numberOfLeadingZeros(y) + 1;
    }
}

