package com.googlecode.clearnlp.beam;

import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.reader.AbstractColumnReader;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/googlecode/clearnlp/beam/BeamTree.class */
public class BeamTree<T> {
    private List<BeamNode<T>> b_nodes;
    private int n_size;

    public BeamTree(int i) {
        this.b_nodes = new ArrayList(i);
        this.n_size = i;
    }

    public BeamNode<T> getNode(int i) {
        return this.b_nodes.get(i);
    }

    public List<BeamNode<T>> getCurrNodes() {
        return this.b_nodes;
    }

    public void setBeam(List<List<StringPrediction>> list) {
        ArrayList arrayList = new ArrayList();
        List<StringPrediction> list2 = list.get(0);
        BeamNode<T> prevNode = getPrevNode(0);
        int size = list2.size();
        if (size > this.n_size) {
            size = this.n_size;
        }
        for (int i = 0; i < size; i++) {
            arrayList.add(new BeamNode(prevNode, list2.get(i)));
        }
        int size2 = list.size();
        for (int i2 = 1; i2 < size2; i2++) {
            List<StringPrediction> list3 = list.get(i2);
            BeamNode<T> prevNode2 = getPrevNode(i2);
            double score = ((BeamNode) arrayList.get(arrayList.size() - 1)).getScore();
            for (StringPrediction stringPrediction : list3) {
                if (stringPrediction.score < score) {
                    break;
                } else {
                    arrayList.add(new BeamNode(prevNode2, stringPrediction));
                }
            }
            Collections.sort(arrayList);
            if (arrayList.size() > this.n_size) {
                arrayList.subList(this.n_size, arrayList.size()).clear();
            }
        }
        this.b_nodes = arrayList;
    }

    public List<BeamNode<T>> getBestSequence() {
        List<BeamNode<T>> sequence = this.b_nodes.get(0).getSequence();
        double overallScore = getOverallScore(sequence);
        int size = this.b_nodes.size();
        for (int i = 1; i < size; i++) {
            List<BeamNode<T>> sequence2 = this.b_nodes.get(i).getSequence();
            double overallScore2 = getOverallScore(sequence2);
            if (overallScore2 > overallScore) {
                sequence = sequence2;
                overallScore = overallScore2;
            }
        }
        return sequence;
    }

    protected double getOverallScore(List<BeamNode<T>> list) {
        double d = 0.0d;
        Iterator<BeamNode<T>> it = list.iterator();
        while (it.hasNext()) {
            d += it.next().getScore();
        }
        return d;
    }

    private BeamNode<T> getPrevNode(int i) {
        if (this.b_nodes.isEmpty()) {
            return null;
        }
        return this.b_nodes.get(i);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (BeamNode<T> beamNode : this.b_nodes) {
            sb.append(AbstractColumnReader.DELIM_SENTENCE);
            sb.append(toString(beamNode));
        }
        return sb.substring(1);
    }

    public String toString(BeamNode<T> beamNode) {
        ArrayDeque arrayDeque = new ArrayDeque();
        while (beamNode != null) {
            arrayDeque.add(beamNode);
            beamNode = beamNode.getPrevNode();
        }
        StringBuilder sb = new StringBuilder();
        while (!arrayDeque.isEmpty()) {
            BeamNode beamNode2 = (BeamNode) arrayDeque.pollLast();
            sb.append(" -> ");
            sb.append(beamNode2.getLabel() + ":" + beamNode2.getScore());
        }
        return sb.substring(4);
    }
}
