package org.apache.sysds.hops;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.FunctionCallCP;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/FunctionOp.class */
public class FunctionOp extends Hop {
    public static final String OPCODE = "fcall";
    private FunctionType _type;
    private String _fnamespace;
    private String _fname;
    private boolean _opt;
    private boolean _pseudo;
    private String[] _inputNames;
    private String[] _outputNames;
    private ArrayList<Hop> _outputHops;

    /* loaded from: input_file:org/apache/sysds/hops/FunctionOp$FunctionType.class */
    public enum FunctionType {
        DML,
        MULTIRETURN_BUILTIN,
        UNKNOWN
    }

    private FunctionOp() {
        this._type = null;
        this._fnamespace = null;
        this._fname = null;
        this._opt = true;
        this._pseudo = false;
        this._inputNames = null;
        this._outputNames = null;
        this._outputHops = null;
    }

    public FunctionOp(FunctionType functionType, String str, String str2, String[] strArr, List<Hop> list, String[] strArr2, ArrayList<Hop> arrayList) {
        this(functionType, str, str2, strArr, list, strArr2, false);
        this._outputHops = arrayList;
    }

    public FunctionOp(FunctionType functionType, String str, String str2, String[] strArr, List<Hop> list, String[] strArr2, boolean z) {
        this(functionType, str, str2, strArr, list, strArr2, z, false);
    }

    public FunctionOp(FunctionType functionType, String str, String str2, String[] strArr, List<Hop> list, String[] strArr2, boolean z, boolean z2) {
        super(str + "::" + str2, Types.DataType.UNKNOWN, Types.ValueType.UNKNOWN);
        this._type = null;
        this._fnamespace = null;
        this._fname = null;
        this._opt = true;
        this._pseudo = false;
        this._inputNames = null;
        this._outputNames = null;
        this._outputHops = null;
        this._type = functionType;
        this._fnamespace = str;
        this._fname = str2;
        this._inputNames = strArr;
        this._outputNames = strArr2;
        this._pseudo = z2;
        for (Hop hop : list) {
            getInput().add(hop);
            hop.getParent().add(this);
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
    }

    public String getFunctionKey() {
        return DMLProgram.constructFunctionKey(getFunctionNamespace(), getFunctionName());
    }

    public String getFunctionNamespace() {
        return this._fnamespace;
    }

    public String getFunctionName() {
        return this._fname;
    }

    public void setFunctionName(String str) {
        this._fname = str;
    }

    public void setFunctionNamespace(String str) {
        this._fnamespace = str;
    }

    public void setInputVariableNames(String[] strArr) {
        this._inputNames = strArr;
    }

    public ArrayList<Hop> getOutputs() {
        return this._outputHops;
    }

    public String[] getInputVariableNames() {
        return this._inputNames;
    }

    public String[] getOutputVariableNames() {
        return this._outputNames;
    }

    public boolean containsOutput(String str) {
        return Arrays.stream(getOutputVariableNames()).anyMatch(str2 -> {
            return str2.equals(str);
        });
    }

    public FunctionType getFunctionType() {
        return this._type;
    }

    public void setCallOptimized(boolean z) {
        this._opt = z;
    }

    public boolean isPseudoFunctionCall() {
        return this._pseudo;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return false;
    }

    @Override // org.apache.sysds.hops.Hop
    public void computeMemEstimate(MemoTable memoTable) {
        if (this._type == FunctionType.DML) {
            this._memEstimate = 1.0d;
            return;
        }
        if (this._type == FunctionType.UNKNOWN) {
            this._memEstimate = 2.097152E7d;
            return;
        }
        if (this._type == FunctionType.MULTIRETURN_BUILTIN) {
            boolean z = true;
            Iterator<Hop> it = getOutputs().iterator();
            while (it.hasNext()) {
                z &= it.next().dimsKnown();
            }
            if (z) {
                long nnz = getNnz() >= 0 ? getNnz() : getLength();
                this._outputMemEstimate = computeOutputMemEstimate(getDim1(), getDim2(), nnz);
                this._processingMemEstimate = computeIntermediateMemEstimate(getDim1(), getDim2(), nnz);
            }
            this._memEstimate = getInputOutputSize();
        }
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        if (getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
            throw new RuntimeException("Invalid call of computeOutputMemEstimate in FunctionOp.");
        }
        if (getFunctionName().equalsIgnoreCase(GPUInstruction.MISC_TIMER_QR)) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 0.5d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 0.5d);
        }
        if (getFunctionName().equalsIgnoreCase("lu")) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 0.5d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 0.5d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0d / getOutputs().get(1).getDim2());
        }
        if (getFunctionName().equalsIgnoreCase("eigen")) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1L, 1.0d);
        }
        if (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0d);
        }
        if (getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0d);
        }
        if (getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
            return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0d) + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0d);
        }
        if (!getFunctionName().equalsIgnoreCase("svd")) {
            throw new RuntimeException("Invalid call of computeOutputMemEstimate in FunctionOp.");
        }
        long estimateSizeExactSparsity = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0d);
        long estimateSizeExactSparsity2 = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0d);
        return estimateSizeExactSparsity + estimateSizeExactSparsity2 + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0d);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        if (getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
            throw new RuntimeException("Invalid call of computeIntermediateMemEstimate in FunctionOp.");
        }
        if (getFunctionName().equalsIgnoreCase(GPUInstruction.MISC_TIMER_QR)) {
            return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0d);
        }
        if (getFunctionName().equalsIgnoreCase("lu")) {
            return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1L, 1.0d);
        }
        if (getFunctionName().equalsIgnoreCase("eigen")) {
            return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0d) + (3 * OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1L, 1.0d));
        }
        if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test") || getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        if (getFunctionName().equalsIgnoreCase("svd")) {
            return OptimizerUtils.estimateSizeExactSparsity(1L, getInput().get(0).getDim2(), 1.0d);
        }
        throw new RuntimeException("Invalid call of computeIntermediateMemEstimate in FunctionOp.");
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        throw new RuntimeException("Invalid call of inferOutputCharacteristics in FunctionOp.");
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test");
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        Types.ExecType optFindExecType = optFindExecType();
        ArrayList arrayList = new ArrayList();
        Iterator<Hop> it = getInput().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().constructLops());
        }
        FunctionCallCP functionCallCP = new FunctionCallCP(arrayList, this._fnamespace, this._fname, this._inputNames, this._outputNames, this._outputHops, this._opt, optFindExecType);
        setLineNumbers(functionCallCP);
        setLops(functionCallCP);
        return getLops();
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return OPCODE;
    }

    @Override // org.apache.sysds.hops.Hop
    protected Types.ExecType optFindExecType(boolean z) {
        checkAndSetForcedPlatform();
        if (getFunctionType() == FunctionType.MULTIRETURN_BUILTIN) {
            boolean isBuiltinFunction = isBuiltinFunction();
            if (isBuiltinFunction && getFunctionName().equalsIgnoreCase("transformencode")) {
                this._etype = (this._etypeForced == Types.ExecType.SPARK || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? Types.ExecType.SPARK : Types.ExecType.CP;
            } else if (isBuiltinFunction && (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) {
                if (!DMLScript.USE_ACCELERATOR) {
                    throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
                }
                this._etype = Types.ExecType.GPU;
            } else if (isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
                this._etype = DMLScript.USE_ACCELERATOR ? Types.ExecType.GPU : Types.ExecType.CP;
            } else if (isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
                this._etype = Types.ExecType.GPU;
            } else {
                this._etype = Types.ExecType.CP;
            }
        } else {
            this._etype = Types.ExecType.CP;
        }
        return this._etype;
    }

    private boolean isBuiltinFunction() {
        return getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE);
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        FunctionOp functionOp = new FunctionOp();
        functionOp.clone(this, false);
        functionOp._type = this._type;
        functionOp._fnamespace = this._fnamespace;
        functionOp._fname = this._fname;
        functionOp._opt = this._opt;
        functionOp._inputNames = this._inputNames != null ? (String[]) this._inputNames.clone() : null;
        functionOp._outputNames = (String[]) this._outputNames.clone();
        if (this._outputHops != null) {
            functionOp._outputHops = (ArrayList) this._outputHops.clone();
        }
        return functionOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        return false;
    }

    @Override // org.apache.sysds.hops.Hop
    public String toString() {
        return getOpString();
    }
}
