package opennlp.tools.ml.maxent.quasinewton;

import java.io.IOException;
import java.util.ArrayList;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.maxent.quasinewton.QNMinimizer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.DataIndexer;

/* loaded from: input_file:opennlp-tools-1.6.0.jar:opennlp/tools/ml/maxent/quasinewton/QNTrainer.class */
public class QNTrainer extends AbstractEventTrainer {
    public static final String MAXENT_QN_VALUE = "MAXENT_QN";
    public static final String THREADS_PARAM = "Threads";
    public static final int THREADS_DEFAULT = 1;
    public static final String L1COST_PARAM = "L1Cost";
    public static final double L1COST_DEFAULT = 0.1d;
    public static final String L2COST_PARAM = "L2Cost";
    public static final double L2COST_DEFAULT = 0.1d;
    public static final String M_PARAM = "NumOfUpdates";
    public static final int M_DEFAULT = 15;
    public static final String MAX_FCT_EVAL_PARAM = "MaxFctEval";
    public static final int MAX_FCT_EVAL_DEFAULT = 30000;
    private int threads;
    private double l1Cost;
    private double l2Cost;
    private int m;
    private int maxFctEval;
    private boolean verbose;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:opennlp-tools-1.6.0.jar:opennlp/tools/ml/maxent/quasinewton/QNTrainer$ModelEvaluator.class */
    public class ModelEvaluator implements QNMinimizer.Evaluator {
        private DataIndexer indexer;

        public ModelEvaluator(DataIndexer dataIndexer) {
            this.indexer = dataIndexer;
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.QNMinimizer.Evaluator
        public double evaluate(double[] dArr) {
            int[][] contexts = this.indexer.getContexts();
            float[][] values = this.indexer.getValues();
            int[] numTimesEventsSeen = this.indexer.getNumTimesEventsSeen();
            int[] outcomeList = this.indexer.getOutcomeList();
            int length = this.indexer.getOutcomeLabels().length;
            int length2 = this.indexer.getPredLabels().length;
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < contexts.length; i3++) {
                int[] iArr = contexts[i3];
                float[] fArr = values == null ? null : values[i3];
                double[] dArr2 = new double[length];
                QNModel.eval(iArr, fArr, dArr2, length, length2, dArr);
                if (ArrayMath.maxIdx(dArr2) == outcomeList[i3]) {
                    i += numTimesEventsSeen[i3];
                }
                i2 += numTimesEventsSeen[i3];
            }
            return i / i2;
        }
    }

    public QNTrainer(boolean z) {
        this(15, z);
    }

    public QNTrainer(int i) {
        this(i, true);
    }

    public QNTrainer(int i, boolean z) {
        this(i, 30000, z);
    }

    public QNTrainer(int i, int i2, boolean z) {
        this.verbose = true;
        this.verbose = z;
        this.m = i < 0 ? 15 : i;
        this.maxFctEval = i2 < 0 ? 30000 : i2;
        this.threads = 1;
        this.l1Cost = 0.1d;
        this.l2Cost = 0.1d;
    }

    public QNTrainer() {
        this.verbose = true;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer, opennlp.tools.ml.AbstractTrainer
    public boolean isValid() {
        int intParam;
        if (!super.isValid()) {
            return false;
        }
        String algorithm = getAlgorithm();
        if ((algorithm != null && !MAXENT_QN_VALUE.equals(algorithm)) || (intParam = getIntParam(M_PARAM, 15)) < 0) {
            return false;
        }
        this.m = intParam;
        int intParam2 = getIntParam(MAX_FCT_EVAL_PARAM, 30000);
        if (intParam2 < 0) {
            return false;
        }
        this.maxFctEval = intParam2;
        int intParam3 = getIntParam(THREADS_PARAM, 1);
        if (intParam3 < 1) {
            return false;
        }
        this.threads = intParam3;
        double doubleParam = getDoubleParam(L1COST_PARAM, 0.1d);
        if (doubleParam < 0.0d) {
            return false;
        }
        this.l1Cost = doubleParam;
        double doubleParam2 = getDoubleParam(L2COST_PARAM, 0.1d);
        if (doubleParam2 < 0.0d) {
            return false;
        }
        this.l2Cost = doubleParam2;
        return true;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public boolean isSortAndMerge() {
        return true;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public AbstractModel doTrain(DataIndexer dataIndexer) throws IOException {
        return trainModel(getIterations(), dataIndexer);
    }

    public QNModel trainModel(int i, DataIndexer dataIndexer) {
        Function parallelNegLogLikelihood;
        if (this.threads == 1) {
            System.out.println("Computing model parameters ...");
            parallelNegLogLikelihood = new NegLogLikelihood(dataIndexer);
        } else {
            System.out.println("Computing model parameters in " + this.threads + " threads ...");
            parallelNegLogLikelihood = new ParallelNegLogLikelihood(dataIndexer, this.threads);
        }
        QNMinimizer qNMinimizer = new QNMinimizer(this.l1Cost, this.l2Cost, i, this.m, this.maxFctEval, this.verbose);
        qNMinimizer.setEvaluator(new ModelEvaluator(dataIndexer));
        double[] minimize = qNMinimizer.minimize(parallelNegLogLikelihood);
        String[] predLabels = dataIndexer.getPredLabels();
        int length = predLabels.length;
        String[] outcomeLabels = dataIndexer.getOutcomeLabels();
        int length2 = outcomeLabels.length;
        Context[] contextArr = new Context[length];
        for (int i2 = 0; i2 < contextArr.length; i2++) {
            ArrayList arrayList = new ArrayList(length2);
            ArrayList arrayList2 = new ArrayList(length2);
            for (int i3 = 0; i3 < length2; i3++) {
                double d = minimize[(i3 * length) + i2];
                arrayList.add(Integer.valueOf(i3));
                arrayList2.add(Double.valueOf(d));
            }
            contextArr[i2] = new Context(ArrayMath.toIntArray(arrayList), ArrayMath.toDoubleArray(arrayList2));
        }
        return new QNModel(contextArr, predLabels, outcomeLabels);
    }
}
