package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/sentiment/SentimentTraining.class */
public class SentimentTraining {
    private static final Redwood.RedwoodChannels log = Redwood.channels(SentimentTraining.class);
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");

    private SentimentTraining() {
    }

    private static void executeOneTrainingBatch(SentimentModel sentimentModel, List<Tree> list, double[] dArr) {
        SentimentCostAndGradient sentimentCostAndGradient = new SentimentCostAndGradient(sentimentModel, list);
        double[] paramsToVector = sentimentModel.paramsToVector();
        double[] derivativeAt = sentimentCostAndGradient.derivativeAt(paramsToVector);
        log.info("batch cost: " + sentimentCostAndGradient.valueAt(paramsToVector));
        for (int i = 0; i < derivativeAt.length; i++) {
            dArr[i] = dArr[i] + (derivativeAt[i] * derivativeAt[i]);
            paramsToVector[i] = paramsToVector[i] - ((sentimentModel.op.trainOptions.learningRate * derivativeAt[i]) / (Math.sqrt(dArr[i]) + 0.001d));
        }
        sentimentModel.vectorToParams(paramsToVector);
    }

    public static void train(SentimentModel sentimentModel, String str, List<Tree> list, List<Tree> list2) {
        Timing timing = new Timing();
        long j = sentimentModel.op.trainOptions.maxTrainTimeSeconds * 1000;
        int i = 0;
        double[] dArr = new double[sentimentModel.totalParamSize()];
        Arrays.fill(dArr, sentimentModel.op.trainOptions.initialAdagradWeight);
        int size = (list.size() / sentimentModel.op.trainOptions.batchSize) + 1;
        log.info("Training on " + list.size() + " trees in " + size + " batches");
        log.info("Times through each training batch: " + sentimentModel.op.trainOptions.epochs);
        for (int i2 = 0; i2 < sentimentModel.op.trainOptions.epochs; i2++) {
            log.info("======================================");
            log.info("Starting epoch " + i2);
            if (i2 > 0 && sentimentModel.op.trainOptions.adagradResetFrequency > 0 && i2 % sentimentModel.op.trainOptions.adagradResetFrequency == 0) {
                log.info("Resetting adagrad weights to " + sentimentModel.op.trainOptions.initialAdagradWeight);
                Arrays.fill(dArr, sentimentModel.op.trainOptions.initialAdagradWeight);
            }
            ArrayList newArrayList = Generics.newArrayList(list);
            if (sentimentModel.op.trainOptions.shuffleMatrices) {
                Collections.shuffle(newArrayList, sentimentModel.rand);
            }
            for (int i3 = 0; i3 < size; i3++) {
                log.info("======================================");
                log.info("Epoch " + i2 + " batch " + i3);
                int i4 = i3 * sentimentModel.op.trainOptions.batchSize;
                int i5 = (i3 + 1) * sentimentModel.op.trainOptions.batchSize;
                if (i5 > newArrayList.size()) {
                    i5 = newArrayList.size();
                }
                executeOneTrainingBatch(sentimentModel, newArrayList.subList(i4, i5), dArr);
                long report = timing.report();
                log.info("Finished epoch " + i2 + " batch " + i3 + "; total training time " + report + " ms");
                if (j > 0 && report > j) {
                    break;
                }
                if (i3 == size - 1 && sentimentModel.op.trainOptions.debugOutputEpochs > 0 && (i2 + 1) % sentimentModel.op.trainOptions.debugOutputEpochs == 0) {
                    double d = 0.0d;
                    if (list2 != null) {
                        Evaluate evaluate = new Evaluate(sentimentModel);
                        evaluate.eval(list2);
                        evaluate.printSummary();
                        d = evaluate.exactNodeAccuracy() * 100.0d;
                    }
                    if (str != null) {
                        sentimentModel.saveSerialized(str.endsWith(".ser.gz") ? str.substring(0, str.length() - 7) + "-" + FILENAME.format(i) + "-" + NF.format(d) + ".ser.gz" : str.endsWith(".gz") ? str.substring(0, str.length() - 3) + "-" + FILENAME.format(i) + "-" + NF.format(d) + ".gz" : str.substring(0, str.length() - 3) + "-" + FILENAME.format(i) + "-" + NF.format(d));
                    }
                    i++;
                }
            }
            long report2 = timing.report();
            if (j > 0 && report2 > j) {
                log.info("Max training time exceeded, exiting");
                return;
            }
        }
    }

    public static boolean runGradientCheck(SentimentModel sentimentModel, List<Tree> list) {
        return new SentimentCostAndGradient(sentimentModel, list).gradientCheck(sentimentModel.totalParamSize(), 50, sentimentModel.paramsToVector());
    }

    public static void main(String[] strArr) {
        RNNOptions rNNOptions = new RNNOptions();
        String str = "sentimentTreesDebug.txt";
        String str2 = null;
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        String str3 = null;
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-train")) {
                z2 = true;
                i++;
            } else if (strArr[i].equalsIgnoreCase("-gradientcheck")) {
                z = true;
                i++;
            } else if (strArr[i].equalsIgnoreCase("-trainpath")) {
                str = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-devpath")) {
                str2 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-model")) {
                str3 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-filterUnknown")) {
                z3 = true;
                i++;
            } else {
                int option = rNNOptions.setOption(strArr, i);
                if (option == i) {
                    throw new IllegalArgumentException("Unknown argument " + strArr[i]);
                }
                i = option;
            }
        }
        List<Tree> readTreesWithGoldLabels = SentimentUtils.readTreesWithGoldLabels(str);
        log.info("Read in " + readTreesWithGoldLabels.size() + " training trees");
        if (z3) {
            readTreesWithGoldLabels = SentimentUtils.filterUnknownRoots(readTreesWithGoldLabels);
            log.info("Filtered training trees: " + readTreesWithGoldLabels.size());
        }
        List<Tree> list = null;
        if (str2 != null) {
            list = SentimentUtils.readTreesWithGoldLabels(str2);
            log.info("Read in " + list.size() + " dev trees");
            if (z3) {
                list = SentimentUtils.filterUnknownRoots(list);
                log.info("Filtered dev trees: " + list.size());
            }
        }
        log.info("Sentiment model options:\n" + rNNOptions);
        SentimentModel sentimentModel = new SentimentModel(rNNOptions, readTreesWithGoldLabels);
        if (rNNOptions.trainOptions.initialMatrixLogPath != null) {
            StringUtils.printToFile(new File(rNNOptions.trainOptions.initialMatrixLogPath), sentimentModel.toString(), false, false, "utf-8");
        }
        if (z) {
            runGradientCheck(sentimentModel, readTreesWithGoldLabels);
        }
        if (z2) {
            train(sentimentModel, str3, readTreesWithGoldLabels, list);
            sentimentModel.saveSerialized(str3);
        }
    }
}
