/*
 * Decompiled with CFR 0.152.
 */
package org.fastfilter.bloom.count;

import org.fastfilter.Filter;
import org.fastfilter.utils.Hash;

public class SuccinctCountingBloomRanked
implements Filter {
    private static final boolean VERIFY_COUNTS = false;
    private final int buckets;
    private final long seed;
    private final long[] data;
    private final long[] counts;
    private int nextFreeOverflow;
    private final long[] overflow;
    private final byte[] realCounts;

    public static SuccinctCountingBloomRanked construct(long[] keys, int bitsPerKey) {
        long n = keys.length;
        int k = SuccinctCountingBloomRanked.getBestK(bitsPerKey);
        SuccinctCountingBloomRanked f = new SuccinctCountingBloomRanked((int)n, bitsPerKey, k);
        for (long x : keys) {
            f.add(x);
        }
        return f;
    }

    private static int getBestK(double bitsPerKey) {
        return Math.max(1, (int)Math.round(bitsPerKey * Math.log(2.0)));
    }

    @Override
    public long getBitCount() {
        return 64L * (long)this.data.length + 64L * (long)this.counts.length + 64L * (long)this.overflow.length;
    }

    SuccinctCountingBloomRanked(int entryCount, int bitsPerKey, int k) {
        entryCount = Math.max(1, entryCount);
        this.seed = Hash.randomSeed();
        long bits = entryCount * bitsPerKey;
        this.buckets = (int)bits / 64;
        int arrayLength = this.buckets + 16;
        this.data = new long[arrayLength];
        this.counts = new long[arrayLength];
        this.overflow = new long[100 + arrayLength * 12 / 100];
        for (int i = 0; i < this.overflow.length; i += 8) {
            this.overflow[i] = i + 8;
        }
        this.realCounts = null;
    }

    @Override
    public boolean supportsAdd() {
        return true;
    }

    @Override
    public void add(long key) {
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        int a1 = (int)(hash & 0x3FL);
        this.increment(start, a1);
        int second = Hash.reduce((int)(hash >>> 32), this.buckets);
        int a2 = (int)(hash >>> 6 & 0x3FL);
        this.increment(second, a2);
    }

    @Override
    public boolean supportsRemove() {
        return true;
    }

    @Override
    public void remove(long key) {
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        int a1 = (int)(hash & 0x3FL);
        this.decrement(start, a1);
        int second = Hash.reduce((int)(hash >>> 32), this.buckets);
        int a2 = (int)(hash >>> 6 & 0x3FL);
        this.decrement(second, a2);
    }

    @Override
    public long cardinality() {
        long sum = 0L;
        for (long x : this.data) {
            sum += (long)Long.bitCount(x);
        }
        for (long x : this.counts) {
            sum += (long)Long.bitCount(x);
        }
        return sum;
    }

    @Override
    public boolean mayContain(long key) {
        int second;
        long b;
        long hash = Hash.hash64(key, this.seed);
        int start = Hash.reduce((int)hash, this.buckets);
        long a = this.data[start];
        return (a >>> (int)hash & (b = this.data[second = Hash.reduce((int)(hash >>> 32), this.buckets)]) >>> (int)(hash >> 6) & 1L) == 1L;
    }

    private void increment(int group, int x) {
        int bitsBefore;
        long m = this.data[group];
        long c = this.counts[group];
        if ((c & Long.MIN_VALUE) != 0L) {
            System.out.println("overflow!");
            int index = (int)(c & 0xFFFFFFFL);
            this.counts[group] = c += 0x100000000L;
            int bitIndex = x & 0x3F;
            int n = index + bitIndex / 8;
            this.overflow[n] = this.overflow[n] + SuccinctCountingBloomRanked.getBit(bitIndex);
            int n2 = group;
            this.data[n2] = this.data[n2] | 1L << x;
            return;
        }
        long d = m >>> x & 1L;
        if (d == 0L && c == 0L) {
            int n = group;
            this.data[n] = this.data[n] | 1L << x;
            return;
        }
        int bitsSet = Long.bitCount(m);
        int insertAt = bitsBefore = x == 0 ? 0 : Long.bitCount(m << 64 - x);
        if (d == 1L) {
            long bitsForLevel;
            int startLevel = 0;
            while (true) {
                long levelMask = (1L << bitsSet) - 1L << startLevel;
                bitsForLevel = c & levelMask;
                if ((c >>> insertAt & 1L) == 0L) break;
                if ((startLevel += bitsSet) >= 64) {
                    insertAt = 64;
                    break;
                }
                bitsSet = Long.bitCount(bitsForLevel);
                bitsBefore = insertAt == 0 ? 0 : Long.bitCount(bitsForLevel << 64 - insertAt);
                insertAt = startLevel + bitsBefore;
            }
            c |= 1L << insertAt;
            int bitsBeforeLevel = insertAt == 0 ? 0 : Long.bitCount(bitsForLevel << 64 - insertAt);
            int bitsSetLevel = Long.bitCount(bitsForLevel);
            insertAt = startLevel + bitsSet + bitsBeforeLevel;
            bitsSet = bitsSetLevel;
        }
        long mask = (1L << insertAt) - 1L;
        long left = c & (mask ^ 0xFFFFFFFFFFFFFFFFL);
        long right = c & mask;
        c = left << 1 | right;
        if (insertAt >= 64 || (c & Long.MIN_VALUE) != 0L) {
            int index = this.allocateOverflow();
            long count = 1L;
            for (int i = 0; i < 64; ++i) {
                int n = this.readCount((group << 6) + i);
                count += (long)n;
                int n3 = index + i / 8;
                this.overflow[n3] = this.overflow[n3] + (long)n * SuccinctCountingBloomRanked.getBit(i);
            }
            c = Long.MIN_VALUE | count << 32 | (long)index;
            int n = group;
            this.data[n] = this.data[n] | 1L << x;
            this.counts[group] = c;
            int bitIndex = x & 0x3F;
            int n4 = index + bitIndex / 8;
            this.overflow[n4] = this.overflow[n4] + SuccinctCountingBloomRanked.getBit(bitIndex);
            return;
        }
        int n = group;
        this.data[n] = this.data[n] | 1L << x;
        this.counts[group] = c;
    }

    private int allocateOverflow() {
        int result = this.nextFreeOverflow;
        this.nextFreeOverflow = (int)this.overflow[result];
        for (int i = 0; i < 8; ++i) {
            this.overflow[result + i] = 0L;
        }
        return result;
    }

    private void decrement(int group, int x) {
        int bitsBefore;
        long m = this.data[group];
        long c = this.counts[group];
        if ((c & Long.MIN_VALUE) != 0L) {
            int count = (int)(c >>> 32) & 0xFFFFFFF;
            this.counts[group] = c -= 0x100000000L;
            int index = (int)(c & 0xFFFFFFFL);
            int bitIndex = x & 0x3F;
            long n = this.overflow[index + bitIndex / 8];
            this.overflow[index + bitIndex / 8] = n - SuccinctCountingBloomRanked.getBit(bitIndex);
            if (((n >>>= 8 * (bitIndex & 7)) & 0xFFL) == 1L) {
                int n2 = group;
                this.data[n2] = this.data[n2] & (1L << x ^ 0xFFFFFFFFFFFFFFFFL);
            }
            if (count < 64) {
                int count2 = 0;
                int[] temp = new int[64];
                for (int j = 63; j >= 0; --j) {
                    int cj;
                    temp[j] = cj = (int)(this.overflow[index + j / 8] >>> 8 * j & 0xFFL);
                    count2 += cj;
                }
                long c2 = 0L;
                int off = 0;
                while (count2 > 0) {
                    for (int i = 0; i < 64; ++i) {
                        int t = temp[i];
                        if (t <= 0) continue;
                        int n3 = i;
                        temp[n3] = temp[n3] - 1;
                        --count2;
                        c2 |= (t > 1 ? 1L : 0L) << off;
                        ++off;
                    }
                }
                this.counts[group] = c2;
                this.freeOverflow(index);
            }
            return;
        }
        int bitsSet = Long.bitCount(m);
        int removeAt = bitsBefore = x == 0 ? 0 : Long.bitCount(m << 64 - x);
        long d = c >>> bitsBefore & 1L;
        if (d == 1L) {
            int startLevel = 0;
            int resetAt = removeAt;
            do {
                long levelMask = (1L << bitsSet) - 1L << startLevel;
                long bitsForLevel = c & levelMask;
                if ((c >>> removeAt & 1L) == 0L) break;
                bitsSet = Long.bitCount(bitsForLevel);
                bitsBefore = removeAt == 0 ? 0 : Long.bitCount(bitsForLevel << 64 - removeAt);
                resetAt = removeAt;
            } while ((removeAt = (startLevel += bitsSet) + bitsBefore) <= 63);
            c ^= 1L << resetAt;
        }
        if (removeAt < 64) {
            long mask = (1L << removeAt) - 1L;
            long left = c >>> 1 & (mask ^ 0xFFFFFFFFFFFFFFFFL);
            long right = c & mask;
            c = left | right;
        }
        this.counts[group] = c;
        this.data[group] = m & ((d == 0L ? 1L : 0L) << x ^ 0xFFFFFFFFFFFFFFFFL);
    }

    private void freeOverflow(int index) {
        this.overflow[index] = this.nextFreeOverflow;
        this.nextFreeOverflow = index;
    }

    private static long getBit(int index) {
        return 1L << index * 8;
    }

    private int readCount(int x) {
        int count;
        block4: {
            int group = x >>> 6;
            long m = this.data[group];
            long d = m >>> x & 1L;
            if (d == 0L) {
                return 0;
            }
            long c = this.counts[group];
            if ((c & Long.MIN_VALUE) != 0L) {
                int index = (int)(c & 0xFFFFFFFL);
                int bitIndex = x & 0x3F;
                long n = this.overflow[index + bitIndex / 8];
                return (int)((n >>>= 8 * (bitIndex & 7)) & 0xFFL);
            }
            if (c == 0L) {
                return 1;
            }
            int bitsSet = Long.bitCount(m);
            int bitsBefore = (x &= 0x3F) == 0 ? 0 : Long.bitCount(m << 64 - x);
            count = 1;
            int insertAt = bitsBefore;
            do {
                long levelMask = (1L << bitsSet) - 1L;
                long bitsForLevel = c & levelMask;
                if ((c >>> insertAt & 1L) != 1L) break block4;
                c >>>= bitsSet;
                bitsSet = Long.bitCount(bitsForLevel);
                insertAt = bitsBefore = insertAt == 0 ? 0 : Long.bitCount(bitsForLevel << 64 - insertAt);
            } while (++count <= 16);
            throw new AssertionError();
        }
        return count;
    }

    private void verifyCounts(int from, int to) {
    }

    static String getBitsNumber(long x) {
        Object s = "0".repeat(64) + Long.toBinaryString(x);
        s = ((String)s).substring(((String)s).length() - 64);
        return s;
    }
}

