package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeRow.class */
public class CNodeRow extends CNodeTpl {
    private static final String TEMPLATE = "package codegen;\nimport org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.sysds.runtime.codegen.SpoofRowwise;\nimport org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofRowwise { \n  public %TMP%() {\n    super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n  }\n  protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n%BODY_dense%  }\n  protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n%BODY_sparse%  }\n}\n";
    private static final String TEMPLATE_ROWAGG_OUT = "    c[rix] = %IN%;\n";
    private static final String TEMPLATE_FULLAGG_OUT = "    c[0] += %IN%;\n";
    private static final String TEMPLATE_NOAGG_OUT = "    LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
    private SpoofRowwise.RowType _type;
    private long _constDim2;
    private int _numVectors;

    public CNodeRow(ArrayList<CNode> arrayList, CNode cNode) {
        super(arrayList, cNode);
        this._type = null;
        this._constDim2 = -1L;
        this._numVectors = -1;
    }

    public void setRowType(SpoofRowwise.RowType rowType) {
        this._type = rowType;
        this._hash = 0;
    }

    public SpoofRowwise.RowType getRowType() {
        return this._type;
    }

    public void setNumVectorIntermediates(int i) {
        this._numVectors = i;
        this._hash = 0;
    }

    public int getNumVectorIntermediates() {
        return this._numVectors;
    }

    public void setConstDim2(long j) {
        this._constDim2 = j;
        this._hash = 0;
    }

    public long getConstDim2() {
        return this._constDim2;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public void renameInputs() {
        rRenameDataNode(this._output, this._inputs.get(0), GPUInstruction.MISC_TIMER_ALLOCATE);
        renameInputs(this._inputs, 1);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public String codegen(boolean z) {
        String str = this._output.codegen(false) + getOutputStatement(this._output.getVarname());
        this._output.resetGenerated();
        return TEMPLATE.replace("%TMP%", createVarname()).replace("%BODY_dense%", str).replace("%BODY_sparse%", this._output.codegen(true) + getOutputStatement(this._output.getVarname())).replace("%OUT%", "c").replace("%POSOUT%", "0").replace("%LEN%", "len").replace("%TYPE%", this._type.name()).replace("%CONST_DIM2%", String.valueOf(this._constDim2)).replace("%TB1%", String.valueOf(TemplateUtils.containsBinary(this._output, CNodeBinary.BinType.VECT_MATRIXMULT))).replace("%VECT_MEM%", String.valueOf(this._numVectors));
    }

    private String getOutputStatement(String str) {
        switch (this._type) {
            case NO_AGG:
            case NO_AGG_B1:
            case NO_AGG_CONST:
                return TEMPLATE_NOAGG_OUT.replace("%IN%", str).replace("%LEN%", this._output.getVarname() + ".length");
            case FULL_AGG:
                return TEMPLATE_FULLAGG_OUT.replace("%IN%", str);
            case ROW_AGG:
                return TEMPLATE_ROWAGG_OUT.replace("%IN%", str);
            default:
                return "";
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public void setOutputDims() {
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        switch (this._type) {
            case NO_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS;
            case NO_AGG_B1:
                return SpoofFusedOp.SpoofOutputDimsType.ROW_RANK_DIMS;
            case NO_AGG_CONST:
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS_CONST2;
            case FULL_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.SCALAR;
            case ROW_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.ROW_DIMS;
            case COL_AGG:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_DIMS_COLS;
            case COL_AGG_T:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_DIMS_ROWS;
            case COL_AGG_B1:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_RANK_DIMS;
            case COL_AGG_B1_T:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
            case COL_AGG_B1R:
                return SpoofFusedOp.SpoofOutputDimsType.RANK_DIMS_COLS;
            case COL_AGG_CONST:
                return SpoofFusedOp.SpoofOutputDimsType.VECT_CONST2;
            default:
                throw new RuntimeException("Unsupported row type: " + this._type.toString());
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    /* renamed from: clone */
    public CNodeTpl mo131clone() {
        CNodeRow cNodeRow = new CNodeRow(this._inputs, this._output);
        cNodeRow.setRowType(this._type);
        cNodeRow.setNumVectorIntermediates(this._numVectors);
        return cNodeRow;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(UtilFunctions.intHashCode(UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode()), Long.hashCode(this._constDim2)), Integer.hashCode(this._numVectors));
        }
        return this._hash;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeRow)) {
            return false;
        }
        CNodeRow cNodeRow = (CNodeRow) obj;
        return super.equals(obj) && this._type == cNodeRow._type && this._numVectors == cNodeRow._numVectors && this._constDim2 == cNodeRow._constDim2 && equalInputReferences(this._output, cNodeRow._output, this._inputs, cNodeRow._inputs);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public String getTemplateInfo() {
        return "SPOOF ROWAGGREGATE [type=" + this._type.name() + ", reqVectMem=" + this._numVectors + "]";
    }
}
