package opennlp.tools.ml;

import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.BeamSearchContextGenerator;
import opennlp.tools.util.Cache;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;

/* loaded from: input_file:opennlp/tools/ml/BeamSearch.class */
public class BeamSearch<T> implements SequenceClassificationModel<T> {
    public static final String BEAM_SIZE_PARAMETER = "BeamSize";
    private static final Object[] EMPTY_ADDITIONAL_CONTEXT = new Object[0];
    protected int size;
    protected MaxentModel model;
    private double[] probs;
    private Cache<String[], double[]> contextsCache;
    private static final int zeroLog = -100000;

    public BeamSearch(int i, MaxentModel maxentModel) {
        this(i, maxentModel, 0);
    }

    public BeamSearch(int i, MaxentModel maxentModel, int i2) {
        this.size = i;
        this.model = maxentModel;
        if (i2 > 0) {
            this.contextsCache = new Cache<>(i2);
        }
        this.probs = new double[maxentModel.getNumOutcomes()];
    }

    @Override // opennlp.tools.ml.model.SequenceClassificationModel
    public Sequence[] bestSequences(int i, T[] tArr, Object[] objArr, double d, BeamSearchContextGenerator<T> beamSearchContextGenerator, SequenceValidator<T> sequenceValidator) {
        double[] eval;
        PriorityQueue priorityQueue = new PriorityQueue(this.size);
        PriorityQueue priorityQueue2 = new PriorityQueue(this.size);
        priorityQueue.add(new Sequence());
        if (objArr == null) {
            objArr = EMPTY_ADDITIONAL_CONTEXT;
        }
        for (int i2 = 0; i2 < tArr.length; i2++) {
            int min = Math.min(this.size, priorityQueue.size());
            for (int i3 = 0; priorityQueue.size() > 0 && i3 < min; i3++) {
                Sequence sequence = (Sequence) priorityQueue.remove();
                List<String> outcomes = sequence.getOutcomes();
                String[] strArr = (String[]) outcomes.toArray(new String[outcomes.size()]);
                String[] context = beamSearchContextGenerator.getContext(i2, tArr, strArr, objArr);
                if (this.contextsCache != null) {
                    eval = this.contextsCache.get(context);
                    if (eval == null) {
                        eval = this.model.eval(context, this.probs);
                        this.contextsCache.put(context, eval);
                    }
                } else {
                    eval = this.model.eval(context, this.probs);
                }
                double[] dArr = new double[eval.length];
                System.arraycopy(eval, 0, dArr, 0, eval.length);
                Arrays.sort(dArr);
                double d2 = dArr[Math.max(0, eval.length - this.size)];
                for (int i4 = 0; i4 < eval.length; i4++) {
                    if (eval[i4] >= d2) {
                        String outcome = this.model.getOutcome(i4);
                        if (sequenceValidator.validSequence(i2, tArr, strArr, outcome)) {
                            Sequence sequence2 = new Sequence(sequence, outcome, eval[i4]);
                            if (sequence2.getScore() > d) {
                                priorityQueue2.add(sequence2);
                            }
                        }
                    }
                }
                if (priorityQueue2.size() == 0) {
                    for (int i5 = 0; i5 < eval.length; i5++) {
                        String outcome2 = this.model.getOutcome(i5);
                        if (sequenceValidator.validSequence(i2, tArr, strArr, outcome2)) {
                            Sequence sequence3 = new Sequence(sequence, outcome2, eval[i5]);
                            if (sequence3.getScore() > d) {
                                priorityQueue2.add(sequence3);
                            }
                        }
                    }
                }
            }
            priorityQueue.clear();
            PriorityQueue priorityQueue3 = priorityQueue;
            priorityQueue = priorityQueue2;
            priorityQueue2 = priorityQueue3;
        }
        int min2 = Math.min(i, priorityQueue.size());
        Sequence[] sequenceArr = new Sequence[min2];
        for (int i6 = 0; i6 < min2; i6++) {
            sequenceArr[i6] = (Sequence) priorityQueue.remove();
        }
        return sequenceArr;
    }

    @Override // opennlp.tools.ml.model.SequenceClassificationModel
    public Sequence[] bestSequences(int i, T[] tArr, Object[] objArr, BeamSearchContextGenerator<T> beamSearchContextGenerator, SequenceValidator<T> sequenceValidator) {
        return bestSequences(i, tArr, objArr, -100000.0d, beamSearchContextGenerator, sequenceValidator);
    }

    @Override // opennlp.tools.ml.model.SequenceClassificationModel
    public Sequence bestSequence(T[] tArr, Object[] objArr, BeamSearchContextGenerator<T> beamSearchContextGenerator, SequenceValidator<T> sequenceValidator) {
        Sequence[] bestSequences = bestSequences(1, tArr, objArr, beamSearchContextGenerator, sequenceValidator);
        if (bestSequences.length > 0) {
            return bestSequences[0];
        }
        return null;
    }

    @Override // opennlp.tools.ml.model.SequenceClassificationModel
    public String[] getOutcomes() {
        String[] strArr = new String[this.model.getNumOutcomes()];
        for (int i = 0; i < this.model.getNumOutcomes(); i++) {
            strArr[i] = this.model.getOutcome(i);
        }
        return strArr;
    }
}
