package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.CrossValidator;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.international.morph.MorphoFeatures;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.optimization.LineSearcher;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.trees.international.negra.NegraLabel;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.SystemUtils;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/classify/SVMLightClassifierFactory.class */
public class SVMLightClassifierFactory<L, F> implements ClassifierFactory<L, F, SVMLightClassifier<L, F>> {
    private static final long serialVersionUID = 1;
    protected double C;
    private boolean useSigmoid;
    protected boolean verbose;
    private String svmLightLearn;
    private String svmStructLearn;
    private String svmLightClassify;
    private String svmStructClassify;
    private boolean useAlphaFile;
    protected File alphaFile;
    private boolean deleteTempFilesOnExit;
    private int svmLightVerbosity;
    private boolean doEval;
    private boolean tuneHeldOut;
    private boolean tuneCV;
    private Scorer scorer;
    private LineSearcher tuneMinimizer;
    private int folds;
    private double heldOutPercent;
    Pattern whitespacePattern;

    public SVMLightClassifierFactory(String str, String str2) {
        this.C = -1.0d;
        this.useSigmoid = false;
        this.verbose = true;
        this.svmLightLearn = "/u/nlp/packages/svm_light/svm_learn";
        this.svmStructLearn = "/u/nlp/packages/svm_multiclass/svm_multiclass_learn";
        this.svmLightClassify = "/u/nlp/packages/svm_light/svm_classify";
        this.svmStructClassify = "/u/nlp/packages/svm_multiclass/svm_multiclass_classify";
        this.useAlphaFile = false;
        this.deleteTempFilesOnExit = true;
        this.svmLightVerbosity = 0;
        this.doEval = false;
        this.tuneHeldOut = false;
        this.tuneCV = false;
        this.scorer = new MultiClassAccuracyStats();
        this.tuneMinimizer = new GoldenSectionLineSearch(true);
        this.whitespacePattern = Pattern.compile("\\s+");
        this.svmLightLearn = str;
        this.svmStructLearn = str2;
    }

    public SVMLightClassifierFactory() {
        this.C = -1.0d;
        this.useSigmoid = false;
        this.verbose = true;
        this.svmLightLearn = "/u/nlp/packages/svm_light/svm_learn";
        this.svmStructLearn = "/u/nlp/packages/svm_multiclass/svm_multiclass_learn";
        this.svmLightClassify = "/u/nlp/packages/svm_light/svm_classify";
        this.svmStructClassify = "/u/nlp/packages/svm_multiclass/svm_multiclass_classify";
        this.useAlphaFile = false;
        this.deleteTempFilesOnExit = true;
        this.svmLightVerbosity = 0;
        this.doEval = false;
        this.tuneHeldOut = false;
        this.tuneCV = false;
        this.scorer = new MultiClassAccuracyStats();
        this.tuneMinimizer = new GoldenSectionLineSearch(true);
        this.whitespacePattern = Pattern.compile("\\s+");
    }

    public void setC(double d) {
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setUseSigmoid(boolean z) {
        this.useSigmoid = z;
    }

    public boolean getUseSigma() {
        return this.useSigmoid;
    }

    public boolean getDeleteTempFilesOnExitFlag() {
        return this.deleteTempFilesOnExit;
    }

    public void setDeleteTempFilesOnExitFlag(boolean z) {
        this.deleteTempFilesOnExit = z;
    }

    private static Pair<Double, ClassicCounter<Integer>> readModel(File file, boolean z) {
        int i = 0;
        try {
            int i2 = z ? 13 : 10;
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            for (int i3 = 0; i3 < i2; i3++) {
                bufferedReader.readLine();
                i++;
            }
            ArrayList arrayList = new ArrayList();
            i++;
            double parseDouble = Double.parseDouble(bufferedReader.readLine().split("\\s+")[0]);
            while (bufferedReader.ready()) {
                i++;
                String[] split = bufferedReader.readLine().split("\\s+");
                double parseDouble2 = Double.parseDouble(split[0]);
                ClassicCounter classicCounter = new ClassicCounter();
                for (int i4 = 1; i4 < split.length; i4++) {
                    String str = split[i4];
                    if (str.equals(NegraLabel.FEATURE_SEP)) {
                        break;
                    }
                    String[] split2 = str.split(MorphoFeatures.KEY_VAL_DELIM);
                    String str2 = split2[0];
                    if (!str2.equals("qid")) {
                        classicCounter.incrementCount(Integer.valueOf(str2), Double.parseDouble(split2[1]));
                    }
                }
                arrayList.add(new Pair(Double.valueOf(parseDouble2), classicCounter));
            }
            bufferedReader.close();
            return new Pair<>(Double.valueOf(parseDouble), getWeights(arrayList));
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("Error reading SVM model (line " + i + " in file " + file.getAbsolutePath() + ")");
        }
    }

    private static ClassicCounter<Integer> getWeights(List<Pair<Double, ClassicCounter<Integer>>> list) {
        ClassicCounter<Integer> classicCounter = new ClassicCounter<>();
        for (Pair<Double, ClassicCounter<Integer>> pair : list) {
            ClassicCounter classicCounter2 = new ClassicCounter(pair.second());
            Counters.multiplyInPlace(classicCounter2, pair.first().doubleValue());
            Counters.addInPlace(classicCounter, classicCounter2);
        }
        return classicCounter;
    }

    private ClassicCounter<Pair<F, L>> convertWeights(ClassicCounter<Integer> classicCounter, Index<F> index, Index<L> index2, boolean z) {
        return z ? convertSVMStructWeights(classicCounter, index, index2) : convertSVMLightWeights(classicCounter, index, index2);
    }

    private ClassicCounter<Pair<F, L>> convertSVMLightWeights(ClassicCounter<Integer> classicCounter, Index<F> index, Index<L> index2) {
        ClassicCounter<Pair<F, L>> classicCounter2 = new ClassicCounter<>();
        Iterator<Integer> it = classicCounter.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            F f = index.get(intValue - 1);
            double count = classicCounter.getCount(Integer.valueOf(intValue));
            classicCounter2.incrementCount(new Pair<>(f, index2.get(0)), count);
            classicCounter2.incrementCount(new Pair<>(f, index2.get(1)), -count);
        }
        return classicCounter2;
    }

    private ClassicCounter<Pair<F, L>> convertSVMStructWeights(ClassicCounter<Integer> classicCounter, Index<F> index, Index<L> index2) {
        int size = index.size();
        ClassicCounter<Pair<F, L>> classicCounter2 = new ClassicCounter<>();
        Iterator<Integer> it = classicCounter.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            L l = index2.get((intValue - 1) / size);
            F f = index.get((intValue - 1) % size);
            classicCounter2.incrementCount(new Pair<>(f, l), classicCounter.getCount(Integer.valueOf(intValue)));
        }
        return classicCounter2;
    }

    private LinearClassifier<L, L> fitSigmoid(SVMLightClassifier<L, F> sVMLightClassifier, GeneralDataset<L, F> generalDataset) {
        RVFDataset rVFDataset = new RVFDataset();
        for (int i = 0; i < generalDataset.size(); i++) {
            RVFDatum<L, F> rVFDatum = generalDataset.getRVFDatum(i);
            Counter<L> scoresOf = sVMLightClassifier.scoresOf((Datum) rVFDatum);
            scoresOf.incrementCount(null);
            rVFDataset.add(new RVFDatum(scoresOf, rVFDatum.label()));
        }
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory();
        linearClassifierFactory.setPrior(new LogPrior(LogPrior.LogPriorType.NULL));
        return linearClassifierFactory.trainClassifier((GeneralDataset) rVFDataset);
    }

    public void crossValidateSetC(GeneralDataset<L, F> generalDataset, int i, final Scorer<L> scorer, LineSearcher lineSearcher) {
        System.out.println("in Cross Validate");
        this.useAlphaFile = true;
        boolean z = this.useSigmoid;
        this.useSigmoid = false;
        final CrossValidator crossValidator = new CrossValidator(generalDataset, i);
        final Function<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState>, Double> function = new Function<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState>, Double>() { // from class: edu.stanford.nlp.classify.SVMLightClassifierFactory.1
            @Override // edu.stanford.nlp.util.Function
            public Double apply(Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState> triple) {
                GeneralDataset<L, F> first = triple.first();
                GeneralDataset<L, F> second = triple.second();
                SVMLightClassifierFactory.this.alphaFile = (File) triple.third().state;
                SVMLightClassifier<L, F> trainClassifierBasic = SVMLightClassifierFactory.this.trainClassifierBasic(first);
                triple.third().state = SVMLightClassifierFactory.this.alphaFile;
                return Double.valueOf(scorer.score(trainClassifierBasic, second));
            }
        };
        this.C = lineSearcher.minimize(new Function<Double, Double>() { // from class: edu.stanford.nlp.classify.SVMLightClassifierFactory.2
            @Override // edu.stanford.nlp.util.Function
            public Double apply(Double d) {
                SVMLightClassifierFactory.this.C = d.doubleValue();
                if (SVMLightClassifierFactory.this.verbose) {
                    System.out.print("C = " + d + " ");
                }
                Double valueOf = Double.valueOf(crossValidator.computeAverage(function));
                if (SVMLightClassifierFactory.this.verbose) {
                    System.out.println(" -> average Score: " + valueOf);
                }
                return Double.valueOf(-valueOf.doubleValue());
            }
        });
        this.useAlphaFile = false;
        this.useSigmoid = z;
    }

    public void heldOutSetC(GeneralDataset<L, F> generalDataset, double d, Scorer<L> scorer, LineSearcher lineSearcher) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = generalDataset.split(d);
        heldOutSetC(split.first(), split.second(), scorer, lineSearcher);
    }

    public void heldOutSetC(final GeneralDataset<L, F> generalDataset, final GeneralDataset<L, F> generalDataset2, final Scorer<L> scorer, LineSearcher lineSearcher) {
        this.useAlphaFile = true;
        boolean z = this.useSigmoid;
        this.useSigmoid = false;
        this.C = lineSearcher.minimize(new Function<Double, Double>() { // from class: edu.stanford.nlp.classify.SVMLightClassifierFactory.3
            @Override // edu.stanford.nlp.util.Function
            public Double apply(Double d) {
                SVMLightClassifierFactory.this.C = d.doubleValue();
                return Double.valueOf(-scorer.score(SVMLightClassifierFactory.this.trainClassifierBasic(generalDataset), generalDataset2));
            }
        });
        this.useAlphaFile = false;
        this.useSigmoid = z;
    }

    @Override // edu.stanford.nlp.classify.ClassifierFactory
    @Deprecated
    public SVMLightClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> list) {
        return null;
    }

    public double getHeldOutPercent() {
        return this.heldOutPercent;
    }

    public void setHeldOutPercent(double d) {
        this.heldOutPercent = d;
    }

    public int getFolds() {
        return this.folds;
    }

    public void setFolds(int i) {
        this.folds = i;
    }

    public LineSearcher getTuneMinimizer() {
        return this.tuneMinimizer;
    }

    public void setTuneMinimizer(LineSearcher lineSearcher) {
        this.tuneMinimizer = lineSearcher;
    }

    public Scorer getScorer() {
        return this.scorer;
    }

    public void setScorer(Scorer scorer) {
        this.scorer = scorer;
    }

    public boolean getTuneCV() {
        return this.tuneCV;
    }

    public void setTuneCV(boolean z) {
        this.tuneCV = z;
    }

    public boolean getTuneHeldOut() {
        return this.tuneHeldOut;
    }

    public void setTuneHeldOut(boolean z) {
        this.tuneHeldOut = z;
    }

    public int getSvmLightVerbosity() {
        return this.svmLightVerbosity;
    }

    public void setSvmLightVerbosity(int i) {
        this.svmLightVerbosity = i;
    }

    @Override // edu.stanford.nlp.classify.ClassifierFactory
    public SVMLightClassifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset) {
        if (this.tuneHeldOut) {
            heldOutSetC(generalDataset, this.heldOutPercent, this.scorer, this.tuneMinimizer);
        } else if (this.tuneCV) {
            crossValidateSetC(generalDataset, this.folds, this.scorer, this.tuneMinimizer);
        }
        return trainClassifierBasic(generalDataset);
    }

    public SVMLightClassifier<L, F> trainClassifierBasic(GeneralDataset<L, F> generalDataset) {
        Index<L> labelIndex = generalDataset.labelIndex();
        Index<F> index = generalDataset.featureIndex;
        boolean z = generalDataset.numClasses() > 2;
        try {
            File createTempFile = File.createTempFile("svm-", ".model");
            if (this.deleteTempFilesOnExit) {
                createTempFile.deleteOnExit();
            }
            File createTempFile2 = File.createTempFile("svm-", ".data");
            if (this.deleteTempFilesOnExit) {
                createTempFile2.deleteOnExit();
            }
            PrintWriter printWriter = new PrintWriter(new FileWriter(createTempFile2));
            generalDataset.printSVMLightFormat(printWriter);
            printWriter.close();
            String str = (z ? this.svmStructLearn : this.svmLightLearn) + " -v " + this.svmLightVerbosity + " -m 400 ";
            if (this.C > 0.0d) {
                str = str + " -c " + this.C + " ";
            }
            if (this.useAlphaFile) {
                File createTempFile3 = File.createTempFile("svm-", ".alphas");
                if (this.deleteTempFilesOnExit) {
                    createTempFile3.deleteOnExit();
                }
                str = str + " -a " + createTempFile3.getAbsolutePath();
                if (this.alphaFile != null) {
                    str = str + " -y " + this.alphaFile.getAbsolutePath();
                }
                this.alphaFile = createTempFile3;
            }
            String str2 = str + " " + createTempFile2.getAbsolutePath() + " " + createTempFile.getAbsolutePath();
            if (this.verbose) {
                System.err.println("<< " + str2 + " >>");
            }
            SystemUtils.run(new ProcessBuilder(this.whitespacePattern.split(str2)), new PrintWriter(System.err), new PrintWriter(System.err));
            if (this.doEval) {
                File createTempFile4 = File.createTempFile("svm-", ".pred");
                if (this.deleteTempFilesOnExit) {
                    createTempFile4.deleteOnExit();
                }
                String str3 = (z ? this.svmStructClassify : this.svmLightClassify) + " " + createTempFile2.getAbsolutePath() + " " + createTempFile.getAbsolutePath() + " " + createTempFile4.getAbsolutePath();
                if (this.verbose) {
                    System.err.println("<< " + str3 + " >>");
                }
                SystemUtils.run(new ProcessBuilder(this.whitespacePattern.split(str3)), new PrintWriter(System.err), new PrintWriter(System.err));
            }
            Pair<Double, ClassicCounter<Integer>> readModel = readModel(createTempFile, z);
            double doubleValue = readModel.first().doubleValue();
            ClassicCounter<Pair<F, L>> convertWeights = convertWeights(readModel.second(), index, labelIndex, z);
            ClassicCounter classicCounter = new ClassicCounter();
            if (!z) {
                classicCounter.setCount(labelIndex.get(0), -doubleValue);
                classicCounter.setCount(labelIndex.get(1), doubleValue);
            }
            SVMLightClassifier<L, F> sVMLightClassifier = new SVMLightClassifier<>(convertWeights, classicCounter);
            if (this.doEval) {
                File createTempFile5 = File.createTempFile("svm-", ".pred2");
                if (this.deleteTempFilesOnExit) {
                    createTempFile5.deleteOnExit();
                }
                PrintWriter printWriter2 = new PrintWriter(createTempFile5);
                NumberFormat numberInstance = NumberFormat.getNumberInstance();
                numberInstance.setMaximumFractionDigits(5);
                Iterator<RVFDatum<L, F>> it = generalDataset.iterator();
                while (it.hasNext()) {
                    printWriter2.println(Counters.toString(sVMLightClassifier.scoresOf((Datum) it.next()), numberInstance));
                }
                printWriter2.close();
            }
            if (this.useSigmoid) {
                if (this.verbose) {
                    System.out.print("fitting sigmoid...");
                }
                sVMLightClassifier.setPlatt(fitSigmoid(sVMLightClassifier, generalDataset));
                if (this.verbose) {
                    System.out.println("done");
                }
            }
            return sVMLightClassifier;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
