package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeCell.class */
public class CNodeCell extends CNodeTpl {
    protected static final String JAVA_TEMPLATE = "package codegen;\nimport org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysds.runtime.codegen.SpoofCellwise;\nimport org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\nimport org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\nimport org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofCellwise {\n  public %TMP%() {\n    super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);\n  }\n  protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n%BODY_dense%    return %OUT%;\n  }\n}\n";
    private SpoofCellwise.CellType _type;
    private Types.AggOp _aggOp;
    private boolean _sparseSafe;
    private boolean _containsSeq;
    private boolean _requiresCastdtm;
    private boolean _multipleConsumers;

    public CNodeCell(ArrayList<CNode> arrayList, CNode cNode) {
        super(arrayList, cNode);
        this._type = null;
        this._aggOp = null;
        this._sparseSafe = false;
        this._containsSeq = true;
        this._requiresCastdtm = false;
        this._multipleConsumers = false;
    }

    public void setMultipleConsumers(boolean z) {
        this._multipleConsumers = z;
    }

    public boolean hasMultipleConsumers() {
        return this._multipleConsumers;
    }

    public void setCellType(SpoofCellwise.CellType cellType) {
        this._type = cellType;
        this._hash = 0;
    }

    public SpoofCellwise.CellType getCellType() {
        return this._type;
    }

    public void setAggOp(Types.AggOp aggOp) {
        this._aggOp = aggOp;
        this._hash = 0;
    }

    public Types.AggOp getAggOp() {
        return this._aggOp;
    }

    public SpoofCellwise.AggOp getSpoofAggOp() {
        if (this._aggOp == null) {
            return null;
        }
        switch (this._aggOp) {
            case SUM:
                return SpoofCellwise.AggOp.SUM;
            case SUM_SQ:
                return SpoofCellwise.AggOp.SUM_SQ;
            case MIN:
                return SpoofCellwise.AggOp.MIN;
            case MAX:
                return SpoofCellwise.AggOp.MAX;
            default:
                throw new RuntimeException("Unsupported cell type: " + this._type.toString());
        }
    }

    public void setSparseSafe(boolean z) {
        this._sparseSafe = z;
    }

    public boolean isSparseSafe() {
        return this._sparseSafe;
    }

    public void setContainsSeq(boolean z) {
        this._containsSeq = z;
    }

    public boolean containsSeq() {
        return this._containsSeq;
    }

    public void setRequiresCastDtm(boolean z) {
        this._requiresCastdtm = z;
        this._hash = 0;
    }

    public boolean requiredCastDtm() {
        return this._requiresCastdtm;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public void renameInputs() {
        rRenameDataNode(this._output, this._inputs.get(0), GPUInstruction.MISC_TIMER_ALLOCATE);
        renameInputs(this._inputs, 1);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public String codegen(boolean z, SpoofCompiler.GeneratorAPI generatorAPI) {
        this.api = generatorAPI;
        String languageTemplate = getLanguageTemplate(this, this.api);
        String codegen = this._output.codegen(false, this.api);
        if (this.api == SpoofCompiler.GeneratorAPI.CUDA) {
            codegen = codegen.replace("a.vals(0)", GPUInstruction.MISC_TIMER_ALLOCATE);
        }
        this._output.resetGenerated();
        String replace = languageTemplate.replace(this.api.isJava() ? "%TMP%" : "/*%TMP%*/SPOOF_OP_NAME", getVarname() == null ? createVarname() : getVarname());
        String replace2 = (codegen.contains("grix") ? replace.replace("//%NEED_GRIX%", "\t\tuint32_t grix=_grix + rix;") : replace.replace("//%NEED_GRIX%\n", "")).replace("%BODY_dense%", codegen);
        String replace3 = ((this.api.isJava() || this._aggOp != Types.AggOp.SUM_SQ) ? replace2.replaceAll("%OUT%", this._output.getVarname()) : replace2.replaceAll("%OUT%", this._output.getVarname() + " * " + this._output.getVarname())).replaceAll("%TYPE%", getCellType().name()).replace("%AGG_OP_NAME%", this._aggOp != null ? "AggOp." + this._aggOp.name() : ProgramConverter.EMPTY).replace("%SPARSE_SAFE%", String.valueOf(isSparseSafe())).replace("%SEQ%", String.valueOf(containsSeq()));
        if (this.api == SpoofCompiler.GeneratorAPI.CUDA) {
            String str = "IdentityOp";
            String str2 = "(T)1.0";
            if (this._aggOp != null) {
                switch (this._aggOp) {
                    case SUM:
                    case SUM_SQ:
                        str = "SumOp";
                        str2 = "(T)0.0";
                        break;
                    case MIN:
                        str = "MinOp";
                        str2 = "MAX<T>()";
                        break;
                    case MAX:
                        str = "MaxOp";
                        str2 = "-MAX<T>()";
                        break;
                    default:
                        str = "IdentityOp";
                        str2 = "(T)0.0";
                        break;
                }
            }
            replace3 = replace3.replaceAll("%AGG_OP%", str).replaceAll("%INITIAL_VALUE%", str2);
        }
        return replace3;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public void setOutputDims() {
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    /* renamed from: clone */
    public CNodeTpl mo143clone() {
        CNodeCell cNodeCell = new CNodeCell(this._inputs, this._output);
        cNodeCell.setDataType(getDataType());
        cNodeCell.setCellType(getCellType());
        cNodeCell.setMultipleConsumers(hasMultipleConsumers());
        return cNodeCell;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        switch (this._type) {
            case NO_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS;
            case ROW_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.ROW_DIMS;
            case COL_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_DIMS_COLS;
            case FULL_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.SCALAR;
            default:
                throw new RuntimeException("Unsupported cell type: " + this._type.toString());
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(UtilFunctions.intHashCode(UtilFunctions.intHashCode(UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode()), this._aggOp != null ? this._aggOp.hashCode() : 0), Boolean.hashCode(this._sparseSafe)), Boolean.hashCode(this._requiresCastdtm));
        }
        return this._hash;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeCell)) {
            return false;
        }
        CNodeCell cNodeCell = (CNodeCell) obj;
        return super.equals(cNodeCell) && this._type == cNodeCell._type && this._aggOp == cNodeCell._aggOp && this._sparseSafe == cNodeCell._sparseSafe && this._requiresCastdtm == cNodeCell._requiresCastdtm && equalInputReferences(this._output, cNodeCell._output, this._inputs, cNodeCell._inputs);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public String getTemplateInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("SPOOF CELLWISE [type=");
        sb.append(this._type.name());
        sb.append(", aggOp=" + (this._aggOp != null ? this._aggOp.name() : ProgramConverter.EMPTY));
        sb.append(", sparseSafe=" + this._sparseSafe);
        sb.append(", castdtm=" + this._requiresCastdtm);
        sb.append(", mc=" + this._multipleConsumers);
        sb.append("]");
        return sb.toString();
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public boolean isSupported(SpoofCompiler.GeneratorAPI generatorAPI) {
        return (generatorAPI == SpoofCompiler.GeneratorAPI.CUDA || generatorAPI == SpoofCompiler.GeneratorAPI.JAVA) && this._output.isSupported(generatorAPI);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public int compile(SpoofCompiler.GeneratorAPI generatorAPI, String str) {
        if (generatorAPI == SpoofCompiler.GeneratorAPI.CUDA) {
            return compile_nvrtc(SpoofCompiler.native_contexts.get(generatorAPI).longValue(), this._genVar, str, this._type.getValue(), this._aggOp != null ? this._aggOp.getValue() : 0, this._sparseSafe);
        }
        return -1;
    }

    private native int compile_nvrtc(long j, String str, String str2, int i, int i2, boolean z);
}
