package edu.berkeley.nlp.lm.io;

import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.map.HashNgramMap;
import edu.berkeley.nlp.lm.map.NgramMap;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer;
import edu.berkeley.nlp.lm.values.ProbBackoffPair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/lm/io/KneserNeyLmReaderCallback.class */
public class KneserNeyLmReaderCallback<W> implements NgramOrderedLmReaderCallback<LongRef>, LmReader<ProbBackoffPair, ArpaLmReaderCallback<ProbBackoffPair>>, ArrayEncodedNgramLanguageModel<W>, Serializable {
    protected static final long serialVersionUID = 1;
    protected static final int MAX_ORDER = 10;
    protected static final float DEFAULT_DISCOUNT = 0.75f;
    protected final int lmOrder;
    protected final WordIndexer<W> wordIndexer;
    protected final HashNgramMap<KneserNeyCountValueContainer.KneserNeyCounts> ngrams;
    protected final ConfigOptions opts;
    protected final int startIndex;
    static final /* synthetic */ boolean $assertionsDisabled;

    public KneserNeyLmReaderCallback(WordIndexer<W> wordIndexer, int i) {
        this(wordIndexer, i, new ConfigOptions());
    }

    public KneserNeyLmReaderCallback(WordIndexer<W> wordIndexer, int i, ConfigOptions configOptions) {
        this.lmOrder = i;
        this.startIndex = wordIndexer.getIndexPossiblyUnk(wordIndexer.getStartSymbol());
        if (i >= MAX_ORDER) {
            throw new IllegalArgumentException("Reguested n-grams of order " + i + " but we only allow up to " + MAX_ORDER);
        }
        this.opts = configOptions;
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : configOptions.kneserNeyMinCounts) {
            if (d2 < d) {
                throw new IllegalArgumentException("Please ensure that ConfigOptions.kneserNeyMinCounts is monotonic (value was " + Arrays.toString(configOptions.kneserNeyMinCounts) + ")");
            }
            d = d2;
        }
        this.wordIndexer = wordIndexer;
        this.ngrams = HashNgramMap.createExplicitWordHashNgramMap(new KneserNeyCountValueContainer(this.lmOrder, this.startIndex), configOptions, this.lmOrder, false);
    }

    public void call(W[] wArr, LongRef longRef) {
        int[] iArr = new int[wArr.length];
        for (int i = 0; i < wArr.length; i++) {
            iArr[i] = this.wordIndexer.getOrAddIndex(wArr[i]);
        }
        call(iArr, 0, iArr.length, longRef, "");
    }

    public void callJustLast(W[] wArr, LongRef longRef, long[][] jArr) {
        int[] iArr = new int[wArr.length];
        for (int i = 0; i < wArr.length; i++) {
            iArr[i] = this.wordIndexer.getOrAddIndex(wArr[i]);
        }
        addNgram(iArr, 0, iArr.length, longRef, "", true, jArr);
    }

    @Override // edu.berkeley.nlp.lm.io.LmReaderCallback
    public void call(int[] iArr, int i, int i2, LongRef longRef, String str) {
        addNgram(iArr, i, i2, longRef, str, false, new long[this.lmOrder][i2 - i]);
    }

    public void addNgram(int[] iArr, int i, int i2, LongRef longRef, String str, boolean z, long[][] jArr) {
        KneserNeyCountValueContainer.KneserNeyCounts kneserNeyCounts = new KneserNeyCountValueContainer.KneserNeyCounts();
        this.ngrams.rehashIfNecessary(i2 - i);
        int i3 = 0;
        while (i3 < this.lmOrder) {
            for (int i4 = i; i4 < i2; i4++) {
                int i5 = i4 + i3 + 1;
                if (i5 <= i2) {
                    kneserNeyCounts.tokenCounts = longRef.value;
                    long j = i3 == 0 ? 0L : jArr[i3 - 1][i4];
                    long j2 = i3 == 0 ? 0L : jArr[i3 - 1][i4 + 1];
                    if (!$assertionsDisabled && j < 0) {
                        throw new AssertionError();
                    }
                    jArr[i3][i4 - i] = this.ngrams.putWithOffsetAndSuffix(iArr, i4, i5, j, j2, (!z || i5 == i2) ? kneserNeyCounts : null);
                }
            }
            i3++;
        }
    }

    protected float interpolateProb(int[] iArr, int i, int i2) {
        if (i == i2) {
            return 0.0f;
        }
        return getLowerOrderProb(iArr, i, i2) + (getLowerOrderBackoff(iArr, i, i2 - 1) * interpolateProb(iArr, i + 1, i2));
    }

    protected float getHighestOrderProb(int[] iArr, int i, int i2) {
        KneserNeyCountValueContainer.KneserNeyCounts counts = getCounts(iArr, i, i2, false);
        KneserNeyCountValueContainer.KneserNeyCounts counts2 = getCounts(iArr, i, i2 - 1, true);
        return counts2.tokenCounts == 0 ? 0.0f : Math.max(0.0f, (((float) counts.tokenCounts) - getDiscountForOrder((i2 - i) - 1)) / ((float) counts2.tokenCounts));
    }

    protected float getLowerOrderProb(int[] iArr, int i, int i2) {
        if (i == i2) {
            return 1.0f;
        }
        KneserNeyCountValueContainer.KneserNeyCounts counts = getCounts(iArr, i, i2, false);
        KneserNeyCountValueContainer.KneserNeyCounts counts2 = getCounts(iArr, i, i2 - 1, true);
        return counts2.dotdotTypeCounts == 0 ? 0.0f : Math.max(0.0f, ((float) counts.leftDotTypeCounts) - (i2 - i == 1 ? 0.0f : getDiscountForOrder((i2 - i) - 1))) / ((float) counts2.dotdotTypeCounts);
    }

    protected float getLowerOrderBackoff(int[] iArr, int i, int i2) {
        if (i == i2) {
            return 1.0f;
        }
        KneserNeyCountValueContainer.KneserNeyCounts counts = getCounts(iArr, i, i2, true);
        long j = (i2 - i == this.lmOrder - 1 || iArr[i] == this.startIndex) ? counts.tokenCounts : counts.dotdotTypeCounts;
        if ($assertionsDisabled || j >= 0) {
            return ((float) j) == 0.0f ? 1.0f : (getDiscountForOrder(i2 - i) * ((float) counts.rightDotTypeCounts)) / ((float) j);
        }
        throw new AssertionError();
    }

    protected float getDiscountForOrder(int i) {
        if (this.opts.kneserNeyDiscounts != null) {
            return (float) this.opts.kneserNeyDiscounts[i];
        }
        int numOneCountNgrams = ((KneserNeyCountValueContainer) this.ngrams.getValues()).getNumOneCountNgrams(i);
        float numTwoCountNgrams = numOneCountNgrams + (2.0f * ((KneserNeyCountValueContainer) this.ngrams.getValues()).getNumTwoCountNgrams(i));
        if (numTwoCountNgrams == 0.0f) {
            return 1.0E-5f;
        }
        return numOneCountNgrams / numTwoCountNgrams;
    }

    @Override // edu.berkeley.nlp.lm.io.LmReaderCallback
    public void cleanup() {
    }

    private KneserNeyCountValueContainer.KneserNeyCounts getCounts(int[] iArr, int i, int i2, boolean z) {
        KneserNeyCountValueContainer.KneserNeyCounts kneserNeyCounts = new KneserNeyCountValueContainer.KneserNeyCounts();
        if (i == i2) {
            kneserNeyCounts.dotdotTypeCounts = ((KneserNeyCountValueContainer) this.ngrams.getValues()).getBigramTypeCounts();
            return kneserNeyCounts;
        }
        long offsetForNgramInModel = this.ngrams.getOffsetForNgramInModel(iArr, i, i2);
        if (offsetForNgramInModel < 0) {
            return kneserNeyCounts;
        }
        this.ngrams.getValues().getFromOffset(offsetForNgramInModel, (i2 - i) - 1, kneserNeyCounts);
        boolean z2 = iArr[i] == this.startIndex;
        boolean z3 = iArr[i2 - 1] == this.wordIndexer.getIndexPossiblyUnk(this.wordIndexer.getEndSymbol());
        if (z2) {
            kneserNeyCounts.dotdotTypeCounts = kneserNeyCounts.rightDotTypeCounts;
            if (i2 - i < this.lmOrder - 1 || (i2 - i == this.lmOrder - 1 && !z)) {
                kneserNeyCounts.tokenCounts = kneserNeyCounts.leftDotTypeCounts;
            }
        }
        if (z3) {
            kneserNeyCounts.rightDotTypeCounts = serialVersionUID;
            kneserNeyCounts.dotdotTypeCounts = kneserNeyCounts.leftDotTypeCounts;
        }
        return kneserNeyCounts;
    }

    public static double[] defaultDiscounts() {
        return constantArray(MAX_ORDER, 0.75d);
    }

    public static double[] defaultMinCounts() {
        return new double[]{1.0d, 1.0d, 1.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d};
    }

    private static double[] constantArray(int i, double d) {
        double[] dArr = new double[i];
        Arrays.fill(dArr, d);
        return dArr;
    }

    @Override // edu.berkeley.nlp.lm.io.LmReader
    public void parse(ArpaLmReaderCallback<ProbBackoffPair> arpaLmReaderCallback) {
        Logger.startTrack("Writing Kneser-Ney probabilities", new Object[0]);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.lmOrder; i++) {
            Logger.startTrack("Counting counts for order " + i, new Object[0]);
            long j = 0;
            Iterator<NgramMap.Entry<KneserNeyCountValueContainer.KneserNeyCounts>> it = this.ngrams.getNgramsForOrder(i).iterator();
            while (it.hasNext()) {
                long j2 = it.next().value.tokenCounts;
                if (i < this.lmOrder - 2 || j2 >= this.opts.kneserNeyMinCounts[i]) {
                    j += serialVersionUID;
                }
            }
            arrayList.add(Long.valueOf(j));
            Logger.endTrack();
        }
        arpaLmReaderCallback.initWithLengths(arrayList);
        for (int i2 = 0; i2 < this.lmOrder; i2++) {
            arpaLmReaderCallback.handleNgramOrderStarted(i2 + 1);
            Logger.logss("On order " + (i2 + 1));
            int i3 = 0;
            for (NgramMap.Entry<KneserNeyCountValueContainer.KneserNeyCounts> entry : this.ngrams.getNgramsForOrder(i2)) {
                int i4 = i3;
                i3++;
                if (i4 % 10000 == 0) {
                    Logger.logs("Writing line " + i3);
                }
                long j3 = entry.value.tokenCounts;
                if (i2 < this.lmOrder - 2 || j3 >= this.opts.kneserNeyMinCounts[i2]) {
                    int[] iArr = entry.key;
                    int length = iArr.length;
                    arpaLmReaderCallback.call(iArr, 0, length, getProbBackoff(iArr, 0, length), "");
                }
            }
            arpaLmReaderCallback.handleNgramOrderFinished(i2 + 1);
        }
        arpaLmReaderCallback.cleanup();
        Logger.endTrack();
    }

    private ProbBackoffPair getProbBackoff(int[] iArr, int i, int i2) {
        boolean z = (i2 - i) - 1 == this.lmOrder - 1;
        float highestOrderProb = (z || iArr[i] == this.startIndex) ? getHighestOrderProb(iArr, i, i2) : getLowerOrderProb(iArr, i, i2);
        int i3 = i + 1;
        while (i3 < i2 && iArr[i3] == this.startIndex) {
            i3++;
        }
        return new ProbBackoffPair(i2 - i == 1 && iArr[i] == this.startIndex ? -99.0f : (float) Math.log10(highestOrderProb + (getLowerOrderBackoff(iArr, i, i2 - 1) * interpolateProb(iArr, i3, i2))), z ? 0.0f : (float) Math.log10(getLowerOrderBackoff(iArr, i, i2)));
    }

    @Override // edu.berkeley.nlp.lm.NgramLanguageModel
    public WordIndexer<W> getWordIndexer() {
        return this.wordIndexer;
    }

    @Override // edu.berkeley.nlp.lm.io.NgramOrderedLmReaderCallback
    public void handleNgramOrderFinished(int i) {
    }

    @Override // edu.berkeley.nlp.lm.io.NgramOrderedLmReaderCallback
    public void handleNgramOrderStarted(int i) {
    }

    @Override // edu.berkeley.nlp.lm.NgramLanguageModel
    public int getLmOrder() {
        return this.lmOrder;
    }

    @Override // edu.berkeley.nlp.lm.NgramLanguageModel
    public float scoreSentence(List<W> list) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.scoreSentence(list, this);
    }

    @Override // edu.berkeley.nlp.lm.NgramLanguageModel
    public float getLogProb(List<W> list) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(list, this);
    }

    @Override // edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel
    public float getLogProb(int[] iArr, int i, int i2) {
        return getProbBackoff(iArr, i, i2).prob;
    }

    @Override // edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel
    public float getLogProb(int[] iArr) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(iArr, this);
    }

    public long getTotalSize() {
        return this.ngrams.getTotalSize();
    }

    @Override // edu.berkeley.nlp.lm.NgramLanguageModel
    public void setOovWordLogProb(float f) {
        throw new UnsupportedOperationException("Method not yet implemented");
    }

    static {
        $assertionsDisabled = !KneserNeyLmReaderCallback.class.desiredAssertionStatus();
    }
}
