/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.automaticdifferentiation.backward;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.IntToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import net.finmath.functions.DoubleTernaryOperator;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.montecarlo.automaticdifferentiation.RandomVariableDifferentiable;
import net.finmath.montecarlo.automaticdifferentiation.backward.RandomVariableDifferentiableAADFactory;
import net.finmath.montecarlo.conditionalexpectation.LinearRegression;
import net.finmath.stochastic.ConditionalExpectationEstimator;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;

public class RandomVariableDifferentiableAAD
implements RandomVariableDifferentiable {
    private static final long serialVersionUID = 2459373647785530657L;
    private static final int typePriorityDefault = 3;
    private static final RandomVariable one = new Scalar(1.0);
    private final int typePriority;
    private static AtomicLong indexOfNextRandomVariable = new AtomicLong(0L);
    private RandomVariable values;
    private final OperatorTreeNode operatorTreeNode;
    private final RandomVariableDifferentiableAADFactory factory;

    public RandomVariableDifferentiableAAD(RandomVariable values, List<OperatorTreeNode> argumentOperatorTreeNodes, List<RandomVariable> argumentValues, ConditionalExpectationEstimator estimator, OperatorType operator, RandomVariableDifferentiableAADFactory factory, int methodArgumentTypePriority) {
        this.values = values;
        this.operatorTreeNode = new OperatorTreeNode(operator, argumentOperatorTreeNodes, argumentValues, estimator, factory);
        this.factory = factory != null ? factory : new RandomVariableDifferentiableAADFactory();
        this.typePriority = methodArgumentTypePriority;
    }

    public static RandomVariableDifferentiableAAD of(double value) {
        return new RandomVariableDifferentiableAAD(value);
    }

    public static RandomVariableDifferentiableAAD of(RandomVariable randomVariable) {
        return new RandomVariableDifferentiableAAD(randomVariable);
    }

    public RandomVariableDifferentiableAAD(double value) {
        this(new Scalar(value), null, null, null);
    }

    public RandomVariableDifferentiableAAD(RandomVariable randomVariable) {
        this(randomVariable, null, null, randomVariable instanceof RandomVariableDifferentiableAAD ? ((RandomVariableDifferentiableAAD)randomVariable).getFactory() : null);
    }

    public RandomVariableDifferentiableAAD(RandomVariable values, RandomVariableDifferentiableAADFactory factory) {
        this(values, null, null, factory);
    }

    private RandomVariableDifferentiableAAD(RandomVariable values, List<RandomVariable> arguments, OperatorType operator, RandomVariableDifferentiableAADFactory factory) {
        this(values, arguments, null, operator, factory);
    }

    public RandomVariableDifferentiableAAD(RandomVariable values, List<RandomVariable> arguments, ConditionalExpectationEstimator estimator, OperatorType operator, RandomVariableDifferentiableAADFactory factory) {
        this(values, arguments, estimator, operator, factory, 3);
    }

    public RandomVariableDifferentiableAAD(RandomVariable values, List<RandomVariable> arguments, ConditionalExpectationEstimator estimator, OperatorType operator, RandomVariableDifferentiableAADFactory factory, int methodArgumentTypePriority) {
        this(values, OperatorTreeNode.extractOperatorTreeNodes(arguments), OperatorTreeNode.extractOperatorValues(arguments), estimator, operator, factory, methodArgumentTypePriority);
    }

    public OperatorTreeNode getOperatorTreeNode() {
        return this.operatorTreeNode;
    }

    @Override
    public RandomVariable getValues() {
        return this.values;
    }

    public RandomVariableDifferentiableAADFactory getFactory() {
        return this.factory;
    }

    @Override
    public Long getID() {
        return this.getOperatorTreeNode().id;
    }

    @Override
    public Map<Long, RandomVariable> getGradient(Set<Long> independentIDs) {
        HashMap<Long, RandomVariable> derivatives = new HashMap<Long, RandomVariable>();
        derivatives.put(this.getID(), one);
        TreeMap<Long, OperatorTreeNode> independents = new TreeMap<Long, OperatorTreeNode>();
        independents.put(this.getID(), this.getOperatorTreeNode());
        while (independents.size() > 0) {
            Map.Entry independentEntry = independents.pollLastEntry();
            Long id = (Long)independentEntry.getKey();
            OperatorTreeNode independent = (OperatorTreeNode)independentEntry.getValue();
            List<OperatorTreeNode> arguments = independent.arguments;
            if (arguments != null && arguments.size() > 0) {
                independent.propagateDerivativesFromResultToArgument(derivatives);
                if (this.isGradientRetainsLeafNodesOnly()) {
                    derivatives.remove(id);
                }
                for (OperatorTreeNode argument : arguments) {
                    if (argument == null) continue;
                    independents.put(argument.id, argument);
                }
            }
            if (independentIDs == null || !independentIDs.contains(id)) continue;
            derivatives.remove(id);
        }
        return derivatives;
    }

    private boolean isGradientRetainsLeafNodesOnly() {
        return this.getFactory() != null && this.getFactory().isGradientRetainsLeafNodesOnly();
    }

    @Override
    public Map<Long, RandomVariable> getTangents(Set<Long> dependentIDs) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean equals(RandomVariable randomVariable) {
        return this.getValues().equals(randomVariable);
    }

    @Override
    public double getFiltrationTime() {
        return this.getValues().getFiltrationTime();
    }

    @Override
    public int getTypePriority() {
        return this.typePriority;
    }

    @Override
    public double get(int pathOrState) {
        return this.getValues().get(pathOrState);
    }

    @Override
    public int size() {
        return this.getValues().size();
    }

    @Override
    public boolean isDeterministic() {
        return this.getValues().isDeterministic();
    }

    @Override
    public double[] getRealizations() {
        return this.getValues().getRealizations();
    }

    @Override
    public Double doubleValue() {
        return this.getValues().doubleValue();
    }

    @Override
    public double getMin() {
        return this.getValues().getMin();
    }

    @Override
    public double getMax() {
        return this.getValues().getMax();
    }

    @Override
    public double getAverage() {
        return this.getValues().getAverage();
    }

    @Override
    public double getAverage(RandomVariable probabilities) {
        return this.getValues().getAverage(probabilities);
    }

    @Override
    public double getVariance() {
        return this.getValues().getVariance();
    }

    @Override
    public double getVariance(RandomVariable probabilities) {
        return this.getValues().getVariance(probabilities);
    }

    @Override
    public double getSampleVariance() {
        return this.getValues().getSampleVariance();
    }

    @Override
    public double getStandardDeviation() {
        return this.getValues().getStandardDeviation();
    }

    @Override
    public double getStandardDeviation(RandomVariable probabilities) {
        return this.getValues().getStandardDeviation(probabilities);
    }

    @Override
    public double getStandardError() {
        return this.getValues().getStandardError();
    }

    @Override
    public double getStandardError(RandomVariable probabilities) {
        return this.getValues().getStandardError(probabilities);
    }

    @Override
    public double getQuantile(double quantile) {
        return this.getValues().getQuantile(quantile);
    }

    @Override
    public double getQuantile(double quantile, RandomVariable probabilities) {
        return this.getValues().getQuantile(quantile, probabilities);
    }

    @Override
    public double getQuantileExpectation(double quantileStart, double quantileEnd) {
        return this.getValues().getQuantileExpectation(quantileStart, quantileEnd);
    }

    @Override
    public double[] getHistogram(double[] intervalPoints) {
        return this.getValues().getHistogram(intervalPoints);
    }

    @Override
    public double[][] getHistogram(int numberOfPoints, double standardDeviations) {
        return this.getValues().getHistogram(numberOfPoints, standardDeviations);
    }

    @Override
    public RandomVariable cache() {
        this.values = this.values.cache();
        return this;
    }

    @Override
    public RandomVariable cap(double cap) {
        return new RandomVariableDifferentiableAAD(this.getValues().cap(cap), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(this.getValues(), new Scalar(cap)), null, OperatorType.CAP, this.getFactory(), 3);
    }

    @Override
    public RandomVariable floor(double floor) {
        return new RandomVariableDifferentiableAAD(this.getValues().floor(floor), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(this.getValues(), new Scalar(floor)), null, OperatorType.FLOOR, this.getFactory(), 3);
    }

    @Override
    public RandomVariable add(double value) {
        return new RandomVariableDifferentiableAAD(this.getValues().add(value), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(null, null), null, OperatorType.ADD, this.getFactory(), 3);
    }

    @Override
    public RandomVariable sub(double value) {
        return new RandomVariableDifferentiableAAD(this.getValues().sub(value), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(null, null), null, OperatorType.SUB, this.getFactory(), 3);
    }

    @Override
    public RandomVariable mult(double value) {
        return new RandomVariableDifferentiableAAD(this.getValues().mult(value), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(null, new Scalar(value)), null, OperatorType.MULT, this.getFactory(), 3);
    }

    @Override
    public RandomVariable div(double value) {
        return new RandomVariableDifferentiableAAD(this.getValues().div(value), Arrays.asList(this.getOperatorTreeNode(), null), Arrays.asList(null, new Scalar(value)), null, OperatorType.DIV, this.getFactory(), 3);
    }

    @Override
    public RandomVariable pow(double exponent) {
        return new RandomVariableDifferentiableAAD(this.getValues().pow(exponent), Arrays.asList(this, new Scalar(exponent)), OperatorType.POW, this.getFactory());
    }

    @Override
    public RandomVariable average() {
        return new RandomVariableDifferentiableAAD(this.getValues().average(), Arrays.asList(this), OperatorType.AVERAGE, this.getFactory());
    }

    @Override
    public RandomVariable getConditionalExpectation(ConditionalExpectationEstimator estimator) {
        return new RandomVariableDifferentiableAAD(this.getValues().getConditionalExpectation(estimator), Arrays.asList(this), estimator, OperatorType.CONDITIONAL_EXPECTATION, this.getFactory());
    }

    @Override
    public RandomVariable squared() {
        return new RandomVariableDifferentiableAAD(this.getValues().squared(), Arrays.asList(this), OperatorType.SQUARED, this.getFactory());
    }

    @Override
    public RandomVariable sqrt() {
        return new RandomVariableDifferentiableAAD(this.getValues().sqrt(), Arrays.asList(this), OperatorType.SQRT, this.getFactory());
    }

    @Override
    public RandomVariable exp() {
        return new RandomVariableDifferentiableAAD(this.getValues().exp(), Arrays.asList(this), OperatorType.EXP, this.getFactory());
    }

    @Override
    public RandomVariable log() {
        return new RandomVariableDifferentiableAAD(this.getValues().log(), Arrays.asList(this), OperatorType.LOG, this.getFactory());
    }

    @Override
    public RandomVariable sin() {
        return new RandomVariableDifferentiableAAD(this.getValues().sin(), Arrays.asList(this), OperatorType.SIN, this.getFactory());
    }

    @Override
    public RandomVariable cos() {
        return new RandomVariableDifferentiableAAD(this.getValues().cos(), Arrays.asList(this), OperatorType.COS, this.getFactory());
    }

    @Override
    public RandomVariable add(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.add(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().add(randomVariable.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)), Arrays.asList(null, null), null, OperatorType.ADD, this.getFactory(), 3);
    }

    @Override
    public RandomVariable sub(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.bus(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().sub(randomVariable.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)), Arrays.asList(null, null), null, OperatorType.SUB, this.getFactory(), 3);
    }

    @Override
    public RandomVariable bus(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.sub(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().bus(randomVariable.getValues()), Arrays.asList(OperatorTreeNode.of(randomVariable), this.getOperatorTreeNode()), Arrays.asList(null, null), null, OperatorType.SUB, this.getFactory(), 3);
    }

    @Override
    public RandomVariable mult(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.mult(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().mult(randomVariable.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)), Arrays.asList(this.getValues(), randomVariable.getValues()), null, OperatorType.MULT, this.getFactory(), 3);
    }

    @Override
    public RandomVariable div(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.vid(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().div(randomVariable.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)), Arrays.asList(this.getValues(), randomVariable.getValues()), null, OperatorType.DIV, this.getFactory(), 3);
    }

    @Override
    public RandomVariable vid(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.div(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().vid(randomVariable.getValues()), Arrays.asList(OperatorTreeNode.of(randomVariable), this.getOperatorTreeNode()), Arrays.asList(randomVariable.getValues(), this.getValues()), null, OperatorType.DIV, this.getFactory(), 3);
    }

    @Override
    public RandomVariable cap(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.cap(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().cap(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.CAP, this.getFactory());
    }

    @Override
    public RandomVariable floor(RandomVariable floor) {
        if (floor.getTypePriority() > this.getTypePriority()) {
            return floor.floor(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().floor(floor.getValues()), Arrays.asList(this, floor), OperatorType.FLOOR, this.getFactory());
    }

    @Override
    public RandomVariable accrue(RandomVariable rate, double periodLength) {
        if (rate.getTypePriority() > this.getTypePriority()) {
            return rate.mult(periodLength).add(1.0).mult(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().accrue(rate.getValues(), periodLength), Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)), OperatorType.ACCRUE, this.getFactory());
    }

    @Override
    public RandomVariable discount(RandomVariable rate, double periodLength) {
        if (rate.getTypePriority() > this.getTypePriority()) {
            return rate.mult(periodLength).add(1.0).invert().mult(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().discount(rate.getValues(), periodLength), Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)), OperatorType.DISCOUNT, this.getFactory());
    }

    @Override
    public RandomVariable choose(RandomVariable valueIfTriggerNonNegative, RandomVariable valueIfTriggerNegative) {
        return new RandomVariableDifferentiableAAD(this.getValues().choose(valueIfTriggerNonNegative.getValues(), valueIfTriggerNegative.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(valueIfTriggerNonNegative), OperatorTreeNode.of(valueIfTriggerNegative)), Arrays.asList(this.getValues(), valueIfTriggerNonNegative.getValues(), valueIfTriggerNegative.getValues()), null, OperatorType.CHOOSE, this.getFactory(), 3);
    }

    @Override
    public RandomVariable invert() {
        return new RandomVariableDifferentiableAAD(this.getValues().invert(), Arrays.asList(this), OperatorType.INVERT, this.getFactory());
    }

    @Override
    public RandomVariable abs() {
        return new RandomVariableDifferentiableAAD(this.getValues().abs(), Arrays.asList(this), OperatorType.ABS, this.getFactory());
    }

    @Override
    public RandomVariable addProduct(RandomVariable factor1, double factor2) {
        if (factor1.getTypePriority() > this.getTypePriority()) {
            return factor1.mult(factor2).add(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().addProduct(factor1.getValues(), factor2), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(factor1), null), Arrays.asList(this.getValues(), factor1.getValues(), new Scalar(factor2)), null, OperatorType.ADDPRODUCT, this.getFactory(), 3);
    }

    @Override
    public RandomVariable addProduct(RandomVariable factor1, RandomVariable factor2) {
        if (factor1.getTypePriority() > this.getTypePriority() || factor2.getTypePriority() > this.getTypePriority()) {
            return factor1.mult(factor2).add(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().addProduct(factor1.getValues(), factor2.getValues()), Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(factor1), OperatorTreeNode.of(factor2)), Arrays.asList(this.getValues(), factor1.getValues(), factor2.getValues()), null, OperatorType.ADDPRODUCT, this.getFactory(), 3);
    }

    @Override
    public RandomVariable addRatio(RandomVariable numerator, RandomVariable denominator) {
        if (numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
            return numerator.div(denominator).add(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().addRatio(numerator.getValues(), denominator.getValues()), Arrays.asList(this, numerator, denominator), OperatorType.ADDRATIO, this.getFactory());
    }

    @Override
    public RandomVariable subRatio(RandomVariable numerator, RandomVariable denominator) {
        if (numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
            return numerator.div(denominator).mult(-1.0).add(this);
        }
        return new RandomVariableDifferentiableAAD(this.getValues().subRatio(numerator.getValues(), denominator.getValues()), Arrays.asList(this, numerator, denominator), OperatorType.SUBRATIO, this.getFactory());
    }

    @Override
    public RandomVariable isNaN() {
        return this.getValues().isNaN();
    }

    @Override
    public IntToDoubleFunction getOperator() {
        return this.getValues().getOperator();
    }

    @Override
    public DoubleStream getRealizationsStream() {
        return this.getValues().getRealizationsStream();
    }

    @Override
    public RandomVariable apply(DoubleUnaryOperator operator) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    @Override
    public RandomVariable apply(DoubleBinaryOperator operator, RandomVariable argument) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    @Override
    public RandomVariable apply(DoubleTernaryOperator operator, RandomVariable argument1, RandomVariable argument2) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    public RandomVariable getVarianceAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getVariance()), Arrays.asList(this), OperatorType.VARIANCE, this.getFactory());
    }

    public RandomVariable getSampleVarianceAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getSampleVariance()), Arrays.asList(this), OperatorType.SVARIANCE, this.getFactory());
    }

    public RandomVariable getStandardDeviationAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getStandardDeviation()), Arrays.asList(this), OperatorType.STDEV, this.getFactory());
    }

    public RandomVariable getStandardErrorAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getStandardError()), Arrays.asList(this), OperatorType.STDERROR, this.getFactory());
    }

    public RandomVariable getMinAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getMin()), Arrays.asList(this), OperatorType.MIN, this.getFactory());
    }

    public RandomVariable getMaxAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAAD(new RandomVariableFromDoubleArray(this.getMax()), Arrays.asList(this), OperatorType.MAX, this.getFactory());
    }

    public String toString() {
        return "RandomVariableDifferentiableAAD [values=" + this.values + ",\n ID=" + this.getID() + "]";
    }

    @Override
    public RandomVariableDifferentiable getCloneIndependent() {
        return new RandomVariableDifferentiableAAD(this.getValues());
    }

    private static class OperatorTreeNode
    implements Serializable {
        private static final long serialVersionUID = -8428352552169568990L;
        private final Long id = indexOfNextRandomVariable.getAndIncrement();
        private final OperatorType operatorType;
        private final List<OperatorTreeNode> arguments;
        private final List<RandomVariable> argumentValues;
        private final Object operator;
        private final RandomVariableDifferentiableAADFactory factory;
        private static final RandomVariable zero = new Scalar(0.0);
        private static final RandomVariable one = new Scalar(1.0);
        private static final RandomVariable minusOne = new Scalar(-1.0);

        OperatorTreeNode(OperatorType operatorType, List<OperatorTreeNode> arguments, List<RandomVariable> argumentValues, Object operator, RandomVariableDifferentiableAADFactory factory) {
            this.operatorType = operatorType;
            this.arguments = arguments;
            this.operator = operator;
            this.factory = factory;
            if (operatorType != null && (operatorType.equals((Object)OperatorType.ADD) || operatorType.equals((Object)OperatorType.SUB))) {
                argumentValues = null;
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.AVERAGE)) {
                argumentValues = null;
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.MULT)) {
                if (arguments.get(0) == null) {
                    argumentValues.set(1, null);
                }
                if (arguments.get(1) == null) {
                    argumentValues.set(0, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.DIV)) {
                if (arguments.get(1) == null) {
                    argumentValues.set(0, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.ADDPRODUCT)) {
                argumentValues.set(0, null);
                if (arguments.get(1) == null) {
                    argumentValues.set(2, null);
                }
                if (arguments.get(2) == null) {
                    argumentValues.set(1, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.ACCRUE)) {
                if (arguments.get(1) == null && arguments.get(2) == null) {
                    argumentValues.set(0, null);
                }
                if (arguments.get(0) == null && arguments.get(1) == null) {
                    argumentValues.set(1, null);
                }
                if (arguments.get(0) == null && arguments.get(2) == null) {
                    argumentValues.set(2, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.CHOOSE) && arguments.get(0) == null) {
                argumentValues.set(1, null);
                argumentValues.set(2, null);
            }
            this.argumentValues = argumentValues;
        }

        private void propagateDerivativesFromResultToArgument(Map<Long, RandomVariable> derivatives) {
            if (this.arguments == null) {
                return;
            }
            for (int argumentIndex = 0; argumentIndex < this.arguments.size(); ++argumentIndex) {
                OperatorTreeNode argument = this.arguments.get(argumentIndex);
                if (argument == null) continue;
                Long argumentID = argument.id;
                RandomVariable partialDerivative = this.getPartialDerivative(argument, argumentIndex);
                RandomVariable derivative = derivatives.get(this.id);
                RandomVariable argumentDerivative = derivatives.get(argumentID);
                switch (this.operatorType) {
                    case AVERAGE: {
                        derivative = derivative.average();
                        break;
                    }
                    case CONDITIONAL_EXPECTATION: {
                        ConditionalExpectationEstimator estimator = (ConditionalExpectationEstimator)this.operator;
                        derivative = estimator.getConditionalExpectation(derivative);
                        break;
                    }
                    case CHOOSE: {
                        if (argumentIndex != 0 || this.factory.getDiracDeltaApproximationMethod() != RandomVariableDifferentiableAADFactory.DiracDeltaApproximationMethod.REGRESSION_ON_DENSITY && this.factory.getDiracDeltaApproximationMethod() != RandomVariableDifferentiableAADFactory.DiracDeltaApproximationMethod.REGRESSION_ON_DISTRIBUITON) break;
                        derivative = this.getDiracDeltaRegression(derivative, this.argumentValues.get(0));
                        break;
                    }
                }
                argumentDerivative = argumentDerivative == null ? derivative.mult(partialDerivative) : argumentDerivative.addProduct(partialDerivative, derivative);
                derivatives.put(argumentID, argumentDerivative);
            }
        }

        private RandomVariable getPartialDerivative(OperatorTreeNode differential, int differentialIndex) {
            RandomVariable derivative;
            if (!this.arguments.contains(differential)) {
                return zero;
            }
            RandomVariable X = this.arguments.size() > 0 && this.argumentValues != null ? this.argumentValues.get(0) : null;
            RandomVariable Y = this.arguments.size() > 1 && this.argumentValues != null ? this.argumentValues.get(1) : null;
            RandomVariable Z = this.arguments.size() > 2 && this.argumentValues != null ? this.argumentValues.get(2) : null;
            block0 : switch (this.operatorType) {
                case SQUARED: {
                    derivative = X.mult(2.0);
                    break;
                }
                case SQRT: {
                    derivative = X.sqrt().invert().mult(0.5);
                    break;
                }
                case EXP: {
                    derivative = X.exp();
                    break;
                }
                case LOG: {
                    derivative = X.invert();
                    break;
                }
                case SIN: {
                    derivative = X.cos();
                    break;
                }
                case COS: {
                    derivative = X.sin().mult(-1.0);
                    break;
                }
                case INVERT: {
                    derivative = X.invert().squared().mult(-1.0);
                    break;
                }
                case AVERAGE: {
                    derivative = one;
                    break;
                }
                case CONDITIONAL_EXPECTATION: {
                    derivative = one;
                    break;
                }
                case VARIANCE: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size());
                    break;
                }
                case STDEV: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size()).mult(0.5).div(Math.sqrt(X.getVariance()));
                    break;
                }
                case MIN: {
                    final double min = X.getMin();
                    derivative = X.apply(new DoubleUnaryOperator(){

                        @Override
                        public double applyAsDouble(double x) {
                            return x == min ? 1.0 : 0.0;
                        }
                    });
                    break;
                }
                case MAX: {
                    final double max = X.getMax();
                    derivative = X.apply(new DoubleUnaryOperator(){

                        @Override
                        public double applyAsDouble(double x) {
                            return x == max ? 1.0 : 0.0;
                        }
                    });
                    break;
                }
                case ABS: {
                    derivative = X.choose(one, minusOne);
                    break;
                }
                case STDERROR: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size()).mult(0.5).div(Math.sqrt(X.getVariance() * (double)X.size()));
                    break;
                }
                case SVARIANCE: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)(X.size() - 1));
                    break;
                }
                case ADD: {
                    derivative = one;
                    break;
                }
                case SUB: {
                    derivative = differentialIndex == 0 ? one : minusOne;
                    break;
                }
                case MULT: {
                    derivative = differentialIndex == 0 ? Y : X;
                    break;
                }
                case DIV: {
                    derivative = differentialIndex == 0 ? Y.invert() : X.div(Y.squared()).mult(-1.0);
                    break;
                }
                case CAP: {
                    if (differentialIndex == 0) {
                        derivative = X.sub(Y).choose(zero, one);
                        break;
                    }
                    derivative = X.sub(Y).choose(one, zero);
                    break;
                }
                case FLOOR: {
                    if (differentialIndex == 0) {
                        derivative = X.sub(Y).choose(one, zero);
                        break;
                    }
                    derivative = X.sub(Y).choose(zero, one);
                    break;
                }
                case AVERAGE2: {
                    derivative = differentialIndex == 0 ? Y : X;
                    break;
                }
                case VARIANCE2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X))));
                    break;
                }
                case STDEV2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y))) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X)));
                    break;
                }
                case STDERROR2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y) * (double)X.size())) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X) * (double)Y.size()));
                    break;
                }
                case POW: {
                    derivative = differentialIndex == 0 ? X.pow(Y.doubleValue() - 1.0).mult(Y) : zero;
                    break;
                }
                case ADDPRODUCT: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z;
                        break;
                    }
                    derivative = Y;
                    break;
                }
                case ADDRATIO: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z.invert();
                        break;
                    }
                    derivative = Y.div(Z.squared()).mult(-1.0);
                    break;
                }
                case SUBRATIO: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z.invert().mult(-1.0);
                        break;
                    }
                    derivative = Y.div(Z.squared());
                    break;
                }
                case ACCRUE: {
                    if (differentialIndex == 0) {
                        derivative = Y.mult(Z).add(1.0);
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = X.mult(Z);
                        break;
                    }
                    derivative = X.mult(Y);
                    break;
                }
                case DISCOUNT: {
                    if (differentialIndex == 0) {
                        derivative = Y.mult(Z).add(1.0).invert();
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = X.mult(Z).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
                        break;
                    }
                    derivative = X.mult(Y).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
                    break;
                }
                case CHOOSE: {
                    if (differentialIndex == 0) {
                        switch (this.factory.getDiracDeltaApproximationMethod()) {
                            case ONE: {
                                derivative = Y.sub(Z);
                                break block0;
                            }
                            case ZERO: {
                                derivative = zero;
                                break block0;
                            }
                            case DISCRETE_DELTA: {
                                double epsilon = this.factory.getDiracDeltaApproximationWidthPerStdDev() * X.getStandardDeviation();
                                if (Double.isInfinite(epsilon)) {
                                    derivative = Y.sub(Z);
                                    break block0;
                                }
                                if (epsilon > 0.0) {
                                    derivative = Y.sub(Z);
                                    derivative = derivative.mult(X.add(epsilon / 2.0).choose(one, zero));
                                    derivative = derivative.mult(X.sub(epsilon / 2.0).choose(zero, one));
                                    derivative = derivative.div(epsilon);
                                    break block0;
                                }
                                derivative = zero;
                                break block0;
                            }
                            case REGRESSION_ON_DENSITY: 
                            case REGRESSION_ON_DISTRIBUITON: {
                                derivative = Y.sub(Z);
                                break block0;
                            }
                        }
                        throw new UnsupportedOperationException("Diract Delta Approximation Method " + this.factory.getDiracDeltaApproximationMethod().name() + " not supported.");
                    }
                    if (differentialIndex == 1) {
                        derivative = X.choose(one, zero);
                        break;
                    }
                    derivative = X.choose(zero, one);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Operation " + this.operatorType.name() + " not supported in differentiation.");
                }
            }
            return derivative;
        }

        private RandomVariable getDiracDeltaRegression(RandomVariable derivative, RandomVariable indicator) {
            double diracDeltaApproximationWidthPerStdDev = this.factory.getDiracDeltaApproximationWidthPerStdDev();
            double epsilon = diracDeltaApproximationWidthPerStdDev * indicator.getStandardDeviation();
            RandomVariable localizedOne = indicator.add(epsilon / 2.0).choose(one, zero).mult(indicator.sub(epsilon / 2.0).choose(zero, one));
            boolean isDirectDeltaRegressionUseRegressionOnAdjointDerivative = false;
            derivative = derivative.mult(localizedOne).div(localizedOne.getAverage());
            return derivative.mult(this.getDensityRegression(indicator));
        }

        private double getDensityRegression(RandomVariable indicator) {
            double diracDeltaApproximationDensityRegressionWidthPerStdDev = this.factory.getDiracDeltaApproximationDensityRegressionWidthPerStdDev();
            double underlyingStdDev = indicator.getStandardDeviation();
            int numberOfSamplePointsHalf = 50;
            double sampleIntervalWidthHalf = diracDeltaApproximationDensityRegressionWidthPerStdDev / 2.0 * underlyingStdDev / 50.0;
            double[] samplePointX = new double[100];
            double[] samplePointY = new double[100];
            double sampleInterval = sampleIntervalWidthHalf;
            RandomVariable indicatorPositiveValues = indicator.choose(new Scalar(1.0), new Scalar(0.0));
            RandomVariable indicatorNegativeValues = indicator.choose(new Scalar(0.0), new Scalar(1.0));
            switch (this.factory.getDiracDeltaApproximationMethod()) {
                case REGRESSION_ON_DENSITY: {
                    for (int i = 0; i < 100; i += 2) {
                        RandomVariable indicatorOnNegValues = indicator.add(sampleInterval += sampleIntervalWidthHalf).choose(new Scalar(1.0), new Scalar(0.0)).mult(indicatorNegativeValues);
                        RandomVariable indicatorOnPosValues = indicator.sub(sampleInterval).choose(new Scalar(0.0), new Scalar(1.0)).mult(indicatorPositiveValues);
                        samplePointX[i] = -sampleInterval;
                        samplePointY[i] = indicatorOnNegValues.getAverage() / sampleInterval;
                        samplePointX[i + 1] = sampleInterval;
                        samplePointY[i + 1] = indicatorOnPosValues.getAverage() / sampleInterval;
                    }
                    RandomVariableFromDoubleArray densityX = new RandomVariableFromDoubleArray(0.0, samplePointX);
                    RandomVariableFromDoubleArray densityValues = new RandomVariableFromDoubleArray(0.0, samplePointY);
                    double[] densityRegressionCoeff = new LinearRegression(new RandomVariable[]{densityX.mult(0.0).add(1.0), densityX, densityX.squared()}).getRegressionCoefficients(densityValues);
                    double density = densityRegressionCoeff[0];
                    return density;
                }
                case REGRESSION_ON_DISTRIBUITON: {
                    for (int i = 0; i < 100; i += 2) {
                        RandomVariable indicatorOnNegValues = indicator.add(sampleInterval += sampleIntervalWidthHalf).choose(new Scalar(1.0), new Scalar(0.0)).mult(indicatorNegativeValues);
                        RandomVariable indicatorOnPosValues = indicator.sub(sampleInterval).choose(new Scalar(0.0), new Scalar(1.0)).mult(indicatorPositiveValues);
                        samplePointX[i] = -sampleInterval;
                        samplePointY[i] = -indicatorOnNegValues.getAverage();
                        samplePointX[i + 1] = sampleInterval;
                        samplePointY[i + 1] = indicatorOnPosValues.getAverage();
                    }
                    RandomVariableFromDoubleArray densityX = new RandomVariableFromDoubleArray(0.0, samplePointX);
                    RandomVariableFromDoubleArray densityValues = new RandomVariableFromDoubleArray(0.0, samplePointY);
                    double[] densityRegressionCoeff = new LinearRegression(new RandomVariable[]{densityX, densityX.squared(), densityX.pow(3.0)}).getRegressionCoefficients(densityValues);
                    double density = densityRegressionCoeff[0];
                    return density;
                }
            }
            throw new UnsupportedOperationException("Density regression method " + this.factory.getDiracDeltaApproximationMethod().name() + " not supported.");
        }

        private static OperatorTreeNode of(RandomVariable randomVariable) {
            return randomVariable != null && randomVariable instanceof RandomVariableDifferentiableAAD ? ((RandomVariableDifferentiableAAD)randomVariable).getOperatorTreeNode() : null;
        }

        private static RandomVariable getValue(RandomVariable randomVariable) {
            return randomVariable != null ? randomVariable.getValues() : randomVariable;
        }

        private static List<OperatorTreeNode> extractOperatorTreeNodes(List<RandomVariable> arguments) {
            return arguments != null ? arguments.stream().map(OperatorTreeNode::of).collect(Collectors.toList()) : null;
        }

        private static List<RandomVariable> extractOperatorValues(List<RandomVariable> arguments) {
            return arguments != null ? arguments.stream().map(OperatorTreeNode::getValue).collect(Collectors.toList()) : null;
        }

        private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
            stream.defaultReadObject();
            try {
                Field idField = this.getClass().getDeclaredField("id");
                idField.setAccessible(true);
                idField.set(this, indexOfNextRandomVariable.getAndIncrement());
                idField.setAccessible(false);
            }
            catch (IllegalAccessException | IllegalArgumentException | NoSuchFieldException | SecurityException e) {
                throw new RuntimeException("Unable to re-assing id of " + this.getClass().getSimpleName() + ".", e);
            }
        }
    }

    private static enum OperatorType {
        ADD,
        MULT,
        DIV,
        SUB,
        SQUARED,
        SQRT,
        LOG,
        SIN,
        COS,
        EXP,
        INVERT,
        CAP,
        FLOOR,
        ABS,
        ADDPRODUCT,
        ADDRATIO,
        SUBRATIO,
        CHOOSE,
        DISCOUNT,
        ACCRUE,
        POW,
        MIN,
        MAX,
        AVERAGE,
        VARIANCE,
        STDEV,
        STDERROR,
        SVARIANCE,
        AVERAGE2,
        VARIANCE2,
        STDEV2,
        STDERROR2,
        CONDITIONAL_EXPECTATION;

    }
}

