/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.interestrate.models;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import net.finmath.exception.CalculationException;
import net.finmath.marketdata.model.AnalyticModel;
import net.finmath.marketdata.model.curves.DiscountCurve;
import net.finmath.marketdata.model.curves.DiscountCurveFromForwardCurve;
import net.finmath.marketdata.model.curves.ForwardCurve;
import net.finmath.montecarlo.RandomVariableFactory;
import net.finmath.montecarlo.RandomVariableFromArrayFactory;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.montecarlo.interestrate.LIBORMarketModel;
import net.finmath.montecarlo.interestrate.LIBORModel;
import net.finmath.montecarlo.interestrate.models.covariance.ShortRateVolatilityModel;
import net.finmath.montecarlo.model.AbstractProcessModel;
import net.finmath.montecarlo.process.MonteCarloProcess;
import net.finmath.stochastic.RandomVariable;
import net.finmath.time.TimeDiscretization;

public class HullWhiteModelWithDirectSimulation
extends AbstractProcessModel
implements LIBORModel {
    private final TimeDiscretization liborPeriodDiscretization;
    private String forwardCurveName;
    private final AnalyticModel analyticModel;
    private final ForwardCurve forwardRateCurve;
    private final DiscountCurve discountCurve;
    private final DiscountCurve discountCurveFromForwardCurve;
    private final RandomVariableFactory randomVariableFactory = new RandomVariableFromArrayFactory();
    private final ConcurrentHashMap<Integer, RandomVariable> numeraires;
    private MonteCarloProcess numerairesProcess = null;
    private final ShortRateVolatilityModel volatilityModel;
    private RandomVariable[] initialState;

    public HullWhiteModelWithDirectSimulation(TimeDiscretization liborPeriodDiscretization, AnalyticModel analyticModel, ForwardCurve forwardRateCurve, DiscountCurve discountCurve, ShortRateVolatilityModel volatilityModel, Map<String, ?> properties) {
        this.liborPeriodDiscretization = liborPeriodDiscretization;
        this.analyticModel = analyticModel;
        this.forwardRateCurve = forwardRateCurve;
        this.discountCurve = discountCurve;
        this.volatilityModel = volatilityModel;
        this.discountCurveFromForwardCurve = new DiscountCurveFromForwardCurve(forwardRateCurve);
        this.numeraires = new ConcurrentHashMap();
    }

    @Override
    public int getNumberOfComponents() {
        return 1;
    }

    @Override
    public int getNumberOfFactors() {
        return 1;
    }

    @Override
    public RandomVariable applyStateSpaceTransform(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        return randomVariable;
    }

    @Override
    public RandomVariable applyStateSpaceTransformInverse(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable randomVariable) {
        return randomVariable;
    }

    @Override
    public RandomVariable[] getInitialState(MonteCarloProcess process) {
        if (this.initialState == null) {
            double dt = process.getTimeDiscretization().getTimeStep(0);
            this.initialState = new RandomVariable[]{new RandomVariableFromDoubleArray(Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(0.0) / this.discountCurveFromForwardCurve.getDiscountFactor(dt)) / dt)};
        }
        return this.initialState;
    }

    @Override
    public RandomVariable getNumeraire(MonteCarloProcess process, double time) throws CalculationException {
        RandomVariable numeraire;
        if (time < 0.0) {
            return this.randomVariableFactory.createRandomVariable(this.discountCurve.getDiscountFactor(this.analyticModel, time));
        }
        if (time == process.getTime(0)) {
            RandomVariable one = this.randomVariableFactory.createRandomVariable(1.0);
            return one;
        }
        int timeIndex = process.getTimeIndex(time);
        if (timeIndex < 0) {
            int previousTimeIndex = process.getTimeIndex(time);
            if (previousTimeIndex < 0) {
                previousTimeIndex = -previousTimeIndex - 1;
            }
            double previousTime = process.getTime(--previousTimeIndex);
            RandomVariable rate = this.getShortRate(process, previousTimeIndex);
            RandomVariable integratedRate = rate.mult(time - previousTime);
            return this.getNumeraire(process, previousTime).mult(integratedRate.exp());
        }
        if (process != this.numerairesProcess) {
            this.numeraires.clear();
            this.numerairesProcess = process;
        }
        if ((numeraire = this.numeraires.get(timeIndex)) == null) {
            RandomVariable zero;
            RandomVariable integratedRate = zero = process.getStochasticDriver().getRandomVariableForConstant(0.0);
            for (int i = 0; i < timeIndex; ++i) {
                RandomVariable rate = this.getShortRate(process, i);
                double dt = process.getTimeDiscretization().getTimeStep(i);
                integratedRate = integratedRate.addProduct(rate, dt);
                numeraire = integratedRate.exp();
                this.numeraires.put(i + 1, numeraire);
            }
        }
        if (this.discountCurve != null) {
            double deterministicNumeraireAdjustment = numeraire.invert().getAverage() / this.discountCurve.getDiscountFactor(this.analyticModel, time);
            numeraire = numeraire.mult(deterministicNumeraireAdjustment);
        }
        return numeraire;
    }

    @Override
    public RandomVariable[] getDrift(MonteCarloProcess process, int timeIndex, RandomVariable[] realizationAtTimeIndex, RandomVariable[] realizationPredictor) {
        double time = process.getTime(timeIndex);
        double timeNext = process.getTime(timeIndex + 1);
        double t0 = time;
        double t1 = timeNext;
        double t2 = timeIndex < process.getTimeDiscretization().getNumberOfTimes() - 2 ? process.getTime(timeIndex + 2) : t1 + process.getTimeDiscretization().getTimeStep(timeIndex);
        double df0 = this.discountCurveFromForwardCurve.getDiscountFactor(t0);
        double df1 = this.discountCurveFromForwardCurve.getDiscountFactor(t1);
        double df2 = this.discountCurveFromForwardCurve.getDiscountFactor(t2);
        double forward = time > 0.0 ? -Math.log(df1 / df0) / (t1 - t0) : this.getInitialState(process)[0].get(0);
        double forwardNext = -Math.log(df2 / df1) / (t2 - t1);
        double forwardChange = (forwardNext - forward) / (t1 - t0);
        int timeIndexVolatility = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexVolatility < 0) {
            timeIndexVolatility = -timeIndexVolatility - 2;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexVolatility).doubleValue();
        double meanReversionEffective = meanReversion * this.getB(time, timeNext) / (timeNext - time);
        double phi = (this.getDV(0.0, timeNext) - Math.exp(-meanReversion * (timeNext - time)) * this.getDV(0.0, time)) / (timeNext - time);
        double theta = forwardChange + meanReversionEffective * forward + phi;
        return new RandomVariable[]{realizationAtTimeIndex[0].mult(-meanReversionEffective).add(theta)};
    }

    @Override
    public RandomVariable[] getFactorLoading(MonteCarloProcess process, int timeIndex, int componentIndex, RandomVariable[] realizationAtTimeIndex) {
        double time = process.getTime(timeIndex);
        double timeNext = process.getTime(timeIndex + 1);
        int timeIndexVolatility = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexVolatility < 0) {
            timeIndexVolatility = -timeIndexVolatility - 2;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexVolatility).doubleValue();
        double volatility = this.volatilityModel.getVolatility(timeIndexVolatility).doubleValue();
        double scaling = Math.sqrt((1.0 - Math.exp(-2.0 * meanReversion * (timeNext - time))) / (2.0 * meanReversion * (timeNext - time)));
        double volatilityEffective = scaling * volatility;
        return new RandomVariable[]{new RandomVariableFromDoubleArray(volatilityEffective)};
    }

    @Override
    public RandomVariable getRandomVariableForConstant(double value) {
        return this.randomVariableFactory.createRandomVariable(value);
    }

    @Override
    public RandomVariable getForwardRate(MonteCarloProcess process, double time, double periodStart, double periodEnd) throws CalculationException {
        return this.getZeroCouponBond(process, time, periodStart).div(this.getZeroCouponBond(process, time, periodEnd)).sub(1.0).div(periodEnd - periodStart);
    }

    @Override
    public RandomVariable getLIBOR(MonteCarloProcess process, int timeIndex, int liborIndex) throws CalculationException {
        return this.getZeroCouponBond(process, process.getTime(timeIndex), this.getLiborPeriod(liborIndex)).div(this.getZeroCouponBond(process, process.getTime(timeIndex), this.getLiborPeriod(liborIndex + 1))).sub(1.0).div(this.getLiborPeriodDiscretization().getTimeStep(liborIndex));
    }

    @Override
    public TimeDiscretization getLiborPeriodDiscretization() {
        return this.liborPeriodDiscretization;
    }

    @Override
    public int getNumberOfLibors() {
        return this.liborPeriodDiscretization.getNumberOfTimeSteps();
    }

    @Override
    public double getLiborPeriod(int timeIndex) {
        return this.liborPeriodDiscretization.getTime(timeIndex);
    }

    @Override
    public int getLiborPeriodIndex(double time) {
        return this.liborPeriodDiscretization.getTimeIndex(time);
    }

    @Override
    public AnalyticModel getAnalyticModel() {
        return this.analyticModel;
    }

    @Override
    public DiscountCurve getDiscountCurve() {
        return this.discountCurve;
    }

    @Override
    public ForwardCurve getForwardRateCurve() {
        return this.forwardRateCurve;
    }

    @Override
    public LIBORMarketModel getCloneWithModifiedData(Map<String, Object> dataModified) {
        throw new UnsupportedOperationException();
    }

    private RandomVariable getShortRate(MonteCarloProcess process, int timeIndex) throws CalculationException {
        RandomVariable value = process.getProcessValue(timeIndex, 0);
        return value;
    }

    private RandomVariable getZeroCouponBond(MonteCarloProcess process, double time, double maturity) throws CalculationException {
        int timeIndex = process.getTimeIndex(time);
        RandomVariable shortRate = this.getShortRate(process, timeIndex);
        double A = this.getA(process, time, maturity);
        double B = this.getB(time, maturity);
        return shortRate.mult(-B).exp().mult(A);
    }

    private double getA(MonteCarloProcess process, double time, double maturity) {
        double timeStep;
        int timeIndex = process.getTimeIndex(time);
        double dt = timeStep = process.getTimeDiscretization().getTimeStep(timeIndex);
        double zeroRate = -Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(time + dt) / this.discountCurveFromForwardCurve.getDiscountFactor(time)) / dt;
        double B = this.getB(time, maturity);
        double lnA = Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(maturity) / this.discountCurveFromForwardCurve.getDiscountFactor(time)) + B * zeroRate - 0.5 * this.getShortRateConditionalVariance(0.0, time) * B * B;
        return Math.exp(lnA);
    }

    private double getMRTime(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1).doubleValue();
            integral += meanReversion * (timeNext - timePrev);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd).doubleValue();
        return integral += meanReversion * (timeNext - timePrev);
    }

    private double getB(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1).doubleValue();
            integral += (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / meanReversion;
            timePrev = timeNext;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd).doubleValue();
        timeNext = maturity;
        return integral += (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / meanReversion;
    }

    private double getV(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        if (time == maturity) {
            return 0.0;
        }
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1).doubleValue();
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1).doubleValue();
            integral += volatility * volatility * (timeNext - timePrev) / (meanReversion * meanReversion);
            integral -= volatility * volatility * 2.0 * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion * meanReversion);
            integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd).doubleValue();
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd).doubleValue();
        integral += volatility * volatility * (timeNext - timePrev) / (meanReversion * meanReversion);
        integral -= volatility * volatility * 2.0 * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion * meanReversion);
        return integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion * meanReversion);
    }

    private double getDV(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        if (time == maturity) {
            return 0.0;
        }
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1).doubleValue();
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1).doubleValue();
            integral += volatility * volatility * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion);
            integral -= volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd).doubleValue();
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd).doubleValue();
        integral += volatility * volatility * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion);
        return integral -= volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion);
    }

    public double getShortRateConditionalVariance(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1).doubleValue();
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1).doubleValue();
            integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd).doubleValue();
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd).doubleValue();
        return integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion);
    }

    public double getIntegratedBondSquaredVolatility(double time, double maturity) {
        return this.getShortRateConditionalVariance(0.0, time) * this.getB(time, maturity) * this.getB(time, maturity);
    }

    @Override
    public Map<String, RandomVariable> getModelParameters() {
        throw new UnsupportedOperationException();
    }
}

