package opennlp.tools.ml.maxent;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Prior;
import opennlp.tools.ml.model.UniformPrior;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/* loaded from: input_file:opennlp/tools/ml/maxent/GISTrainer.class */
public class GISTrainer extends AbstractEventTrainer {

    @Deprecated
    public static final String OLD_LL_THRESHOLD_PARAM = "llthreshold";
    public static final String LOG_LIKELIHOOD_THRESHOLD_PARAM = "LLThreshold";
    public static final double LOG_LIKELIHOOD_THRESHOLD_DEFAULT = 1.0E-4d;
    private double llThreshold = 1.0E-4d;
    private boolean useSimpleSmoothing = false;
    private boolean useGaussianSmoothing = false;
    private double sigma = GAUSSIAN_SMOOTHING_SIGMA_DEFAULT;
    private double _smoothingObservation = 0.1d;
    private int numUniqueEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;
    private MutableContext[] observedExpects;
    private MutableContext[] params;
    private MutableContext[][] modelExpects;
    private Prior prior;
    private EvalParameters evalParams;
    public static final String MAXENT_VALUE = "MAXENT";
    private static final String SMOOTHING_PARAM = "Smoothing";
    private static final boolean SMOOTHING_DEFAULT = false;
    private static final String SMOOTHING_OBSERVATION_PARAM = "SmoothingObservation";
    private static final double SMOOTHING_OBSERVATION = 0.1d;
    private static final String GAUSSIAN_SMOOTHING_PARAM = "GaussianSmoothing";
    private static final boolean GAUSSIAN_SMOOTHING_DEFAULT = false;
    private static final String GAUSSIAN_SMOOTHING_SIGMA_PARAM = "GaussianSmoothingSigma";
    private static final double GAUSSIAN_SMOOTHING_SIGMA_DEFAULT = 2.0d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:opennlp/tools/ml/maxent/GISTrainer$ModelExpectationComputeTask.class */
    public class ModelExpectationComputeTask implements Callable<ModelExpectationComputeTask> {
        private final int startIndex;
        private final int length;
        private final int threadIndex;
        private double loglikelihood = 0.0d;
        private int numEvents = 0;
        private int numCorrect = 0;

        ModelExpectationComputeTask(int i, int i2, int i3) {
            this.startIndex = i2;
            this.length = i3;
            this.threadIndex = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public ModelExpectationComputeTask call() {
            double[] dArr = new double[GISTrainer.this.numOutcomes];
            for (int i = this.startIndex; i < this.startIndex + this.length; i++) {
                if (GISTrainer.this.values != null) {
                    GISTrainer.this.prior.logPrior(dArr, GISTrainer.this.contexts[i], GISTrainer.this.values[i]);
                    GISModel.eval(GISTrainer.this.contexts[i], GISTrainer.this.values[i], dArr, GISTrainer.this.evalParams);
                } else {
                    GISTrainer.this.prior.logPrior(dArr, GISTrainer.this.contexts[i]);
                    GISModel.eval(GISTrainer.this.contexts[i], dArr, GISTrainer.this.evalParams);
                }
                for (int i2 = 0; i2 < GISTrainer.this.contexts[i].length; i2++) {
                    int i3 = GISTrainer.this.contexts[i][i2];
                    int[] outcomes = GISTrainer.this.modelExpects[this.threadIndex][i3].getOutcomes();
                    for (int i4 = 0; i4 < outcomes.length; i4++) {
                        int i5 = outcomes[i4];
                        if (GISTrainer.this.values == null || GISTrainer.this.values[i] == null) {
                            GISTrainer.this.modelExpects[this.threadIndex][i3].updateParameter(i4, dArr[i5] * GISTrainer.this.numTimesEventsSeen[i]);
                        } else {
                            GISTrainer.this.modelExpects[this.threadIndex][i3].updateParameter(i4, dArr[i5] * GISTrainer.this.values[i][i2] * GISTrainer.this.numTimesEventsSeen[i]);
                        }
                    }
                }
                this.loglikelihood += StrictMath.log(dArr[GISTrainer.this.outcomeList[i]]) * GISTrainer.this.numTimesEventsSeen[i];
                this.numEvents += GISTrainer.this.numTimesEventsSeen[i];
                if (GISTrainer.this.printMessages && ArrayMath.argmax(dArr) == GISTrainer.this.outcomeList[i]) {
                    this.numCorrect += GISTrainer.this.numTimesEventsSeen[i];
                }
            }
            return this;
        }

        synchronized int getNumEvents() {
            return this.numEvents;
        }

        synchronized int getNumCorrect() {
            return this.numCorrect;
        }

        synchronized double getLoglikelihood() {
            return this.loglikelihood;
        }
    }

    public GISTrainer() {
        this.printMessages = false;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public boolean isSortAndMerge() {
        return true;
    }

    @Override // opennlp.tools.ml.AbstractTrainer, opennlp.tools.ml.EventTrainer
    public void init(TrainingParameters trainingParameters, Map<String, String> map) {
        super.init(trainingParameters, map);
        if (trainingParameters.getDoubleParameter(OLD_LL_THRESHOLD_PARAM, -1.0d) > 0.0d) {
            display("WARNING: the training parameter: llthreshold has been deprecated.  Please use 1.0E-4 instead");
            if (trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, -1.0d) < 0.0d) {
                trainingParameters.put(LOG_LIKELIHOOD_THRESHOLD_PARAM, trainingParameters.getDoubleParameter(OLD_LL_THRESHOLD_PARAM, 1.0E-4d));
            }
        }
        this.llThreshold = trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, 1.0E-4d);
        this.useSimpleSmoothing = trainingParameters.getBooleanParameter(SMOOTHING_PARAM, false);
        if (this.useSimpleSmoothing) {
            this._smoothingObservation = trainingParameters.getDoubleParameter(SMOOTHING_OBSERVATION_PARAM, 0.1d);
        }
        this.useGaussianSmoothing = trainingParameters.getBooleanParameter(GAUSSIAN_SMOOTHING_PARAM, false);
        if (this.useGaussianSmoothing) {
            this.sigma = trainingParameters.getDoubleParameter(GAUSSIAN_SMOOTHING_SIGMA_PARAM, GAUSSIAN_SMOOTHING_SIGMA_DEFAULT);
        }
        if (this.useSimpleSmoothing && this.useGaussianSmoothing) {
            throw new RuntimeException("Cannot set both Gaussian smoothing and Simple smoothing");
        }
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public MaxentModel doTrain(DataIndexer dataIndexer) throws IOException {
        return trainModel(getIterations(), dataIndexer, this.trainingParameters.getIntParameter("Threads", 1));
    }

    GISTrainer(boolean z) {
        this.printMessages = z;
    }

    public void setSmoothing(boolean z) {
        this.useSimpleSmoothing = z;
    }

    public void setSmoothingObservation(double d) {
        this._smoothingObservation = d;
    }

    public void setGaussianSigma(double d) {
        this.useGaussianSmoothing = true;
        this.sigma = d;
    }

    public GISModel trainModel(ObjectStream<Event> objectStream) throws IOException {
        return trainModel(objectStream, 100, 0);
    }

    public GISModel trainModel(ObjectStream<Event> objectStream, int i, int i2) throws IOException {
        OnePassDataIndexer onePassDataIndexer = new OnePassDataIndexer();
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", i2);
        trainingParameters.put("Iterations", i);
        onePassDataIndexer.init(trainingParameters, new HashMap());
        onePassDataIndexer.index(objectStream);
        return trainModel(i, onePassDataIndexer);
    }

    public GISModel trainModel(int i, DataIndexer dataIndexer) {
        return trainModel(i, dataIndexer, new UniformPrior(), 1);
    }

    public GISModel trainModel(int i, DataIndexer dataIndexer, int i2) {
        return trainModel(i, dataIndexer, new UniformPrior(), i2);
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [opennlp.tools.ml.model.MutableContext[], opennlp.tools.ml.model.MutableContext[][]] */
    public GISModel trainModel(int i, DataIndexer dataIndexer, Prior prior, int i2) {
        int[] iArr;
        if (i2 <= 0) {
            throw new IllegalArgumentException("threads must be at least one or greater but is " + i2 + "!");
        }
        this.modelExpects = new MutableContext[i2];
        display("Incorporating indexed data for training...  \n");
        this.contexts = dataIndexer.getContexts();
        this.values = dataIndexer.getValues();
        dataIndexer.getPredCounts();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numUniqueEvents = this.contexts.length;
        this.prior = prior;
        double d = 0.0d;
        for (int i3 = 0; i3 < this.contexts.length; i3++) {
            if (this.values != null && this.values[i3] != null) {
                float f = this.values[i3][0];
                for (int i4 = 1; i4 < this.values[i3].length; i4++) {
                    f += this.values[i3][i4];
                }
                if (f > d) {
                    d = f;
                }
            } else if (this.contexts[i3].length > d) {
                d = this.contexts[i3].length;
            }
        }
        display("done.\n");
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.numOutcomes = this.outcomeLabels.length;
        this.predLabels = dataIndexer.getPredLabels();
        this.prior.setLabels(this.outcomeLabels, this.predLabels);
        this.numPreds = this.predLabels.length;
        display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        display("\t  Number of Predicates: " + this.numPreds + "\n");
        float[][] fArr = new float[this.numPreds][this.numOutcomes];
        for (int i5 = 0; i5 < this.numUniqueEvents; i5++) {
            for (int i6 = 0; i6 < this.contexts[i5].length; i6++) {
                if (this.values == null || this.values[i5] == null) {
                    float[] fArr2 = fArr[this.contexts[i5][i6]];
                    int i7 = this.outcomeList[i5];
                    fArr2[i7] = fArr2[i7] + this.numTimesEventsSeen[i5];
                } else {
                    float[] fArr3 = fArr[this.contexts[i5][i6]];
                    int i8 = this.outcomeList[i5];
                    fArr3[i8] = fArr3[i8] + (this.numTimesEventsSeen[i5] * this.values[i5][i6]);
                }
            }
        }
        double d2 = this._smoothingObservation;
        this.params = new MutableContext[this.numPreds];
        for (int i9 = 0; i9 < this.modelExpects.length; i9++) {
            this.modelExpects[i9] = new MutableContext[this.numPreds];
        }
        this.observedExpects = new MutableContext[this.numPreds];
        this.evalParams = new EvalParameters(this.params, this.numOutcomes);
        int[] iArr2 = new int[this.numOutcomes];
        int[] iArr3 = new int[this.numOutcomes];
        for (int i10 = 0; i10 < this.numOutcomes; i10++) {
            iArr3[i10] = i10;
        }
        for (int i11 = 0; i11 < this.numPreds; i11++) {
            int i12 = 0;
            if (this.useSimpleSmoothing) {
                i12 = this.numOutcomes;
                iArr = iArr3;
            } else {
                for (int i13 = 0; i13 < this.numOutcomes; i13++) {
                    if (fArr[i11][i13] > 0.0f) {
                        iArr2[i12] = i13;
                        i12++;
                    }
                }
                if (i12 == this.numOutcomes) {
                    iArr = iArr3;
                } else {
                    iArr = new int[i12];
                    System.arraycopy(iArr2, 0, iArr, 0, i12);
                }
            }
            this.params[i11] = new MutableContext(iArr, new double[i12]);
            for (int i14 = 0; i14 < this.modelExpects.length; i14++) {
                this.modelExpects[i14][i11] = new MutableContext(iArr, new double[i12]);
            }
            this.observedExpects[i11] = new MutableContext(iArr, new double[i12]);
            for (int i15 = 0; i15 < i12; i15++) {
                int i16 = iArr[i15];
                this.params[i11].setParameter(i15, 0.0d);
                for (MutableContext[] mutableContextArr : this.modelExpects) {
                    mutableContextArr[i11].setParameter(i15, 0.0d);
                }
                if (fArr[i11][i16] > 0.0f) {
                    this.observedExpects[i11].setParameter(i15, fArr[i11][i16]);
                } else if (this.useSimpleSmoothing) {
                    this.observedExpects[i11].setParameter(i15, d2);
                }
            }
        }
        display("...done.\n");
        if (i2 == 1) {
            display("Computing model parameters ...\n");
        } else {
            display("Computing model parameters in " + i2 + " threads...\n");
        }
        findParameters(i, d);
        return new GISModel(this.params, this.predLabels, this.outcomeLabels);
    }

    private void findParameters(int i, double d) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.modelExpects.length, runnable -> {
            Thread thread = new Thread(runnable);
            thread.setName("opennlp.tools.ml.maxent.ModelExpactationComputeTask.nextIteration()");
            thread.setDaemon(true);
            return thread;
        });
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(newFixedThreadPool);
        double d2 = 0.0d;
        display("Performing " + i + " iterations.\n");
        int i2 = 1;
        while (true) {
            if (i2 > i) {
                break;
            }
            if (i2 < 10) {
                display("  " + i2 + ":  ");
            } else if (i2 < 100) {
                display(" " + i2 + ":  ");
            } else {
                display(i2 + ":  ");
            }
            double nextIteration = nextIteration(d, executorCompletionService);
            if (i2 > 1) {
                if (d2 > nextIteration) {
                    System.err.println("Model Diverging: loglikelihood decreased");
                    break;
                } else if (nextIteration - d2 < this.llThreshold) {
                    break;
                }
            }
            d2 = nextIteration;
            i2++;
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.contexts = null;
        newFixedThreadPool.shutdown();
    }

    private double gaussianUpdate(int i, int i2, double d) {
        double d2 = this.params[i].getParameters()[i2];
        double d3 = 0.0d;
        double d4 = this.modelExpects[0][i].getParameters()[i2];
        double d5 = this.observedExpects[i].getParameters()[i2];
        int i3 = 0;
        while (true) {
            if (i3 >= 50) {
                break;
            }
            double exp = d4 * StrictMath.exp(d * d3);
            double d6 = (exp + ((d2 + d3) / this.sigma)) - d5;
            double d7 = (exp * d) + (1.0d / this.sigma);
            if (d7 == 0.0d) {
                break;
            }
            double d8 = d3 - (d6 / d7);
            if (StrictMath.abs(d8 - d3) < 1.0E-6d) {
                d3 = d8;
                break;
            }
            d3 = d8;
            i3++;
        }
        return d3;
    }

    private double nextIteration(double d, CompletionService<ModelExpectationComputeTask> completionService) {
        double d2 = 0.0d;
        int i = 0;
        int i2 = 0;
        int length = this.modelExpects.length;
        int i3 = this.numUniqueEvents / length;
        int i4 = this.numUniqueEvents % length;
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 < i4) {
                completionService.submit(new ModelExpectationComputeTask(i5, (i5 * i3) + i5, i3 + 1));
            } else {
                completionService.submit(new ModelExpectationComputeTask(i5, (i5 * i3) + i4, i3));
            }
        }
        for (int i6 = 0; i6 < length; i6++) {
            try {
                ModelExpectationComputeTask modelExpectationComputeTask = completionService.take().get();
                i += modelExpectationComputeTask.getNumEvents();
                i2 += modelExpectationComputeTask.getNumCorrect();
                d2 += modelExpectationComputeTask.getLoglikelihood();
            } catch (InterruptedException e) {
                e.printStackTrace();
                throw new IllegalStateException("Interruption is not supported!", e);
            } catch (ExecutionException e2) {
                throw new RuntimeException("Exception during training: " + e2.getMessage(), e2);
            }
        }
        display(".");
        for (int i7 = 0; i7 < this.numPreds; i7++) {
            int[] outcomes = this.params[i7].getOutcomes();
            for (int i8 = 0; i8 < outcomes.length; i8++) {
                for (int i9 = 1; i9 < this.modelExpects.length; i9++) {
                    this.modelExpects[0][i7].updateParameter(i8, this.modelExpects[i9][i7].getParameters()[i8]);
                }
            }
        }
        display(".");
        for (int i10 = 0; i10 < this.numPreds; i10++) {
            double[] parameters = this.observedExpects[i10].getParameters();
            double[] parameters2 = this.modelExpects[0][i10].getParameters();
            int[] outcomes2 = this.params[i10].getOutcomes();
            for (int i11 = 0; i11 < outcomes2.length; i11++) {
                if (this.useGaussianSmoothing) {
                    this.params[i10].updateParameter(i11, gaussianUpdate(i10, i11, d));
                } else {
                    if (parameters2[i11] == 0.0d) {
                        System.err.println("Model expects == 0 for " + this.predLabels[i10] + " " + this.outcomeLabels[i11]);
                    }
                    this.params[i10].updateParameter(i11, (StrictMath.log(parameters[i11]) - StrictMath.log(parameters2[i11])) / d);
                }
                for (MutableContext[] mutableContextArr : this.modelExpects) {
                    mutableContextArr[i10].setParameter(i11, 0.0d);
                }
            }
        }
        double d3 = i2 / i;
        display(". loglikelihood=" + d2 + "\t" + this + "\n");
        return d2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // opennlp.tools.ml.AbstractTrainer
    public void display(String str) {
        if (this.printMessages) {
            System.out.print(str);
        }
    }
}
