/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.crf.InstanceSequence;
import edu.berkeley.nlp.crf.ScoreCalculator;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.DoubleMatrices;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;
import java.io.Serializable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Inference<V, E, F, L>
implements Serializable {
    private static final long serialVersionUID = 1948395432745606240L;
    private final Encoding<F, L> encoding;
    private final ScoreCalculator<V, E, F, L> scoreCalculator;

    public Inference(Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor) {
        this.encoding = encoding;
        this.scoreCalculator = new ScoreCalculator<V, E, F, L>(encoding, vertexExtractor, edgeExtractor);
    }

    public double[][] getAlphas(InstanceSequence<V, E, L> sequence, double[] w) {
        int n = sequence.getSequenceLength();
        double[][] alpha = new double[n][];
        alpha[0] = this.scoreCalculator.getVertexScores(sequence, 0, w);
        int i = 1;
        while (i < n) {
            double[][] scoreMatrix = this.scoreCalculator.getScoreMatrix(sequence, i, w);
            alpha[i] = DoubleMatrices.product(alpha[i - 1], scoreMatrix);
            ++i;
        }
        return alpha;
    }

    public double[][] getBetas(InstanceSequence<V, E, L> sequence, double[] w) {
        int n = sequence.getSequenceLength();
        double[][] beta = new double[n][];
        beta[n - 1] = DoubleArrays.constantArray(1.0, this.encoding.getNumLabels());
        int i = n - 2;
        while (i >= 0) {
            double[][] scoreMatrix = this.scoreCalculator.getScoreMatrix(sequence, i + 1, w);
            beta[i] = DoubleMatrices.product(scoreMatrix, beta[i + 1]);
            --i;
        }
        return beta;
    }

    public Pair<int[][][][], double[][][]> getKBestChartAndBacktrace(InstanceSequence<V, E, L> sequence, double[] w, int k) {
        int n = sequence.getSequenceLength();
        int numLabels = this.encoding.getNumLabels();
        int[][][][] bestLabels = new int[n][numLabels][][];
        double[][][] bestScores = new double[n][numLabels][];
        double[] startScores = this.scoreCalculator.getLinearVertexScores(sequence, 0, w);
        int l = 0;
        while (l < numLabels) {
            bestScores[0][l] = new double[]{startScores[l]};
            int[][] nArrayArray = new int[1][];
            int[] nArray = new int[2];
            nArray[0] = -1;
            nArrayArray[0] = nArray;
            bestLabels[0][l] = nArrayArray;
            ++l;
        }
        int i = 1;
        while (i < n) {
            double[][] scoreMatrix = this.scoreCalculator.getLinearScoreMatrix(sequence, i, w);
            int l2 = 0;
            while (l2 < numLabels) {
                PriorityQueue<Pair<Integer, Integer>> pq = new PriorityQueue<Pair<Integer, Integer>>();
                int pl = 0;
                while (pl < numLabels) {
                    double edgeScore = scoreMatrix[pl][l2];
                    int c = 0;
                    while (c < bestScores[i - 1][pl].length) {
                        double totalScore = edgeScore + bestScores[i - 1][pl][c];
                        pq.add(Pair.makePair(pl, c), totalScore);
                        ++c;
                    }
                    ++pl;
                }
                int cands = Math.min(k, pq.size());
                bestScores[i][l2] = new double[cands];
                bestLabels[i][l2] = new int[cands][2];
                int c = 0;
                while (c < cands) {
                    bestScores[i][l2][c] = pq.getPriority();
                    Pair backtrace = (Pair)pq.next();
                    bestLabels[i][l2][c][0] = (Integer)backtrace.getFirst();
                    bestLabels[i][l2][c][1] = (Integer)backtrace.getSecond();
                    ++c;
                }
                ++l2;
            }
            ++i;
        }
        return Pair.makePair(bestLabels, bestScores);
    }

    public double[][] getVertexPosteriors(double[][] alpha, double[][] beta) {
        double[][] p = new double[alpha.length][this.encoding.getNumLabels()];
        int i = 0;
        while (i < p.length) {
            int l = 0;
            while (l < p[i].length) {
                p[i][l] = alpha[i][l] * beta[i][l];
                ++l;
            }
            ArrayUtil.normalize(p[i]);
            ++i;
        }
        return p;
    }

    public double[][][] getEdgePosteriors(InstanceSequence<V, E, L> sequence, double[] w, double[][] alpha, double[][] beta) {
        int numLabels = this.encoding.getNumLabels();
        int n = sequence.getSequenceLength();
        double[][][] p = new double[n][numLabels][numLabels];
        int i = 1;
        while (i < p.length) {
            double[][] scoreMatrix = this.scoreCalculator.getScoreMatrix(sequence, i, w);
            int lp = 0;
            while (lp < numLabels) {
                int lc = 0;
                while (lc < numLabels) {
                    p[i][lp][lc] = alpha[i - 1][lp] * scoreMatrix[lp][lc] * beta[i][lc];
                    ++lc;
                }
                ++lp;
            }
            ArrayUtil.normalize(p[i]);
            ++i;
        }
        return p;
    }

    public double getNormalizationConstant(double[][] alpha, double[][] beta) {
        int anyIndex = 0;
        double[] p = new double[alpha[anyIndex].length];
        int l = 0;
        while (l < p.length) {
            p[l] = alpha[anyIndex][l] * beta[anyIndex][l];
            ++l;
        }
        return ArrayUtil.sum(p);
    }
}

