package io.siddhi.extension.execution.streamingml.bayesian.util;

import io.siddhi.core.exception.SiddhiAppCreationException;
import java.io.Serializable;
import org.apache.log4j.Logger;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Sgd;

/* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/util/BayesianModel.class */
public abstract class BayesianModel implements Serializable {
    private static final Logger logger = Logger.getLogger(BayesianModel.class.getName());
    private static final long serialVersionUID = -3217237991548906395L;
    SameDiff sd;
    SDVariable xIn;
    SDVariable yIn;
    int numFeatures;
    int numSamples;
    private boolean addBias;
    private int predictionSamples;
    private OptimizerType optimizerType;
    private double learningRate;
    private SDVariable[] vars;
    private GradientUpdater[] updaters;
    private IUpdater optimizer;
    private INDArray[] weightStates;
    private INDArray[] viewArrays;

    /* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/util/BayesianModel$OptimizerType.class */
    public enum OptimizerType {
        ADAM,
        ADAGRAD,
        SGD,
        NADAM
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BayesianModel() {
        this.numFeatures = -1;
        this.numSamples = 1;
        this.addBias = false;
        this.optimizerType = OptimizerType.ADAM;
        this.learningRate = 0.05d;
        this.predictionSamples = 1000;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BayesianModel(BayesianModel bayesianModel) {
        this.numFeatures = bayesianModel.numFeatures;
        this.numSamples = bayesianModel.numSamples;
        this.addBias = bayesianModel.addBias;
        this.predictionSamples = bayesianModel.predictionSamples;
        this.optimizer = bayesianModel.optimizer;
        this.weightStates = bayesianModel.weightStates;
        this.optimizerType = bayesianModel.optimizerType;
        this.learningRate = bayesianModel.learningRate;
        this.viewArrays = bayesianModel.viewArrays;
    }

    public void initiateModel() {
        this.sd = SameDiff.create();
        if (this.addBias) {
            this.numFeatures++;
        }
        this.vars = specifyModel();
        int length = this.vars.length;
        if (this.weightStates == null) {
            this.optimizer = createUpdater();
            this.weightStates = new INDArray[length];
            this.viewArrays = new INDArray[length];
            for (int i = 0; i < length; i++) {
                this.weightStates[i] = this.vars[i].getArr();
                long j = 1;
                for (long j2 : this.vars[i].getShape()) {
                    j *= j2;
                }
                this.viewArrays[i] = Nd4j.create(1, this.optimizer.stateSize(j));
            }
        } else {
            if (this.vars.length != this.weightStates.length) {
                throw new SiddhiAppCreationException(String.format("Failed to restore model due to unmatched number of variables. Expected %d variables. Given %d", Integer.valueOf(this.vars.length), Integer.valueOf(this.weightStates.length)));
            }
            if (this.viewArrays == null || this.viewArrays.length != this.vars.length) {
                throw new SiddhiAppCreationException("Failed recovering the state of the model. Invalid state for the gradient updaters");
            }
            for (int i2 = 0; i2 < this.vars.length; i2++) {
                try {
                    this.vars[0].setArray(this.weightStates[i2]);
                } catch (Exception e) {
                    throw new SiddhiAppCreationException("Failed recovering the state of the gradient updaters. Invalid state for the variables");
                }
            }
        }
        this.updaters = new GradientUpdater[this.vars.length];
        for (int i3 = 0; i3 < length; i3++) {
            if (OptimizerType.SGD.equals(this.optimizerType)) {
                this.updaters[i3] = this.optimizer.instantiate(null, true);
            } else {
                this.updaters[i3] = this.optimizer.instantiate(this.viewArrays[i3], true);
            }
        }
        logger.debug("Successfully initiated gradient optimizer : " + this.optimizer.getClass().getSimpleName());
    }

    private void updateVariables() {
        for (int i = 0; i < this.vars.length; i++) {
            SDVariable sDVariable = this.vars[i];
            INDArray arr = sDVariable.getGradient().getArr();
            long[] shape = arr.shape();
            INDArray flattened = Nd4j.toFlattened(arr);
            if (Double.isNaN(arr.mean(new int[0]).toDoubleVector()[0])) {
                logger.warn(String.format("invalid gradients. skipping variable update of %s", sDVariable.getVarName()));
                return;
            } else {
                this.updaters[i].applyUpdater(flattened, 1, 0);
                sDVariable.setArray(sDVariable.getArr().sub(flattened.reshape(shape)));
            }
        }
    }

    public double[] update(double[] dArr, double[] dArr2) {
        INDArray create = Nd4j.create(dArr);
        INDArray create2 = Nd4j.create(dArr2);
        if (this.addBias) {
            create = Nd4j.append(create, 1, 1.0d, 1);
        }
        this.xIn.setArray(create);
        this.yIn.setArray(create2);
        INDArray execAndEndResult = this.sd.execAndEndResult();
        this.sd.execBackwards();
        logger.info(getClass().getName() + " model loss : " + execAndEndResult.toString());
        updateVariables();
        return execAndEndResult.toDoubleVector();
    }

    public Double predict(double[] dArr) {
        INDArray create = Nd4j.create(dArr);
        if (this.addBias) {
            create = Nd4j.append(create, 1, 1.0d, 1);
        }
        return Double.valueOf(predictionFromPredictiveDensity(estimatePredictiveDistribution(create, this.predictionSamples)));
    }

    public Double[] predictWithStd(double[] dArr) {
        INDArray create = Nd4j.create(dArr);
        logger.info(create.toString());
        if (this.addBias) {
            create = Nd4j.append(create, 1, 1.0d, 1);
        }
        INDArray estimatePredictiveDistribution = estimatePredictiveDistribution(create, this.predictionSamples);
        return new Double[]{Double.valueOf(predictionFromPredictiveDensity(estimatePredictiveDistribution)), Double.valueOf(confidenceFromPredictiveDensity(estimatePredictiveDistribution))};
    }

    protected abstract double[][] getUpdatedWeights();

    public abstract double evaluate(double[] dArr, Object obj);

    abstract INDArray estimatePredictiveDistribution(INDArray iNDArray, int i);

    abstract SDVariable[] specifyModel();

    abstract double predictionFromPredictiveDensity(INDArray iNDArray);

    abstract double confidenceFromPredictiveDensity(INDArray iNDArray);

    private IUpdater createUpdater() {
        switch (this.optimizerType) {
            case ADAM:
                return new Adam(this.learningRate);
            case SGD:
                return new Sgd(this.learningRate);
            case ADAGRAD:
                return new AdaGrad(this.learningRate);
            case NADAM:
                return new Nadam(this.learningRate);
            default:
                return new Adam(this.learningRate);
        }
    }

    public void setAddBias(boolean z) {
        this.addBias = z;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setNumFeatures(int i) {
        this.numFeatures = i;
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    public void setNumSamples(int i) {
        this.numSamples = i;
    }

    public OptimizerType getOptimizerType() {
        return this.optimizerType;
    }

    public void setOptimizerType(OptimizerType optimizerType) {
        this.optimizerType = optimizerType;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setPredictionSamples(int i) {
        this.predictionSamples = i;
    }
}
