/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.optimizer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.optimizer.SolverException;
import net.finmath.optimizer.StochasticOptimizer;
import net.finmath.stochastic.RandomVariable;

public abstract class StochasticLevenbergMarquardt
implements Serializable,
Cloneable,
StochasticOptimizer {
    private static final long serialVersionUID = 4560864869394838155L;
    private final RegularizationMethod regularizationMethod;
    private RandomVariable[] initialParameters = null;
    private RandomVariable[] parameterSteps = null;
    private RandomVariable[] targetValues = null;
    private final int maxIteration;
    private double lambda;
    private final double lambdaInitialValue = 0.001;
    private double lambdaDivisor = 3.0;
    private double lambdaMultiplicator = 2.0;
    private final double errorTolerance;
    private int iteration = 0;
    private RandomVariable[] parameterTest = null;
    private RandomVariable[] valueTest = null;
    private RandomVariable[] parameterCurrent = null;
    private RandomVariable[] valueCurrent = null;
    private RandomVariable[][] derivativeCurrent = null;
    private double errorMeanSquaredCurrent = Double.POSITIVE_INFINITY;
    private double errorRootMeanSquaredChange = Double.POSITIVE_INFINITY;
    private boolean isParameterCurrentDerivativeValid;
    private int numberOfThreads = 1;
    private ExecutorService executor = null;
    private boolean executorShutdownWhenDone = true;
    private final Logger logger = Logger.getLogger("net.finmath");

    public static void main(String[] args) throws SolverException {
        RandomVariable[] initialParameters = new RandomVariable[]{new RandomVariableFromDoubleArray(2.0), new RandomVariableFromDoubleArray(2.0)};
        RandomVariable[] parameterSteps = new RandomVariable[]{new RandomVariableFromDoubleArray(1.0), new RandomVariableFromDoubleArray(1.0)};
        int maxIteration = 100;
        RandomVariable[] targetValues = new RandomVariable[]{new RandomVariableFromDoubleArray(25.0), new RandomVariableFromDoubleArray(100.0)};
        StochasticLevenbergMarquardt optimizer = new StochasticLevenbergMarquardt(initialParameters, targetValues, parameterSteps, 100, 1.0E-12, null){
            private static final long serialVersionUID = -282626938650139518L;

            @Override
            public void setValues(RandomVariable[] parameters, RandomVariable[] values) {
                values[0] = parameters[0].mult(0.0).add(parameters[1]).squared();
                values[1] = parameters[0].mult(2.0).add(parameters[1]).squared();
            }
        };
        optimizer.run();
        RandomVariable[] bestParameters = optimizer.getBestFitParameters();
        System.out.println("The solver for problem 1 required " + optimizer.getIterations() + " iterations. The best fit parameters are:");
        for (int i = 0; i < bestParameters.length; ++i) {
            System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
        }
        System.out.println("The solver accuracy is " + optimizer.getRootMeanSquaredError());
    }

    public StochasticLevenbergMarquardt(RegularizationMethod regularizationMethod, RandomVariable[] initialParameters, RandomVariable[] targetValues, RandomVariable[] parameterSteps, int maxIteration, double errorTolerance, ExecutorService executorService) {
        this.regularizationMethod = regularizationMethod;
        this.initialParameters = initialParameters;
        this.targetValues = targetValues;
        this.parameterSteps = parameterSteps;
        this.maxIteration = maxIteration;
        this.errorTolerance = errorTolerance;
        this.executor = executorService;
        this.executorShutdownWhenDone = executorService == null;
    }

    public StochasticLevenbergMarquardt(RandomVariable[] initialParameters, RandomVariable[] targetValues, RandomVariable[] parameterSteps, int maxIteration, double errorTolerance, ExecutorService executorService) {
        this(RegularizationMethod.LEVENBERG_MARQUARDT, initialParameters, targetValues, parameterSteps, maxIteration, errorTolerance, executorService);
    }

    public StochasticLevenbergMarquardt(RegularizationMethod regularizationMethod, RandomVariable[] initialParameters, RandomVariable[] targetValues, RandomVariable[] parameterSteps, int maxIteration, double errorTolerance, int numberOfThreads) {
        this(regularizationMethod, initialParameters, targetValues, parameterSteps, maxIteration, errorTolerance, null);
        this.numberOfThreads = numberOfThreads;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double lambda) {
        this.lambda = lambda;
    }

    public double getLambdaMultiplicator() {
        return this.lambdaMultiplicator;
    }

    public void setLambdaMultiplicator(double lambdaMultiplicator) {
        if (lambdaMultiplicator <= 1.0) {
            throw new IllegalArgumentException("Parameter lambdaMultiplicator is required to be > 1.");
        }
        this.lambdaMultiplicator = lambdaMultiplicator;
    }

    public double getLambdaDivisor() {
        return this.lambdaDivisor;
    }

    public void setLambdaDivisor(double lambdaDivisor) {
        if (lambdaDivisor <= 1.0) {
            throw new IllegalArgumentException("Parameter lambdaDivisor is required to be > 1.");
        }
        this.lambdaDivisor = lambdaDivisor;
    }

    @Override
    public RandomVariable[] getBestFitParameters() {
        return this.parameterCurrent;
    }

    @Override
    public double getRootMeanSquaredError() {
        return Math.sqrt(this.errorMeanSquaredCurrent);
    }

    public void setErrorMeanSquaredCurrent(double errorMeanSquaredCurrent) {
        this.errorMeanSquaredCurrent = errorMeanSquaredCurrent;
    }

    @Override
    public int getIterations() {
        return this.iteration;
    }

    protected void prepareAndSetValues(RandomVariable[] parameters, RandomVariable[] values) throws SolverException {
        this.setValues(parameters, values);
    }

    protected void prepareAndSetDerivatives(RandomVariable[] parameters, RandomVariable[] values, RandomVariable[][] derivatives) throws SolverException {
        this.setDerivatives(parameters, derivatives);
    }

    public abstract void setValues(RandomVariable[] var1, RandomVariable[] var2) throws SolverException;

    public void setDerivatives(RandomVariable[] parameters, RandomVariable[][] derivatives) throws SolverException {
        int parameterIndex;
        parameters = this.parameterCurrent;
        Vector<Future<RandomVariable[]>> valueFutures = new Vector<Future<RandomVariable[]>>(this.parameterCurrent.length);
        for (parameterIndex = 0; parameterIndex < this.parameterCurrent.length; ++parameterIndex) {
            final RandomVariable[] parametersNew = (RandomVariable[])parameters.clone();
            final RandomVariable[] derivative = derivatives[parameterIndex];
            final int workerParameterIndex = parameterIndex;
            Callable<RandomVariable[]> worker = new Callable<RandomVariable[]>(){

                @Override
                public RandomVariable[] call() {
                    RandomVariable parameterFiniteDifference = StochasticLevenbergMarquardt.this.parameterSteps != null ? StochasticLevenbergMarquardt.this.parameterSteps[workerParameterIndex] : parametersNew[workerParameterIndex].abs().add(1.0).mult(1.0E-8);
                    parametersNew[workerParameterIndex] = parametersNew[workerParameterIndex].add(parameterFiniteDifference);
                    try {
                        StochasticLevenbergMarquardt.this.prepareAndSetValues(parametersNew, derivative);
                    }
                    catch (Exception e) {
                        Arrays.fill(derivative, new RandomVariableFromDoubleArray(Double.NaN));
                    }
                    for (int valueIndex = 0; valueIndex < StochasticLevenbergMarquardt.this.valueCurrent.length; ++valueIndex) {
                        derivative[valueIndex] = derivative[valueIndex].sub(StochasticLevenbergMarquardt.this.valueCurrent[valueIndex]).div(parameterFiniteDifference);
                    }
                    return derivative;
                }
            };
            if (this.executor != null) {
                Future<RandomVariable[]> valueFuture = this.executor.submit(worker);
                valueFutures.add(parameterIndex, valueFuture);
                continue;
            }
            FutureTask<RandomVariable[]> valueFutureTask = new FutureTask<RandomVariable[]>(worker);
            valueFutureTask.run();
            valueFutures.add(parameterIndex, valueFutureTask);
        }
        for (parameterIndex = 0; parameterIndex < this.parameterCurrent.length; ++parameterIndex) {
            try {
                derivatives[parameterIndex] = (RandomVariable[])((Future)valueFutures.get(parameterIndex)).get();
                continue;
            }
            catch (InterruptedException | ExecutionException e) {
                throw new SolverException(e);
            }
        }
    }

    boolean done() {
        return this.iteration > this.maxIteration || this.errorRootMeanSquaredChange <= this.errorTolerance;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() throws SolverException {
        if (this.numberOfThreads > 1 && this.executor == null) {
            this.executor = Executors.newFixedThreadPool(this.numberOfThreads);
            this.executorShutdownWhenDone = true;
        }
        try {
            int numberOfParameters = this.initialParameters.length;
            int numberOfValues = this.targetValues.length;
            this.parameterTest = (RandomVariable[])this.initialParameters.clone();
            this.parameterCurrent = (RandomVariable[])this.initialParameters.clone();
            this.valueTest = new RandomVariable[numberOfValues];
            this.valueCurrent = new RandomVariable[numberOfValues];
            Arrays.fill(this.valueCurrent, new RandomVariableFromDoubleArray(Double.NaN));
            this.derivativeCurrent = new RandomVariable[numberOfParameters][numberOfValues];
            this.iteration = 0;
            this.lambda = 0.001;
            this.isParameterCurrentDerivativeValid = false;
            while (true) {
                boolean isPointAccepted;
                ++this.iteration;
                this.prepareAndSetValues(this.parameterTest, this.valueTest);
                double errorMeanSquaredTest = this.getMeanSquaredError(this.valueTest);
                boolean bl = isPointAccepted = this.errorMeanSquaredCurrent > errorMeanSquaredTest;
                if (isPointAccepted) {
                    this.parameterCurrent = (RandomVariable[])this.parameterTest.clone();
                    this.valueCurrent = (RandomVariable[])this.valueTest.clone();
                    this.errorRootMeanSquaredChange = Math.sqrt(this.errorMeanSquaredCurrent) - Math.sqrt(errorMeanSquaredTest);
                    this.errorMeanSquaredCurrent = errorMeanSquaredTest;
                }
                if (this.done()) {
                    break;
                }
                boolean bl2 = this.isParameterCurrentDerivativeValid = !isPointAccepted;
                this.lambda = isPointAccepted ? (this.lambda /= this.lambdaDivisor) : (this.lambda *= this.lambdaMultiplicator);
                this.prepareAndSetDerivatives(this.parameterTest, this.valueTest, this.derivativeCurrent);
                double[] parameterIncrement = new double[this.parameterCurrent.length];
                double[][] hessianMatrix = new double[this.parameterCurrent.length][this.parameterCurrent.length];
                double[] beta = new double[this.parameterCurrent.length];
                boolean hessianInvalid = true;
                while (hessianInvalid) {
                    int i;
                    for (i = 0; i < this.parameterCurrent.length; ++i) {
                        for (int j = i; j < this.parameterCurrent.length; ++j) {
                            double alphaElement = 0.0;
                            for (int valueIndex = 0; valueIndex < this.valueCurrent.length; ++valueIndex) {
                                if (this.derivativeCurrent[i][valueIndex] == null || this.derivativeCurrent[j][valueIndex] == null) continue;
                                alphaElement += this.derivativeCurrent[i][valueIndex].mult(this.derivativeCurrent[j][valueIndex]).getAverage();
                            }
                            if (i == j) {
                                alphaElement = this.regularizationMethod == RegularizationMethod.LEVENBERG ? (alphaElement += this.lambda) : (alphaElement == 0.0 ? this.lambda : (alphaElement *= 1.0 + this.lambda));
                            }
                            hessianMatrix[i][j] = alphaElement;
                            hessianMatrix[j][i] = alphaElement;
                        }
                    }
                    for (i = 0; i < this.parameterCurrent.length; ++i) {
                        double betaElement = 0.0;
                        RandomVariable[] derivativeCurrentSingleParam = this.derivativeCurrent[i];
                        for (int k = 0; k < this.valueCurrent.length; ++k) {
                            if (derivativeCurrentSingleParam[k] == null) continue;
                            betaElement += this.targetValues[k].sub(this.valueCurrent[k]).mult(derivativeCurrentSingleParam[k]).getAverage();
                        }
                        beta[i] = betaElement;
                    }
                    try {
                        parameterIncrement = LinearAlgebra.solveLinearEquationSymmetric(hessianMatrix, beta);
                        hessianInvalid = false;
                    }
                    catch (Exception e) {
                        hessianInvalid = true;
                        this.lambda *= 16.0;
                    }
                }
                for (int i = 0; i < this.parameterCurrent.length; ++i) {
                    this.parameterTest[i] = this.parameterCurrent[i].add(parameterIncrement[i]);
                }
                if (!this.logger.isLoggable(Level.FINE)) continue;
                String logString = "Iteration: " + this.iteration + "\tLambda=" + this.lambda + "\tError Current (RMS):" + Math.sqrt(this.errorMeanSquaredCurrent) + "\tError Change:" + this.errorRootMeanSquaredChange + "\t";
                for (int i = 0; i < this.parameterCurrent.length; ++i) {
                    logString = logString + "[" + i + "] = " + this.parameterCurrent[i].doubleValue() + "\t";
                }
                this.logger.fine(logString);
            }
        }
        finally {
            if (this.executor != null && this.executorShutdownWhenDone) {
                this.executor.shutdown();
                this.executor = null;
            }
        }
    }

    public double getMeanSquaredError(RandomVariable[] value) {
        double error = 0.0;
        for (int valueIndex = 0; valueIndex < value.length; ++valueIndex) {
            double deviationSquared = value[valueIndex].sub(this.targetValues[valueIndex]).squared().getAverage();
            error += deviationSquared;
        }
        return error / (double)value.length;
    }

    public StochasticLevenbergMarquardt clone() throws CloneNotSupportedException {
        throw new CloneNotSupportedException();
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(RandomVariable[] newTargetVaues, RandomVariable[] newWeights, boolean isUseBestParametersAsInitialParameters) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt clonedOptimizer = this.clone();
        clonedOptimizer.targetValues = (RandomVariable[])newTargetVaues.clone();
        if (isUseBestParametersAsInitialParameters && this.done()) {
            clonedOptimizer.initialParameters = this.getBestFitParameters();
        }
        return clonedOptimizer;
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(List<RandomVariable> newTargetVaues, List<RandomVariable> newWeights, boolean isUseBestParametersAsInitialParameters) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt clonedOptimizer = this.clone();
        clonedOptimizer.targetValues = newTargetVaues.toArray(new RandomVariable[0]);
        if (isUseBestParametersAsInitialParameters && this.done()) {
            clonedOptimizer.initialParameters = this.getBestFitParameters();
        }
        return clonedOptimizer;
    }

    public static enum RegularizationMethod {
        LEVENBERG,
        LEVENBERG_MARQUARDT;

    }
}

