/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.assetderivativevaluation.models;

import java.util.Arrays;
import java.util.Map;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariableFactory;
import net.finmath.montecarlo.RandomVariableFromArrayFactory;
import net.finmath.montecarlo.model.AbstractProcessModel;
import net.finmath.montecarlo.process.MonteCarloProcess;
import net.finmath.stochastic.RandomVariable;

public class MultiAssetBlackScholesModel
extends AbstractProcessModel {
    private final RandomVariableFactory randomVariableFactory;
    private final double[] initialValues;
    private final double riskFreeRate;
    private final double[][] factorLoadings;
    private final RandomVariable[] initialStates;
    private final RandomVariable[] drift;
    private final RandomVariable[][] factorLoadingOnPaths;

    public MultiAssetBlackScholesModel(RandomVariableFactory randomVariableFactory, double[] initialValues, double riskFreeRate, double[][] factorLoadings) {
        this.randomVariableFactory = randomVariableFactory;
        this.initialValues = initialValues;
        this.riskFreeRate = riskFreeRate;
        this.factorLoadings = factorLoadings;
        this.initialStates = new RandomVariable[this.getNumberOfComponents()];
        this.drift = new RandomVariable[this.getNumberOfComponents()];
        this.factorLoadingOnPaths = new RandomVariable[this.getNumberOfComponents()][];
        for (int underlyingIndex = 0; underlyingIndex < initialValues.length; ++underlyingIndex) {
            double volatilitySquaredForUnderlying = 0.0;
            this.factorLoadingOnPaths[underlyingIndex] = new RandomVariable[factorLoadings[underlyingIndex].length];
            for (int factorIndex = 0; factorIndex < factorLoadings[underlyingIndex].length; ++factorIndex) {
                volatilitySquaredForUnderlying += factorLoadings[underlyingIndex][factorIndex] * factorLoadings[underlyingIndex][factorIndex];
                this.factorLoadingOnPaths[underlyingIndex][factorIndex] = this.getRandomVariableForConstant(factorLoadings[underlyingIndex][factorIndex]);
            }
            this.initialStates[underlyingIndex] = this.getRandomVariableForConstant(Math.log(initialValues[underlyingIndex]));
            this.drift[underlyingIndex] = this.getRandomVariableForConstant(riskFreeRate - volatilitySquaredForUnderlying / 2.0);
        }
    }

    public MultiAssetBlackScholesModel(RandomVariableFactory randomVariableFactory, double[] initialValues, double riskFreeRate, double[] volatilities, double[][] correlations) {
        this(randomVariableFactory, initialValues, riskFreeRate, MultiAssetBlackScholesModel.getFactorLoadingsFromVolatilityAnCorrelation(volatilities, correlations));
    }

    private static double[][] getFactorLoadingsFromVolatilityAnCorrelation(double[] volatilities, double[][] correlations) {
        double[][] factorLoadings = LinearAlgebra.getFactorMatrix(correlations, correlations.length);
        for (int underlyingIndex = 0; underlyingIndex < factorLoadings.length; ++underlyingIndex) {
            double volatility = volatilities[underlyingIndex];
            for (int factorIndex = 0; factorIndex < factorLoadings[underlyingIndex].length; ++factorIndex) {
                factorLoadings[underlyingIndex][factorIndex] = factorLoadings[underlyingIndex][factorIndex] * volatility;
            }
        }
        return factorLoadings;
    }

    public MultiAssetBlackScholesModel(double[] initialValues, double riskFreeRate, double[][] factorLoadings) {
        this(new RandomVariableFromArrayFactory(), initialValues, riskFreeRate, factorLoadings);
    }

    public MultiAssetBlackScholesModel(double[] initialValues, double riskFreeRate, double[] volatilities, double[][] correlations) {
        this(new RandomVariableFromArrayFactory(), initialValues, riskFreeRate, volatilities, correlations);
    }

    @Override
    public RandomVariable[] getInitialState(MonteCarloProcess process) {
        return this.initialStates;
    }

    @Override
    public RandomVariable[] getDrift(MonteCarloProcess process, int timeIndex, RandomVariable[] realizationAtTimeIndex, RandomVariable[] realizationPredictor) {
        return this.drift;
    }

    @Override
    public RandomVariable[] getFactorLoading(MonteCarloProcess process, int timeIndex, int component, RandomVariable[] realizationAtTimeIndex) {
        return this.factorLoadingOnPaths[component];
    }

    @Override
    public RandomVariable applyStateSpaceTransform(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        return randomVariable.exp();
    }

    @Override
    public RandomVariable applyStateSpaceTransformInverse(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        return randomVariable.log();
    }

    @Override
    public RandomVariable getNumeraire(MonteCarloProcess process, double time) {
        double numeraireValue = Math.exp(this.riskFreeRate * time);
        return this.getRandomVariableForConstant(numeraireValue);
    }

    @Override
    public RandomVariable getRandomVariableForConstant(double value) {
        return this.randomVariableFactory.createRandomVariable(value);
    }

    @Override
    public int getNumberOfComponents() {
        return this.initialValues.length;
    }

    @Override
    public int getNumberOfFactors() {
        return this.factorLoadings[0].length;
    }

    @Override
    public MultiAssetBlackScholesModel getCloneWithModifiedData(Map<String, Object> dataModified) {
        RandomVariableFactory newRandomVariableFactory = (RandomVariableFactory)dataModified.getOrDefault("randomVariableFactory", this.randomVariableFactory);
        double[] newInitialValues = (double[])dataModified.getOrDefault("initialValues", this.initialValues);
        double newRiskFreeRate = (Double)dataModified.getOrDefault("riskFreeRate", this.riskFreeRate);
        double[][] newFactorLoadings = (double[][])dataModified.getOrDefault("factorLoadings", this.factorLoadings);
        if (dataModified.containsKey("volatilities") || dataModified.containsKey("correlations")) {
            if (dataModified.containsKey("factorLoadings")) {
                throw new IllegalArgumentException("Inconsistend parameters. Cannot specify volatility or corellation and factorLoadings at the same time.");
            }
            double[] newVolatilities = (double[])dataModified.getOrDefault("volatilities", this.getVolatilityVector());
            double[][] newCorrelations = (double[][])dataModified.getOrDefault("correlations", this.getCorrelationMatrix());
            newFactorLoadings = MultiAssetBlackScholesModel.getFactorLoadingsFromVolatilityAnCorrelation(newVolatilities, newCorrelations);
        }
        return new MultiAssetBlackScholesModel(newRandomVariableFactory, newInitialValues, newRiskFreeRate, newFactorLoadings);
    }

    public String toString() {
        return "MonteCarloMultiAssetBlackScholesModel [initialValues=" + Arrays.toString(this.initialValues) + ", riskFreeRate=" + this.riskFreeRate + ", factorLoadings=" + Arrays.toString((Object[])this.factorLoadings) + "]";
    }

    public double getRiskFreeRate() {
        return this.riskFreeRate;
    }

    public double[][] getFactorLoadingMatrix() {
        return this.factorLoadings;
    }

    public double[] getVolatilityVector() {
        double[] volatilities = new double[this.factorLoadings.length];
        for (int underlyingIndex = 0; underlyingIndex < this.factorLoadings.length; ++underlyingIndex) {
            double volatilitySquaredOfUnderlying = 0.0;
            for (int factorIndex = 0; factorIndex < this.factorLoadings[underlyingIndex].length; ++factorIndex) {
                double factorLoading = this.factorLoadings[underlyingIndex][factorIndex];
                volatilitySquaredOfUnderlying += factorLoading * factorLoading;
            }
            volatilities[underlyingIndex] = Math.sqrt(volatilitySquaredOfUnderlying);
        }
        return volatilities;
    }

    public double[][] getCorrelationMatrix() {
        double[] volatilities = this.getVolatilityVector();
        double[][] correlations = new double[this.factorLoadings.length][this.factorLoadings.length];
        for (int underlyingIndex1 = 0; underlyingIndex1 < this.factorLoadings.length; ++underlyingIndex1) {
            for (int underlyingIndex2 = 0; underlyingIndex2 < this.factorLoadings.length; ++underlyingIndex2) {
                double covariance = 0.0;
                for (int factorIndex = 0; factorIndex < this.factorLoadings[underlyingIndex1].length; ++factorIndex) {
                    covariance += this.factorLoadings[underlyingIndex1][factorIndex] * this.factorLoadings[underlyingIndex2][factorIndex];
                }
                double correlation = volatilities[underlyingIndex1] != 0.0 && volatilities[underlyingIndex2] != 0.0 ? covariance / volatilities[underlyingIndex1] / volatilities[underlyingIndex2] : (underlyingIndex1 == underlyingIndex2 ? 1.0 : 0.0);
                correlations[underlyingIndex1][underlyingIndex2] = correlation;
            }
        }
        return correlations;
    }
}

