package org.apache.mahout.math.random;

import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.list.DoubleArrayList;

/* loaded from: input_file:org/apache/mahout/math/random/Multinomial.class */
public final class Multinomial<T> implements Sampler<T>, Iterable<T> {
    private final DoubleArrayList weight;
    private final List<T> values;
    private final Map<T, Integer> items;
    private Random rand;

    public Multinomial() {
        this.weight = new DoubleArrayList();
        this.values = Lists.newArrayList();
        this.items = Maps.newHashMap();
        this.rand = RandomUtils.getRandom();
        this.weight.add(0.0d);
        this.values.add(null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Multinomial(Multiset<T> multiset) {
        this();
        Preconditions.checkArgument(!multiset.isEmpty(), "Need some data to build sampler");
        this.rand = RandomUtils.getRandom();
        Iterator it = multiset.elementSet().iterator();
        while (it.hasNext()) {
            add(it.next(), multiset.count(r0));
        }
    }

    public Multinomial(Iterable<WeightedThing<T>> iterable) {
        this();
        for (WeightedThing<T> weightedThing : iterable) {
            add(weightedThing.getValue(), weightedThing.getWeight());
        }
    }

    public void add(T t, double d) {
        Preconditions.checkNotNull(t);
        Preconditions.checkArgument(!this.items.containsKey(t));
        int size = this.weight.size();
        if (size == 1) {
            this.weight.add(d);
            this.values.add(t);
            this.items.put(t, 1);
            return;
        }
        this.weight.add(this.weight.get(size / 2));
        this.values.add(this.values.get(size / 2));
        this.items.put(this.values.get(size / 2), Integer.valueOf(size));
        int i = size + 1;
        this.items.put(t, Integer.valueOf(i));
        this.weight.add(d);
        this.values.add(t);
        while (i > 1) {
            i /= 2;
            this.weight.set(i, this.weight.get(i) + d);
        }
    }

    public double getWeight(T t) {
        if (this.items.containsKey(t)) {
            return this.weight.get(this.items.get(t).intValue());
        }
        return 0.0d;
    }

    public double getProbability(T t) {
        if (this.items.containsKey(t)) {
            return this.weight.get(this.items.get(t).intValue()) / this.weight.get(1);
        }
        return 0.0d;
    }

    public double getWeight() {
        if (this.weight.size() > 1) {
            return this.weight.get(1);
        }
        return 0.0d;
    }

    public void delete(T t) {
        set(t, 0.0d);
    }

    public void set(T t, double d) {
        Preconditions.checkArgument(this.items.containsKey(t));
        int intValue = this.items.get(t).intValue();
        if (d <= 0.0d) {
            this.items.remove(t);
        }
        double d2 = this.weight.get(intValue);
        while (intValue > 0) {
            this.weight.set(intValue, (this.weight.get(intValue) - d2) + d);
            intValue /= 2;
        }
    }

    @Override // org.apache.mahout.math.random.Sampler
    public T sample() {
        Preconditions.checkArgument(!this.weight.isEmpty());
        return sample(this.rand.nextDouble());
    }

    public T sample(double d) {
        double d2 = d * this.weight.get(1);
        int i = 1;
        while (true) {
            int i2 = i;
            if (2 * i2 >= this.weight.size()) {
                return this.values.get(i2);
            }
            double d3 = this.weight.get(2 * i2);
            if (d2 <= d3) {
                i = 2 * i2;
            } else {
                d2 -= d3;
                i = (2 * i2) + 1;
            }
        }
    }

    List<Double> getWeights() {
        ArrayList newArrayList = Lists.newArrayList();
        int highestOneBit = Integer.highestOneBit(this.weight.size());
        while (highestOneBit < this.weight.size()) {
            newArrayList.add(Double.valueOf(this.weight.get(highestOneBit)));
            highestOneBit++;
        }
        for (int i = highestOneBit / 2; i < Integer.highestOneBit(this.weight.size()); i++) {
            newArrayList.add(Double.valueOf(this.weight.get(i)));
        }
        return newArrayList;
    }

    @Override // java.lang.Iterable
    public Iterator<T> iterator() {
        return new AbstractIterator<T>() { // from class: org.apache.mahout.math.random.Multinomial.1
            Iterator<T> valuesIterator;

            {
                this.valuesIterator = Iterables.skip(Multinomial.this.values, 1).iterator();
            }

            protected T computeNext() {
                while (this.valuesIterator.hasNext()) {
                    T next = this.valuesIterator.next();
                    if (Multinomial.this.items.containsKey(next)) {
                        return next;
                    }
                }
                return (T) endOfData();
            }
        };
    }
}
