package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.fsm.TransducerGraph;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/GrammarCompactor.class */
public abstract class GrammarCompactor {
    Set<TransducerGraph> compactedGraphs;
    protected Index<String> stateIndex;
    protected Index<String> newStateIndex;
    protected Distribution<String> inputPrior;
    private static final String END = "END";
    private static final String EPSILON = "EPSILON";
    protected final Options op;
    private static Redwood.RedwoodChannels log = Redwood.channels(GrammarCompactor.class);
    public static final Object RAW_COUNTS = new Object();
    public static final Object NORMALIZED_LOG_PROBABILITIES = new Object();
    public Object outputType = RAW_COUNTS;
    protected boolean verbose = false;

    public GrammarCompactor(Options options) {
        this.op = options;
    }

    protected abstract TransducerGraph doCompaction(TransducerGraph transducerGraph, List<List<String>> list, List<List<String>> list2);

    public Triple<Index<String>, UnaryGrammar, BinaryGrammar> compactGrammar(Pair<UnaryGrammar, BinaryGrammar> pair, Index<String> index) {
        return compactGrammar(pair, Generics.newHashMap(), Generics.newHashMap(), index);
    }

    public Triple<Index<String>, UnaryGrammar, BinaryGrammar> compactGrammar(Pair<UnaryGrammar, BinaryGrammar> pair, Map<String, List<List<String>>> map, Map<String, List<List<String>>> map2, Index<String> index) {
        this.inputPrior = computeInputPrior(map);
        this.stateIndex = index;
        Set<UnaryRule> newHashSet = Generics.newHashSet();
        Set<BinaryRule> newHashSet2 = Generics.newHashSet();
        Map<String, TransducerGraph> convertGrammarToGraphs = convertGrammarToGraphs(pair, newHashSet, newHashSet2);
        this.compactedGraphs = Generics.newHashSet();
        if (this.verbose) {
            System.out.println("There are " + convertGrammarToGraphs.size() + " categories to compact.");
        }
        int i = 0;
        Iterator<Map.Entry<String, TransducerGraph>> it = convertGrammarToGraphs.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<String, TransducerGraph> next = it.next();
            String key = next.getKey();
            TransducerGraph value = next.getValue();
            if (this.verbose) {
                System.out.println("About to compact grammar for " + key + " with numNodes=" + value.getNodes().size());
            }
            List<List<String>> remove = map.remove(key);
            if (remove == null) {
                remove = new ArrayList();
            }
            List<List<String>> remove2 = map2.remove(key);
            if (remove2 == null) {
                remove2 = new ArrayList();
            }
            TransducerGraph doCompaction = doCompaction(value, remove, remove2);
            i++;
            if (this.verbose) {
                System.out.println(i + ". Compacted grammar for " + key + " from " + value.getArcs().size() + " arcs to " + doCompaction.getArcs().size() + " arcs.");
            }
            it.remove();
            this.compactedGraphs.add(doCompaction);
        }
        Pair<UnaryGrammar, BinaryGrammar> convertGraphsToGrammar = convertGraphsToGrammar(this.compactedGraphs, newHashSet, newHashSet2);
        return new Triple<>(this.newStateIndex, convertGraphsToGrammar.first(), convertGraphsToGrammar.second());
    }

    protected static Distribution<String> computeInputPrior(Map<String, List<List<String>>> map) {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator<List<List<String>>> it = map.values().iterator();
        while (it.hasNext()) {
            Iterator<List<String>> it2 = it.next().iterator();
            while (it2.hasNext()) {
                Iterator<String> it3 = it2.next().iterator();
                while (it3.hasNext()) {
                    classicCounter.incrementCount(it3.next());
                }
            }
        }
        return Distribution.laplaceSmoothedDistribution(classicCounter, classicCounter.size() * 2, 0.5d);
    }

    private double smartNegate(double d) {
        return this.outputType == NORMALIZED_LOG_PROBABILITIES ? -d : d;
    }

    public static boolean writeFile(TransducerGraph transducerGraph, String str, String str2) {
        try {
            File file = new File(str);
            if (file.exists()) {
                if (!file.isDirectory()) {
                    return false;
                }
            } else if (!file.mkdirs()) {
                return false;
            }
            File file2 = new File(file, str2 + ".dot");
            try {
                try {
                    PrintWriter printWriter = new PrintWriter(new FileWriter(file2));
                    printWriter.print(transducerGraph.asDOTString());
                    printWriter.flush();
                    printWriter.close();
                    return true;
                } catch (IOException e) {
                    log.info("Failed to open file in writeToDOTfile: " + file2);
                    return false;
                }
            } catch (FileNotFoundException e2) {
                log.info("Failed to open file in writeToDOTfile: " + file2);
                return false;
            }
        } catch (Exception e3) {
            e3.printStackTrace();
            return false;
        }
    }

    protected Map<String, TransducerGraph> convertGrammarToGraphs(Pair<UnaryGrammar, BinaryGrammar> pair, Set<UnaryRule> set, Set<BinaryRule> set2) {
        int i = 0;
        UnaryGrammar unaryGrammar = pair.first;
        BinaryGrammar binaryGrammar = pair.second;
        Map<String, TransducerGraph> newHashMap = Generics.newHashMap();
        Iterator<BinaryRule> it = binaryGrammar.iterator();
        while (it.hasNext()) {
            BinaryRule next = it.next();
            i++;
            if (!addOneBinaryRule(next, newHashMap)) {
                set2.add(next);
            }
        }
        Iterator<UnaryRule> it2 = unaryGrammar.iterator();
        while (it2.hasNext()) {
            UnaryRule next2 = it2.next();
            i++;
            if (!addOneUnaryRule(next2, newHashMap)) {
                set.add(next2);
            }
        }
        if (this.verbose) {
            System.out.println("Number of raw rules: " + i);
            System.out.println("Number of raw states: " + this.stateIndex.size());
        }
        return newHashMap;
    }

    protected static TransducerGraph getGraphFromMap(Map<String, TransducerGraph> map, String str) {
        TransducerGraph transducerGraph = map.get(str);
        if (transducerGraph == null) {
            transducerGraph = new TransducerGraph();
            transducerGraph.setEndNode(str);
            map.put(str, transducerGraph);
        }
        return transducerGraph;
    }

    protected static String getTopCategoryOfSyntheticState(String str) {
        if (str.charAt(0) != '@') {
            return null;
        }
        int indexOf = str.indexOf(124);
        if (indexOf < 0) {
            throw new RuntimeException("Grammar format error. Expected bar in state name: " + str);
        }
        return str.substring(1, indexOf);
    }

    protected boolean addOneUnaryRule(UnaryRule unaryRule, Map<String, TransducerGraph> map) {
        String str = this.stateIndex.get(unaryRule.parent);
        String str2 = this.stateIndex.get(unaryRule.child);
        if (isSyntheticState(str)) {
            TransducerGraph graphFromMap = getGraphFromMap(map, getTopCategoryOfSyntheticState(str));
            graphFromMap.addArc(graphFromMap.getStartNode(), str, str2, new Double(smartNegate(unaryRule.score())));
            return true;
        }
        if (!isSyntheticState(str2)) {
            return false;
        }
        TransducerGraph graphFromMap2 = getGraphFromMap(map, str);
        graphFromMap2.addArc(str2, str, END, new Double(smartNegate(unaryRule.score())));
        graphFromMap2.setEndNode(str);
        return true;
    }

    protected boolean addOneBinaryRule(BinaryRule binaryRule, Map<String, TransducerGraph> map) {
        String str;
        String str2;
        String str3 = this.stateIndex.get(binaryRule.parent);
        String str4 = this.stateIndex.get(binaryRule.leftChild);
        String str5 = this.stateIndex.get(binaryRule.rightChild);
        String str6 = null;
        if (this.op.trainOptions.markFinalStates) {
            str6 = str3.substring(str3.length() - 1, str3.length());
        }
        if (isSyntheticState(str4)) {
            str = str4;
            str2 = str5 + (str6 == null ? ">" : str6);
        } else {
            if (!isSyntheticState(str5)) {
                return false;
            }
            str = str5;
            str2 = str4 + (str6 == null ? "<" : str6);
        }
        Double d = new Double(smartNegate(binaryRule.score()));
        String topCategoryOfSyntheticState = getTopCategoryOfSyntheticState(str);
        if (topCategoryOfSyntheticState == null) {
            throw new RuntimeException("can't have null topcat");
        }
        getGraphFromMap(map, topCategoryOfSyntheticState).addArc(str, str3, str2, d);
        return true;
    }

    protected static boolean isSyntheticState(String str) {
        return str.charAt(0) == '@';
    }

    protected Pair<UnaryGrammar, BinaryGrammar> convertGraphsToGrammar(Set<TransducerGraph> set, Set<UnaryRule> set2, Set<BinaryRule> set3) {
        BinaryRule binaryRule;
        this.newStateIndex = new HashIndex();
        for (UnaryRule unaryRule : set2) {
            unaryRule.parent = this.newStateIndex.addToIndex(this.stateIndex.get(unaryRule.parent));
            unaryRule.child = this.newStateIndex.addToIndex(this.stateIndex.get(unaryRule.child));
        }
        for (BinaryRule binaryRule2 : set3) {
            binaryRule2.parent = this.newStateIndex.addToIndex(this.stateIndex.get(binaryRule2.parent));
            binaryRule2.leftChild = this.newStateIndex.addToIndex(this.stateIndex.get(binaryRule2.leftChild));
            binaryRule2.rightChild = this.newStateIndex.addToIndex(this.stateIndex.get(binaryRule2.rightChild));
        }
        for (TransducerGraph transducerGraph : set) {
            Object startNode = transducerGraph.getStartNode();
            for (TransducerGraph.Arc arc : transducerGraph.getArcs()) {
                String obj = arc.getSourceNode().toString();
                String obj2 = arc.getTargetNode().toString();
                String obj3 = arc.getInput().toString();
                double doubleValue = ((Double) arc.getOutput()).doubleValue();
                if (obj.equals(startNode)) {
                    set2.add(new UnaryRule(this.newStateIndex.addToIndex(obj2), this.newStateIndex.addToIndex(obj3), smartNegate(doubleValue)));
                } else if (obj3.equals(END) || obj3.equals("EPSILON")) {
                    set2.add(new UnaryRule(this.newStateIndex.addToIndex(obj2), this.newStateIndex.addToIndex(obj), smartNegate(doubleValue)));
                } else {
                    int length = obj3.length();
                    char charAt = obj3.charAt(length - 1);
                    String substring = obj3.substring(0, length - 1);
                    if (charAt == '<' || charAt == '[') {
                        binaryRule = new BinaryRule(this.newStateIndex.addToIndex(obj2), this.newStateIndex.addToIndex(substring), this.newStateIndex.addToIndex(obj), smartNegate(doubleValue));
                    } else {
                        if (charAt != '>' && charAt != ']') {
                            throw new RuntimeException("Arc input is in unexpected format: " + arc);
                        }
                        binaryRule = new BinaryRule(this.newStateIndex.addToIndex(obj2), this.newStateIndex.addToIndex(obj), this.newStateIndex.addToIndex(substring), smartNegate(doubleValue));
                    }
                    set3.add(binaryRule);
                }
            }
        }
        ClassicCounter classicCounter = new ClassicCounter();
        if (this.outputType == RAW_COUNTS) {
            Iterator<UnaryRule> it = set2.iterator();
            while (it.hasNext()) {
                classicCounter.incrementCount(this.newStateIndex.get(it.next().parent), r0.score);
            }
            Iterator<BinaryRule> it2 = set3.iterator();
            while (it2.hasNext()) {
                classicCounter.incrementCount(this.newStateIndex.get(it2.next().parent), r0.score);
            }
        }
        this.newStateIndex.size();
        int i = 0;
        UnaryGrammar unaryGrammar = new UnaryGrammar(this.newStateIndex);
        BinaryGrammar binaryGrammar = new BinaryGrammar(this.newStateIndex);
        for (UnaryRule unaryRule2 : set2) {
            if (this.outputType == RAW_COUNTS) {
                unaryRule2.score = (float) Math.log(unaryRule2.score / classicCounter.getCount(this.newStateIndex.get(unaryRule2.parent)));
            }
            unaryGrammar.addRule(unaryRule2);
            i++;
        }
        for (BinaryRule binaryRule3 : set3) {
            if (this.outputType == RAW_COUNTS) {
                binaryRule3.score = (float) Math.log((binaryRule3.score - this.op.trainOptions.ruleDiscount) / classicCounter.getCount(this.newStateIndex.get(binaryRule3.parent)));
            }
            binaryGrammar.addRule(binaryRule3);
            i++;
        }
        if (this.verbose) {
            System.out.println("Number of minimized rules: " + i);
            System.out.println("Number of minimized states: " + this.newStateIndex.size());
        }
        unaryGrammar.purgeRules();
        binaryGrammar.splitRules();
        return new Pair<>(unaryGrammar, binaryGrammar);
    }
}
