/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class HierarchicalBinaryRule
extends BinaryRule {
    private static final long serialVersionUID = 1L;
    List<double[][][]> scoreHierarchy = new ArrayList<double[][][]>();
    public int lastLevel = -1;

    public HierarchicalBinaryRule(HierarchicalBinaryRule b) {
        super(b);
        for (double[][][] scores : b.scoreHierarchy) {
            this.scoreHierarchy.add(ArrayUtil.clone(scores));
        }
        this.lastLevel = b.lastLevel;
        this.scores = null;
    }

    public HierarchicalBinaryRule(BinaryRule b) {
        super(b);
        double[][][] scoreThisLevel = new double[1][1][1];
        scoreThisLevel[0][0][0] = Math.log(b.scores[0][0][0]);
        this.scoreHierarchy.add(scoreThisLevel);
        this.lastLevel = 0;
        this.scores = null;
    }

    public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates) {
        int newMaxStates = (int)Math.pow(2.0, finalLevel + 1);
        int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
        int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]);
        int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]);
        this.scores = new double[newLStates][newRStates][newPStates];
        int level = 0;
        while (level <= this.lastLevel) {
            double[][][] scoresThisLevel = this.scoreHierarchy.get(level);
            if (scoresThisLevel != null) {
                int divisorL = newLStates / scoresThisLevel.length;
                int divisorR = newRStates / scoresThisLevel[0].length;
                int divisorP = newPStates / scoresThisLevel[0][0].length;
                int lChild = 0;
                while (lChild < newLStates) {
                    int rChild = 0;
                    while (rChild < newRStates) {
                        int parent = 0;
                        while (parent < newPStates) {
                            double[] dArray = this.scores[lChild][rChild];
                            int n = parent;
                            dArray[n] = dArray[n] + scoresThisLevel[lChild / divisorL][rChild / divisorR][parent / divisorP];
                            ++parent;
                        }
                        ++rChild;
                    }
                    ++lChild;
                }
            }
            ++level;
        }
        int lChild = 0;
        while (lChild < newLStates) {
            int rChild = 0;
            while (rChild < newRStates) {
                int parent = 0;
                while (parent < newPStates) {
                    this.scores[lChild][rChild][parent] = Math.exp(this.scores[lChild][rChild][parent]);
                    ++parent;
                }
                ++rChild;
            }
            ++lChild;
        }
    }

    public double[][][] getLastLevel() {
        return this.scoreHierarchy.get(this.lastLevel);
    }

    public HierarchicalBinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) {
        if (mode != 2) {
            throw new Error("Can't split hiereachical rule in this mode!");
        }
        int newMaxStates = (int)Math.pow(2.0, this.lastLevel + 1);
        int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
        int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]);
        int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]);
        double[][][] newScores = new double[newLStates][newRStates][newPStates];
        int lChild = 0;
        while (lChild < newLStates) {
            int rChild = 0;
            while (rChild < newRStates) {
                int parent = 0;
                while (parent < newPStates) {
                    newScores[lChild][rChild][parent] = random.nextDouble() / 100.0;
                    ++parent;
                }
                ++rChild;
            }
            ++lChild;
        }
        HierarchicalBinaryRule newRule = new HierarchicalBinaryRule(this);
        newRule.scoreHierarchy.add(newScores);
        ++newRule.lastLevel;
        return newRule;
    }

    public int mergeRule() {
        double[][][] scoresFinalLevel = this.scoreHierarchy.get(this.lastLevel);
        boolean allZero = true;
        int lChild = 0;
        while (lChild < scoresFinalLevel.length) {
            int rChild = 0;
            while (rChild < scoresFinalLevel[0].length) {
                int parent = 0;
                while (parent < scoresFinalLevel[0][0].length) {
                    allZero = allZero && scoresFinalLevel[lChild][rChild][parent] == 0.0;
                    ++parent;
                }
                ++rChild;
            }
            ++lChild;
        }
        if (allZero) {
            scoresFinalLevel = null;
            this.scoreHierarchy.remove(this.lastLevel);
            --this.lastLevel;
            return 1;
        }
        return 0;
    }

    public String toString() {
        Numberer n = Numberer.getGlobalNumberer("tags");
        String lState = (String)n.object(this.leftChildState);
        String rState = (String)n.object(this.rightChildState);
        String pState = (String)n.object(this.parentState);
        StringBuilder sb = new StringBuilder();
        if (this.scores == null) {
            return String.valueOf(pState) + " -> " + lState + " " + rState + "\n";
        }
        sb.append(String.valueOf(pState) + " -> " + lState + " " + rState + "\n");
        sb.append(String.valueOf(ArrayUtil.toString(this.scores)) + "\n");
        for (double[][][] s : this.scoreHierarchy) {
            sb.append(String.valueOf(ArrayUtil.toString(s)) + "\n");
        }
        sb.append("\n");
        return sb.toString();
    }

    public int countNonZeroFeatures() {
        int total = 0;
        int level = 0;
        while (level <= this.lastLevel) {
            double[][][] scoresThisLevel = this.scoreHierarchy.get(level);
            if (scoresThisLevel != null) {
                int lChild = 0;
                while (lChild < scoresThisLevel.length) {
                    int rChild = 0;
                    while (rChild < scoresThisLevel.length) {
                        int parent = 0;
                        while (parent < scoresThisLevel.length) {
                            if (scoresThisLevel[lChild][rChild][parent] != 0.0) {
                                ++total;
                            }
                            ++parent;
                        }
                        ++rChild;
                    }
                    ++lChild;
                }
            }
            ++level;
        }
        return total;
    }
}

