package io.siddhi.extension.execution.streamingml.bayesian.util;

import io.siddhi.core.exception.SiddhiAppCreationException;
import io.siddhi.extension.execution.streamingml.bayesian.model.CategoricalDistribution;
import io.siddhi.extension.execution.streamingml.bayesian.model.NormalDistribution;
import java.util.ArrayList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/util/SoftmaxRegression.class */
public class SoftmaxRegression extends BayesianModel {
    private static final Logger logger = LogManager.getLogger(SoftmaxRegression.class.getName());
    private static final long serialVersionUID = 3330926145654494163L;
    private NormalDistribution weights;
    private SDVariable loss;
    private List<String> classes;
    private int noOfClasses;
    private PrequentialEvaluation eval;

    public SoftmaxRegression(int i) {
        this.classes = new ArrayList();
        this.noOfClasses = i;
        this.eval = new PrequentialEvaluation();
    }

    public SoftmaxRegression(SoftmaxRegression softmaxRegression) {
        super(softmaxRegression);
        this.classes = new ArrayList();
        this.noOfClasses = softmaxRegression.noOfClasses;
        this.eval = softmaxRegression.eval;
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    SDVariable[] specifyModel() {
        this.xIn = this.sd.var("xIn", new int[]{1, this.numFeatures});
        this.yIn = this.sd.var("yIn", new int[]{1});
        SDVariable var = this.sd.var("wLoc", new int[]{this.numFeatures, this.noOfClasses});
        SDVariable softplus = this.sd.softplus("wScale", this.sd.var(new int[]{this.numFeatures, this.noOfClasses}));
        this.weights = new NormalDistribution(var, softplus, this.sd);
        SDVariable[] sDVariableArr = new SDVariable[this.numSamples];
        for (int i = 0; i < this.numSamples; i++) {
            sDVariableArr[i] = new CategoricalDistribution(this.xIn.mmul(this.weights.sample()), this.sd).logProbability(this.yIn);
        }
        this.loss = this.sd.neg(this.sd.mergeAvg(sDVariableArr));
        return new SDVariable[]{var, softplus};
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    double predictionFromPredictiveDensity(INDArray iNDArray) {
        return iNDArray.mean(new int[]{1}).argMax(new int[0]).toDoubleVector()[0];
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    double confidenceFromPredictiveDensity(INDArray iNDArray) {
        return iNDArray.mean(new int[]{1}).max(new int[0]).toDoubleVector()[0];
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    protected double[][] getUpdatedWeights() {
        return new double[]{this.weights.getLoc().getArr().reshape(new long[]{-1}).toDoubleVector(), this.weights.getScale().getArr().reshape(new long[]{-1}).toDoubleVector()};
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    public double evaluate(double[] dArr, Object obj) {
        return this.eval.evaluate(addClass((String) obj), predict(dArr).intValue());
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel
    INDArray estimatePredictiveDistribution(INDArray iNDArray, int i) {
        return Transforms.softmax(iNDArray.mmul(Nd4j.randn(new long[]{this.numFeatures * this.noOfClasses, i}).mulColumnVector(this.weights.getScale().getArr().reshape(this.numFeatures * this.noOfClasses, 1L)).addColumnVector(this.weights.getLoc().getArr().reshape(this.numFeatures * this.noOfClasses, 1L)).reshape(new int[]{this.numFeatures, this.noOfClasses * i})).reshape(this.noOfClasses, i).transpose()).transpose();
    }

    public void setNoOfClasses(int i) {
        this.noOfClasses = i;
    }

    public double[] update(double[] dArr, String str) {
        int addClass = addClass(str);
        double[] dArr2 = new double[this.noOfClasses];
        dArr2[addClass] = 1.0d;
        return super.update(dArr, dArr2);
    }

    private int addClass(String str) {
        if (this.classes.contains(str)) {
            return this.classes.indexOf(str);
        }
        if (this.classes.size() >= this.noOfClasses) {
            throw new SiddhiAppCreationException(String.format("Only %s classes are expected by the model . But found %s", Integer.valueOf(this.noOfClasses), Integer.valueOf(this.classes.size() + 1)));
        }
        this.classes.add(str);
        return this.classes.indexOf(str);
    }

    public String getClassLabel(Number number) {
        return this.classes.get(number.intValue());
    }
}
