package io.siddhi.extension.execution.streamingml.bayesian.util;

import io.siddhi.extension.execution.streamingml.bayesian.model.NormalDistribution;
import java.util.Arrays;
import org.apache.log4j.Logger;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/util/LinearRegression.class */
public class LinearRegression extends BayesianModel {
    private static final Logger logger = Logger.getLogger(LinearRegression.class.getName());
    private static final long serialVersionUID = -5112177245729410690L;
    private NormalDistribution weights;
    private SDVariable likelihoodScale;
    private SDVariable loss;

    public LinearRegression() {
    }

    public LinearRegression(LinearRegression linearRegression) {
        super(linearRegression);
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    SDVariable[] specifyModel() {
        this.xIn = this.sd.var("xIn", 1, this.numFeatures);
        this.yIn = this.sd.var("yIn", 1);
        SDVariable var = this.sd.var("wLocVar", this.numFeatures, 1);
        SDVariable var2 = this.sd.var("wScaleVar", this.numFeatures, 1);
        SDVariable var3 = this.sd.var("likScaleVar", 1, 1);
        this.likelihoodScale = this.sd.softplus(var3);
        this.weights = new NormalDistribution(var, this.sd.softplus(var2), this.sd);
        SDVariable[] sDVariableArr = new SDVariable[this.numSamples];
        for (int i = 0; i < this.numSamples; i++) {
            sDVariableArr[i] = new NormalDistribution(this.xIn.mmul(this.weights.sample()), this.likelihoodScale, this.sd).logProbability(this.yIn);
        }
        this.loss = this.sd.neg(this.sd.mergeAvg(sDVariableArr));
        return new SDVariable[]{var, var2, var3};
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    double predictionFromPredictiveDensity(INDArray iNDArray) {
        return iNDArray.mean(1).toDoubleVector()[0];
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    double confidenceFromPredictiveDensity(INDArray iNDArray) {
        return 1.0d / iNDArray.std(1).toDoubleVector()[0];
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    protected double[][] getUpdatedWeights() {
        logger.debug(Arrays.toString(this.weights.getLoc().getArr().toDoubleVector()));
        return new double[]{this.weights.getLoc().getArr().toDoubleVector(), this.weights.getScale().getArr().toDoubleVector()};
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    public double evaluate(double[] dArr, Object obj) {
        return 0.0d;
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    INDArray estimatePredictiveDistribution(INDArray iNDArray, int i) {
        return iNDArray.mmul(Nd4j.randn(new long[]{this.numFeatures, i}).mulColumnVector(this.weights.getScale().getArr()).addColumnVector(this.weights.getLoc().getArr())).add(Nd4j.randn(new long[]{1, i}).mul(this.likelihoodScale.getArr()));
    }
}
