package org.apache.sysds.runtime.functionobjects;

import java.util.HashMap;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.class */
public class ParameterizedBuiltin extends ValueFunction {
    private static final long serialVersionUID = -5966242955816522697L;
    public ParameterizedBuiltinCode bFunc;
    public ProbabilityDistributionCode distFunc;
    public static HashMap<String, ParameterizedBuiltinCode> String2ParameterizedBuiltinCode = new HashMap<>();
    public static HashMap<String, ProbabilityDistributionCode> String2DistCode;
    private static ParameterizedBuiltin normalObj;
    private static ParameterizedBuiltin expObj;
    private static ParameterizedBuiltin chisqObj;
    private static ParameterizedBuiltin fObj;
    private static ParameterizedBuiltin tObj;
    private static ParameterizedBuiltin inormalObj;
    private static ParameterizedBuiltin iexpObj;
    private static ParameterizedBuiltin ichisqObj;
    private static ParameterizedBuiltin ifObj;
    private static ParameterizedBuiltin itObj;

    /* loaded from: input_file:org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin$ParameterizedBuiltinCode.class */
    public enum ParameterizedBuiltinCode {
        CDF,
        INVCDF,
        RMEMPTY,
        REPLACE,
        REXPAND,
        LOWER_TRI,
        UPPER_TRI,
        TRANSFORMAPPLY,
        TRANSFORMDECODE,
        PARAMSERV
    }

    /* loaded from: input_file:org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin$ProbabilityDistributionCode.class */
    public enum ProbabilityDistributionCode {
        INVALID,
        NORMAL,
        EXP,
        CHISQ,
        F,
        T
    }

    private ParameterizedBuiltin(ParameterizedBuiltinCode parameterizedBuiltinCode) {
        this.bFunc = parameterizedBuiltinCode;
        this.distFunc = ProbabilityDistributionCode.INVALID;
    }

    private ParameterizedBuiltin(ParameterizedBuiltinCode parameterizedBuiltinCode, ProbabilityDistributionCode probabilityDistributionCode) {
        this.bFunc = parameterizedBuiltinCode;
        this.distFunc = probabilityDistributionCode;
    }

    public static ParameterizedBuiltin getParameterizedBuiltinFnObject(String str) {
        return getParameterizedBuiltinFnObject(str, null);
    }

    public static ParameterizedBuiltin getParameterizedBuiltinFnObject(String str, String str2) {
        ParameterizedBuiltinCode parameterizedBuiltinCode = String2ParameterizedBuiltinCode.get(str);
        switch (parameterizedBuiltinCode) {
            case CDF:
                ProbabilityDistributionCode probabilityDistributionCode = String2DistCode.get(str2.toLowerCase());
                switch (probabilityDistributionCode) {
                    case NORMAL:
                        if (normalObj == null) {
                            normalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, probabilityDistributionCode);
                        }
                        return normalObj;
                    case EXP:
                        if (expObj == null) {
                            expObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, probabilityDistributionCode);
                        }
                        return expObj;
                    case CHISQ:
                        if (chisqObj == null) {
                            chisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, probabilityDistributionCode);
                        }
                        return chisqObj;
                    case F:
                        if (fObj == null) {
                            fObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, probabilityDistributionCode);
                        }
                        return fObj;
                    case T:
                        if (tObj == null) {
                            tObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, probabilityDistributionCode);
                        }
                        return tObj;
                    default:
                        throw new DMLRuntimeException("Invalid distribution code: " + probabilityDistributionCode);
                }
            case INVCDF:
                ProbabilityDistributionCode probabilityDistributionCode2 = String2DistCode.get(str2.toLowerCase());
                switch (probabilityDistributionCode2) {
                    case NORMAL:
                        if (inormalObj == null) {
                            inormalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, probabilityDistributionCode2);
                        }
                        return inormalObj;
                    case EXP:
                        if (iexpObj == null) {
                            iexpObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, probabilityDistributionCode2);
                        }
                        return iexpObj;
                    case CHISQ:
                        if (ichisqObj == null) {
                            ichisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, probabilityDistributionCode2);
                        }
                        return ichisqObj;
                    case F:
                        if (ifObj == null) {
                            ifObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, probabilityDistributionCode2);
                        }
                        return ifObj;
                    case T:
                        if (itObj == null) {
                            itObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, probabilityDistributionCode2);
                        }
                        return itObj;
                    default:
                        throw new DMLRuntimeException("Invalid distribution code: " + probabilityDistributionCode2);
                }
            case RMEMPTY:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.RMEMPTY);
            case REPLACE:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.REPLACE);
            case LOWER_TRI:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.LOWER_TRI);
            case UPPER_TRI:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.UPPER_TRI);
            case REXPAND:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.REXPAND);
            case TRANSFORMAPPLY:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMAPPLY);
            case TRANSFORMDECODE:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMDECODE);
            case PARAMSERV:
                return new ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV);
            default:
                throw new DMLRuntimeException("Invalid parameterized builtin code: " + parameterizedBuiltinCode);
        }
    }

    @Override // org.apache.sysds.runtime.functionobjects.FunctionObject
    public double execute(HashMap<String, String> hashMap) {
        switch (this.bFunc) {
            case CDF:
            case INVCDF:
                switch (this.distFunc) {
                    case NORMAL:
                    case EXP:
                    case CHISQ:
                    case F:
                    case T:
                        return computeFromDistribution(this.distFunc, hashMap, this.bFunc == ParameterizedBuiltinCode.INVCDF);
                    default:
                        throw new DMLRuntimeException("Unsupported distribution (" + this.distFunc + ").");
                }
            default:
                throw new DMLRuntimeException("ParameterizedBuiltin.execute(): Unknown operation: " + this.bFunc);
        }
    }

    private static double computeFromDistribution(ProbabilityDistributionCode probabilityDistributionCode, HashMap<String, String> hashMap, boolean z) {
        NormalDistribution tDistribution;
        double parseDouble = Double.parseDouble(hashMap.get("target"));
        boolean z2 = true;
        if (hashMap.get("lower.tail") != null) {
            z2 = Boolean.parseBoolean(hashMap.get("lower.tail"));
        }
        switch (probabilityDistributionCode) {
            case NORMAL:
                double d = 0.0d;
                double d2 = 1.0d;
                String str = hashMap.get(Statement.GAGG_FN_MEAN);
                String str2 = hashMap.get("sd");
                if (str != null) {
                    d = Double.parseDouble(str);
                }
                if (str2 != null) {
                    d2 = Double.parseDouble(str2);
                }
                if (d2 > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    tDistribution = new NormalDistribution(d, d2);
                    break;
                } else {
                    throw new DMLRuntimeException("Standard deviation for Normal distribution must be positive (" + d2 + ")");
                }
            case EXP:
                double d3 = 1.0d;
                if (hashMap.get("rate") != null) {
                    d3 = Double.parseDouble(hashMap.get("rate"));
                }
                if (d3 > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    tDistribution = new ExponentialDistribution(1.0d / d3);
                    break;
                } else {
                    throw new DMLRuntimeException("Rate for Exponential distribution must be positive (" + d3 + ")");
                }
            case CHISQ:
                if (hashMap.get("df") != null) {
                    int parseToInt = UtilFunctions.parseToInt(hashMap.get("df"));
                    if (parseToInt > 0) {
                        tDistribution = new ChiSquaredDistribution(parseToInt);
                        break;
                    } else {
                        throw new DMLRuntimeException("Degrees of Freedom for chi-squared distribution must be positive (" + parseToInt + ")");
                    }
                } else {
                    throw new DMLRuntimeException("Degrees of freedom must be specified for chi-squared distribution (e.g., q=qchisq(0.5, df=20); p=pchisq(target=q, df=1.2))");
                }
            case F:
                if (hashMap.get("df1") != null && hashMap.get("df2") != null) {
                    int parseToInt2 = UtilFunctions.parseToInt(hashMap.get("df1"));
                    int parseToInt3 = UtilFunctions.parseToInt(hashMap.get("df2"));
                    if (parseToInt2 > 0 && parseToInt3 > 0) {
                        tDistribution = new FDistribution(parseToInt2, parseToInt3);
                        break;
                    } else {
                        throw new DMLRuntimeException("Degrees of Freedom for F distribution must be positive (" + parseToInt2 + "," + parseToInt3 + ")");
                    }
                } else {
                    throw new DMLRuntimeException("Degrees of freedom must be specified for F distribution (e.g., q = qf(target=0.5, df1=20, df2=30); p=pf(target=q, df1=20, df2=30))");
                }
                break;
            case T:
                if (hashMap.get("df") != null) {
                    int parseToInt4 = UtilFunctions.parseToInt(hashMap.get("df"));
                    if (parseToInt4 > 0) {
                        tDistribution = new TDistribution(parseToInt4);
                        break;
                    } else {
                        throw new DMLRuntimeException("Degrees of Freedom for t distribution must be positive (" + parseToInt4 + ")");
                    }
                } else {
                    throw new DMLRuntimeException("Degrees of freedom is needed to compute probabilities from t distribution (e.g., q = qt(target=0.5, df=10); p = pt(target=q, df=10))");
                }
            default:
                throw new DMLRuntimeException("Invalid distribution code: " + probabilityDistributionCode);
        }
        return z ? tDistribution.inverseCumulativeProbability(parseDouble) : z2 ? tDistribution.cumulativeProbability(parseDouble) : 1.0d - tDistribution.cumulativeProbability(parseDouble);
    }

    static {
        String2ParameterizedBuiltinCode.put("cdf", ParameterizedBuiltinCode.CDF);
        String2ParameterizedBuiltinCode.put("invcdf", ParameterizedBuiltinCode.INVCDF);
        String2ParameterizedBuiltinCode.put("rmempty", ParameterizedBuiltinCode.RMEMPTY);
        String2ParameterizedBuiltinCode.put("replace", ParameterizedBuiltinCode.REPLACE);
        String2ParameterizedBuiltinCode.put("lowertri", ParameterizedBuiltinCode.LOWER_TRI);
        String2ParameterizedBuiltinCode.put("uppertri", ParameterizedBuiltinCode.UPPER_TRI);
        String2ParameterizedBuiltinCode.put("rexpand", ParameterizedBuiltinCode.REXPAND);
        String2ParameterizedBuiltinCode.put("transformapply", ParameterizedBuiltinCode.TRANSFORMAPPLY);
        String2ParameterizedBuiltinCode.put("transformdecode", ParameterizedBuiltinCode.TRANSFORMDECODE);
        String2ParameterizedBuiltinCode.put("paramserv", ParameterizedBuiltinCode.PARAMSERV);
        String2DistCode = new HashMap<>();
        String2DistCode.put("normal", ProbabilityDistributionCode.NORMAL);
        String2DistCode.put("exp", ProbabilityDistributionCode.EXP);
        String2DistCode.put("chisq", ProbabilityDistributionCode.CHISQ);
        String2DistCode.put(GPUInstruction.MISC_TIMER_CUDA_FREE, ProbabilityDistributionCode.F);
        String2DistCode.put("t", ProbabilityDistributionCode.T);
        normalObj = null;
        expObj = null;
        chisqObj = null;
        fObj = null;
        tObj = null;
        inormalObj = null;
        iexpObj = null;
        ichisqObj = null;
        ifObj = null;
        itObj = null;
    }
}
