package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeReaderFactory;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.SynchronizedTreeTransformer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeNormalizer;
import edu.stanford.nlp.trees.TreeReaderFactory;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileFilter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:edu/stanford/nlp/parser/dvparser/CacheParseHypotheses.class */
public class CacheParseHypotheses {
    private static Redwood.RedwoodChannels log = Redwood.channels(CacheParseHypotheses.class);
    static final TreeReaderFactory trf = new LabeledScoredTreeReaderFactory(CoreLabel.factory(), new TreeNormalizer());
    final BasicCategoryTreeTransformer treeBasicCategories;
    public final Predicate<Tree> treeFilter;

    /* loaded from: input_file:edu/stanford/nlp/parser/dvparser/CacheParseHypotheses$CacheProcessor.class */
    static class CacheProcessor implements ThreadsafeProcessor<Tree, Pair<Tree, byte[]>> {
        CacheParseHypotheses cacher;
        LexicalizedParser parser;
        int dvKBest;
        TreeTransformer transformer;

        public CacheProcessor(CacheParseHypotheses cacheParseHypotheses, LexicalizedParser lexicalizedParser, int i, TreeTransformer treeTransformer) {
            this.cacher = cacheParseHypotheses;
            this.parser = lexicalizedParser;
            this.dvKBest = i;
            this.transformer = treeTransformer;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Pair<Tree, byte[]> process(Tree tree) {
            List<Tree> topParsesForOneTree = DVParser.getTopParsesForOneTree(this.parser, this.dvKBest, tree, this.transformer);
            CacheParseHypotheses cacheParseHypotheses = this.cacher;
            List<Tree> convertToTrees = CacheParseHypotheses.convertToTrees(this.cacher.convertToBytes(topParsesForOneTree));
            List filterAsList = CollectionUtils.filterAsList(CollectionUtils.transformAsList(topParsesForOneTree, this.cacher.treeBasicCategories), this.cacher.treeFilter);
            if (filterAsList.size() != topParsesForOneTree.size()) {
                CacheParseHypotheses.log.info("Filtered " + (topParsesForOneTree.size() - filterAsList.size()) + " trees");
                if (filterAsList.size() == 0) {
                    CacheParseHypotheses.log.info(" WARNING: filtered all trees for " + tree);
                }
            }
            if (!filterAsList.equals(convertToTrees)) {
                if (convertToTrees.size() != filterAsList.size()) {
                    throw new AssertionError("horrible error: tree sizes not equal, " + convertToTrees.size() + " vs " + filterAsList.size());
                }
                for (int i = 0; i < convertToTrees.size(); i++) {
                    if (!((Tree) filterAsList.get(i)).equals(convertToTrees.get(i))) {
                        System.out.println("=============================");
                        System.out.println(filterAsList.get(i));
                        System.out.println("=============================");
                        System.out.println(convertToTrees.get(i));
                        System.out.println("=============================");
                        throw new AssertionError("horrible error: tree " + i + " not equal for base tree " + tree);
                    }
                }
            }
            return Pair.makePair(tree, this.cacher.convertToBytes(topParsesForOneTree));
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<Tree, Pair<Tree, byte[]>> newInstance2() {
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/parser/dvparser/CacheParseHypotheses$DecompressionProcessor.class */
    public static class DecompressionProcessor implements ThreadsafeProcessor<byte[], List<Tree>> {
        DecompressionProcessor() {
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public List<Tree> process(byte[] bArr) {
            return CacheParseHypotheses.convertToTrees(bArr);
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<byte[], List<Tree>> newInstance2() {
            return this;
        }
    }

    public CacheParseHypotheses(LexicalizedParser lexicalizedParser) {
        this.treeBasicCategories = new BasicCategoryTreeTransformer(lexicalizedParser.treebankLanguagePack());
        this.treeFilter = new FilterConfusingRules(lexicalizedParser);
    }

    public byte[] convertToBytes(List<Tree> list) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            GZIPOutputStream gZIPOutputStream = new GZIPOutputStream(byteArrayOutputStream);
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(gZIPOutputStream);
            List filterAsList = CollectionUtils.filterAsList(CollectionUtils.transformAsList(list, this.treeBasicCategories), this.treeFilter);
            objectOutputStream.writeObject(Integer.valueOf(filterAsList.size()));
            Iterator it = filterAsList.iterator();
            while (it.hasNext()) {
                objectOutputStream.writeObject(((Tree) it.next()).toString());
            }
            objectOutputStream.close();
            gZIPOutputStream.close();
            byteArrayOutputStream.close();
            return byteArrayOutputStream.toByteArray();
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public IdentityHashMap<Tree, byte[]> convertToBytes(IdentityHashMap<Tree, List<Tree>> identityHashMap) {
        IdentityHashMap<Tree, byte[]> newIdentityHashMap = Generics.newIdentityHashMap();
        for (Map.Entry<Tree, List<Tree>> entry : identityHashMap.entrySet()) {
            newIdentityHashMap.put(entry.getKey(), convertToBytes(entry.getValue()));
        }
        return newIdentityHashMap;
    }

    public static List<Tree> convertToTrees(byte[] bArr) {
        try {
            ArrayList arrayList = new ArrayList();
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bArr);
            GZIPInputStream gZIPInputStream = new GZIPInputStream(byteArrayInputStream);
            ObjectInputStream objectInputStream = new ObjectInputStream(gZIPInputStream);
            int intValue = ((Integer) ErasureUtils.uncheckedCast(objectInputStream.readObject())).intValue();
            for (int i = 0; i < intValue; i++) {
                Tree valueOf = Tree.valueOf((String) ErasureUtils.uncheckedCast(objectInputStream.readObject()), trf);
                valueOf.setSpans();
                arrayList.add(valueOf);
            }
            objectInputStream.close();
            gZIPInputStream.close();
            byteArrayInputStream.close();
            return arrayList;
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        } catch (ClassNotFoundException e2) {
            throw new RuntimeException(e2);
        }
    }

    public static IdentityHashMap<Tree, List<Tree>> convertToTrees(IdentityHashMap<Tree, byte[]> identityHashMap, int i) {
        return convertToTrees(identityHashMap.keySet(), identityHashMap, i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IdentityHashMap<Tree, List<Tree>> convertToTrees(Collection<Tree> collection, IdentityHashMap<Tree, byte[]> identityHashMap, int i) {
        IdentityHashMap<Tree, List<Tree>> newIdentityHashMap = Generics.newIdentityHashMap();
        MulticoreWrapper multicoreWrapper = new MulticoreWrapper(i, new DecompressionProcessor());
        Iterator<Tree> it = collection.iterator();
        while (it.hasNext()) {
            multicoreWrapper.put(identityHashMap.get(it.next()));
        }
        for (Tree tree : collection) {
            if (!multicoreWrapper.peek()) {
                multicoreWrapper.join();
            }
            newIdentityHashMap.put(tree, multicoreWrapper.poll());
        }
        return newIdentityHashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) throws IOException {
        String str = null;
        String str2 = null;
        ArrayList<Pair> newArrayList = Generics.newArrayList();
        int i = 200;
        int i2 = 1;
        int i3 = 0;
        while (i3 < strArr.length) {
            if (strArr[i3].equalsIgnoreCase("-dvKBest")) {
                i = Integer.valueOf(strArr[i3 + 1]).intValue();
                i3 += 2;
            } else if (strArr[i3].equalsIgnoreCase("-parser") || strArr[i3].equals("-model")) {
                str = strArr[i3 + 1];
                i3 += 2;
            } else if (strArr[i3].equalsIgnoreCase("-output")) {
                str2 = strArr[i3 + 1];
                i3 += 2;
            } else if (strArr[i3].equalsIgnoreCase("-treebank")) {
                Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(strArr, i3, "-treebank");
                i3 = i3 + ArgUtils.numSubArgs(strArr, i3) + 1;
                newArrayList.add(treebankDescription);
            } else {
                if (!strArr[i3].equalsIgnoreCase("-numThreads")) {
                    throw new IllegalArgumentException("Unknown argument " + strArr[i3]);
                }
                i2 = Integer.valueOf(strArr[i3 + 1]).intValue();
                i3 += 2;
            }
        }
        if (str == null) {
            throw new IllegalArgumentException("Need to supply a parser model with -model");
        }
        if (str2 == null) {
            throw new IllegalArgumentException("Need to supply an output filename with -output");
        }
        if (newArrayList.size() == 0) {
            throw new IllegalArgumentException("Need to supply a treebank with -treebank");
        }
        log.info("Writing output to " + str2);
        log.info("Loading parser model " + str);
        log.info("Writing " + i + " hypothesis trees for each tree");
        LexicalizedParser loadModel = LexicalizedParser.loadModel(str, "-dvKBest", Integer.toString(i));
        CacheParseHypotheses cacheParseHypotheses = new CacheParseHypotheses(loadModel);
        TreeTransformer buildTrainTransformer = DVParser.buildTrainTransformer(loadModel.getOp());
        ArrayList arrayList = new ArrayList();
        for (Pair pair : newArrayList) {
            log.info("Reading trees from " + ((String) pair.first));
            MemoryTreebank memoryTreebank = loadModel.getOp().tlpParams.memoryTreebank();
            memoryTreebank.loadPath((String) pair.first, (FileFilter) pair.second);
            arrayList.addAll(memoryTreebank.transform(buildTrainTransformer));
        }
        log.info("Processing " + arrayList.size() + " trees");
        ArrayList newArrayList2 = Generics.newArrayList();
        MulticoreWrapper multicoreWrapper = new MulticoreWrapper(i2, new CacheProcessor(cacheParseHypotheses, loadModel, i, new SynchronizedTreeTransformer(buildTrainTransformer)));
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            multicoreWrapper.put((Tree) it.next());
            while (multicoreWrapper.peek()) {
                newArrayList2.add(multicoreWrapper.poll());
                if (newArrayList2.size() % 10 == 0) {
                    System.out.println("Processed " + newArrayList2.size() + " trees");
                }
            }
        }
        multicoreWrapper.join();
        while (multicoreWrapper.peek()) {
            newArrayList2.add(multicoreWrapper.poll());
            if (newArrayList2.size() % 10 == 0) {
                System.out.println("Processed " + newArrayList2.size() + " trees");
            }
        }
        System.out.println("Finished processing " + newArrayList2.size() + " trees");
        IOUtils.writeObjectToFile(newArrayList2, str2);
    }
}
