package org.apache.sysds.parser;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/parser/FunctionStatementBlock.class */
public class FunctionStatementBlock extends StatementBlock implements Types.FunctionBlock {
    private boolean _recompileOnce = false;
    private boolean _nondeterministic = false;

    @Override // org.apache.sysds.parser.StatementBlock
    public VariableSet validate(DMLProgram dMLProgram, VariableSet variableSet, HashMap<String, ConstIdentifier> hashMap, boolean z) {
        if (this._statements.size() > 1) {
            throw new LanguageException(printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (FunctionStatement)");
        }
        FunctionStatement functionStatement = (FunctionStatement) this._statements.get(0);
        for (DataIdentifier dataIdentifier : functionStatement.getInputParams()) {
            if (dataIdentifier.getDataType() == Types.DataType.MATRIX && dataIdentifier.getValueType() != Types.ValueType.FP64) {
                raiseValidateError("for function " + functionStatement.getName() + ", input variable " + dataIdentifier.getName() + " has an unsupported value type of " + dataIdentifier.getValueType() + ".", false);
            }
        }
        this._dmlProg = dMLProgram;
        Iterator<StatementBlock> it = functionStatement.getBody().iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            variableSet = next.validate(dMLProgram, variableSet, hashMap, z);
            hashMap = next.getConstOut();
        }
        if (functionStatement.getBody().size() > 0) {
            this._constVarsIn.putAll(functionStatement.getBody().get(0).getConstIn());
        }
        if (functionStatement.getBody().size() > 1) {
            this._constVarsOut.putAll(functionStatement.getBody().get(functionStatement.getBody().size() - 1).getConstOut());
        }
        for (DataIdentifier dataIdentifier2 : functionStatement.getOutputParams()) {
            DataIdentifier variable = variableSet.getVariable(dataIdentifier2.getName());
            if (variable == null) {
                raiseValidateError("for function " + functionStatement.getName() + ", return variable " + dataIdentifier2.getName() + " must be defined in function ", z);
            }
            if (variable.getDataType() != Types.DataType.UNKNOWN && !variable.getDataType().equals(dataIdentifier2.getDataType())) {
                raiseValidateError("for function " + functionStatement.getName() + ", return variable " + variable.getName() + " data type of " + variable.getDataType() + " does not match data type in function signature of " + dataIdentifier2.getDataType(), z);
            }
            if (variable.getValueType() != Types.ValueType.UNKNOWN && dataIdentifier2.getValueType() != Types.ValueType.UNKNOWN && !variable.getValueType().equals(dataIdentifier2.getValueType())) {
                if (variable.getDataType() != Types.DataType.SCALAR || dataIdentifier2.getDataType() != Types.DataType.SCALAR) {
                    throw new LanguageException(variable.printErrorLocation() + "for function " + functionStatement.getName() + ", return variable " + variable.getName() + " value type of " + variable.getValueType() + " does not match value type in function signature of " + dataIdentifier2.getValueType() + " and cannot safely cast " + variable.getValueType() + " as " + dataIdentifier2.getValueType());
                }
                if (dataIdentifier2.getValueType() == Types.ValueType.FP64) {
                    if (variable.getValueType() != Types.ValueType.INT64) {
                        throw new LanguageException(variable.printErrorLocation() + "for function " + functionStatement.getName() + ", return variable " + variable.getName() + " value type of " + variable.getValueType() + " does not match value type in function signature of " + dataIdentifier2.getValueType() + " and cannot safely cast value");
                    }
                    if (((IntIdentifier) hashMap.get(variable.getName())) != null) {
                        hashMap.put(variable.getName(), new DoubleIdentifier(r0.getValue(), variable));
                    }
                    LOG.warn(variable.printWarningLocation() + "for function " + functionStatement.getName() + ", return variable " + variable.getName() + " value type of " + variable.getValueType() + " does not match value type in function signature of " + dataIdentifier2.getValueType() + " but was safely cast");
                    variable.setValueType(Types.ValueType.FP64);
                    variableSet.addVariable(variable.getName(), variable);
                }
                if (dataIdentifier2.getValueType() == Types.ValueType.INT64) {
                    throw new LanguageException(variable.printErrorLocation() + "for function " + functionStatement.getName() + ", return variable " + variable.getName() + " value type of " + variable.getValueType() + " does not match value type in function signature of " + dataIdentifier2.getValueType() + " and cannot safely cast " + variable.getValueType() + " as " + dataIdentifier2.getValueType());
                }
            }
        }
        return variableSet;
    }

    public FunctionOp.FunctionType getFunctionOpType() {
        return FunctionOp.FunctionType.DML;
    }

    @Override // org.apache.sysds.parser.StatementBlock, org.apache.sysds.parser.LiveVariableAnalysis
    public VariableSet initializeforwardLV(VariableSet variableSet) {
        FunctionStatement functionStatement = (FunctionStatement) this._statements.get(0);
        if (this._statements.size() > 1) {
            throw new LanguageException(printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (while statement)");
        }
        this._read = new VariableSet();
        this._gen = new VariableSet();
        VariableSet variableSet2 = new VariableSet();
        variableSet2.addVariables(variableSet);
        Iterator<StatementBlock> it = functionStatement.getBody().iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            variableSet2 = next.initializeforwardLV(variableSet2);
            for (String str : next._gen.getVariableNames()) {
                if (!this._kill.getVariableNames().contains(str)) {
                    this._gen.addVariable(str, next._gen.getVariable(str));
                }
            }
            this._read.addVariables(next._read);
            this._updated.addVariables(next._updated);
            if (!(next instanceof WhileStatementBlock) && !(next instanceof ForStatementBlock)) {
                this._kill.addVariables(next._kill);
            }
        }
        this._liveOut = new VariableSet();
        this._liveOut.addVariables(variableSet2);
        this._liveOut.addVariables(this._updated);
        return this._liveOut;
    }

    @Override // org.apache.sysds.parser.StatementBlock, org.apache.sysds.parser.LiveVariableAnalysis
    public VariableSet initializebackwardLV(VariableSet variableSet) {
        FunctionStatement functionStatement = (FunctionStatement) this._statements.get(0);
        VariableSet variableSet2 = new VariableSet();
        variableSet2.addVariables(variableSet);
        for (int size = functionStatement.getBody().size() - 1; size >= 0; size--) {
            variableSet2 = functionStatement.getBody().get(size).analyze(variableSet2);
        }
        VariableSet variableSet3 = new VariableSet();
        variableSet3.addVariables(variableSet2);
        return variableSet3;
    }

    @Override // org.apache.sysds.parser.StatementBlock, org.apache.sysds.parser.LiveVariableAnalysis
    public VariableSet analyze(VariableSet variableSet) {
        throw new LanguageException(printBlockErrorLocation() + "Both liveIn and liveOut variables need to be specified for liveness analysis for FunctionStatementBlock");
    }

    public VariableSet analyze(VariableSet variableSet, VariableSet variableSet2) {
        VariableSet variableSet3 = new VariableSet();
        variableSet3.addVariables(variableSet2);
        variableSet3.addVariables(this._gen);
        VariableSet variableSet4 = new VariableSet();
        variableSet4.addVariables(this._liveOut);
        this._liveOut = new VariableSet();
        for (String str : variableSet3.getVariableNames()) {
            if (variableSet4.containsVariable(str)) {
                this._liveOut.addVariable(str, variableSet3.getVariable(str));
            }
        }
        initializebackwardLV(this._liveOut);
        this._liveIn = new VariableSet();
        this._liveIn.addVariables(variableSet);
        VariableSet variableSet5 = new VariableSet();
        variableSet5.addVariables(this._liveIn);
        return variableSet5;
    }

    public void setRecompileOnce(boolean z) {
        this._recompileOnce = z;
    }

    public boolean isRecompileOnce() {
        return this._recompileOnce;
    }

    @Override // org.apache.sysds.parser.StatementBlock
    public void setNondeterministic(boolean z) {
        this._nondeterministic = z;
    }

    @Override // org.apache.sysds.parser.StatementBlock
    public boolean isNondeterministic() {
        return this._nondeterministic;
    }

    @Override // org.apache.sysds.common.Types.FunctionBlock
    public Types.FunctionBlock cloneFunctionBlock() {
        return ProgramConverter.createDeepCopyFunctionStatementBlock(this, new HashSet(), new HashSet());
    }

    @Override // org.apache.sysds.parser.StatementBlock
    public void updateRepetitionEstimates(double d) {
        Iterator<Statement> it = getStatements().iterator();
        while (it.hasNext()) {
            Iterator<StatementBlock> it2 = ((FunctionStatement) it.next()).getBody().iterator();
            while (it2.hasNext()) {
                it2.next().updateRepetitionEstimates(d);
            }
        }
    }
}
