package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.commons.collections.CollectionUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.class */
public class CNodeMultiAgg extends CNodeTpl {
    private static final String TEMPLATE = "package codegen;\nimport org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysds.runtime.codegen.SpoofCellwise;\nimport org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\nimport org.apache.sysds.runtime.codegen.SpoofMultiAggregate;\nimport org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofMultiAggregate { \n  public %TMP%() {\n    super(%SPARSE_SAFE%, %AGG_OP%);\n  }\n  protected void genexec(double a, SideInput[] b, double[] scalars, double[] c, int m, int n, long grix, int rix, int cix) { \n%BODY_dense%  }\n}\n";
    private static final String TEMPLATE_OUT_SUM = "    c[%IX%] += %IN%;\n";
    private static final String TEMPLATE_OUT_SUMSQ = "    c[%IX%] += %IN% * %IN%;\n";
    private static final String TEMPLATE_OUT_MIN = "    c[%IX%] = Math.min(c[%IX%], %IN%);\n";
    private static final String TEMPLATE_OUT_MAX = "    c[%IX%] = Math.max(c[%IX%], %IN%);\n";
    private ArrayList<CNode> _outputs;
    private ArrayList<Types.AggOp> _aggOps;
    private ArrayList<Hop> _roots;
    private boolean _sparseSafe;

    public CNodeMultiAgg(ArrayList<CNode> arrayList, ArrayList<CNode> arrayList2) {
        super(arrayList, null);
        this._outputs = null;
        this._aggOps = null;
        this._roots = null;
        this._sparseSafe = false;
        this._outputs = arrayList2;
    }

    public ArrayList<CNode> getOutputs() {
        return this._outputs;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public void resetVisitStatusOutputs() {
        Iterator<CNode> it = this._outputs.iterator();
        while (it.hasNext()) {
            it.next().resetVisitStatus();
        }
    }

    public void setAggOps(ArrayList<Types.AggOp> arrayList) {
        this._aggOps = arrayList;
        this._hash = 0;
    }

    public ArrayList<Types.AggOp> getAggOps() {
        return this._aggOps;
    }

    public void setRootNodes(ArrayList<Hop> arrayList) {
        this._roots = arrayList;
    }

    public ArrayList<Hop> getRootNodes() {
        return this._roots;
    }

    public void setSparseSafe(boolean z) {
        this._sparseSafe = z;
    }

    public boolean isSparseSafe() {
        return this._sparseSafe;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public void renameInputs() {
        rRenameDataNode(this._outputs, this._inputs.get(0), GPUInstruction.MISC_TIMER_ALLOCATE);
        renameInputs(this._outputs, this._inputs, 1);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public String codegen(boolean z, SpoofCompiler.GeneratorAPI generatorAPI) {
        StringBuilder sb = new StringBuilder();
        Iterator<CNode> it = this._outputs.iterator();
        while (it.hasNext()) {
            sb.append(it.next().codegen(false, generatorAPI));
        }
        Iterator<CNode> it2 = this._outputs.iterator();
        while (it2.hasNext()) {
            it2.next().resetGenerated();
        }
        for (int i = 0; i < this._outputs.size(); i++) {
            CNode cNode = this._outputs.get(i);
            sb.append(getAggTemplate(i).replace("%IN%", ((cNode instanceof CNodeData) && ((CNodeData) cNode).getHopID() == ((CNodeData) this._inputs.get(0)).getHopID()) ? GPUInstruction.MISC_TIMER_ALLOCATE : cNode.getVarname()).replace("%IX%", String.valueOf(i)));
        }
        String replace = TEMPLATE.replace("%TMP%", createVarname()).replace("%BODY_dense%", sb.toString());
        String str = "";
        Iterator<Types.AggOp> it3 = this._aggOps.iterator();
        while (it3.hasNext()) {
            str = (str + (!str.isEmpty() ? "," : "")) + "AggOp." + it3.next().name();
        }
        return replace.replace("%AGG_OP%", str).replace("%SPARSE_SAFE%", String.valueOf(isSparseSafe()));
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public void setOutputDims() {
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        return SpoofFusedOp.SpoofOutputDimsType.MULTI_SCALAR;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    /* renamed from: clone */
    public CNodeTpl mo133clone() {
        CNodeMultiAgg cNodeMultiAgg = new CNodeMultiAgg(this._inputs, this._outputs);
        cNodeMultiAgg.setAggOps(getAggOps());
        return cNodeMultiAgg;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            int hashCode = super.hashCode();
            for (int i = 0; i < this._outputs.size(); i++) {
                hashCode = UtilFunctions.intHashCode(hashCode, UtilFunctions.intHashCode(this._outputs.get(i).hashCode(), this._aggOps.get(i).hashCode()));
            }
            this._hash = hashCode;
        }
        return this._hash;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeMultiAgg)) {
            return false;
        }
        CNodeMultiAgg cNodeMultiAgg = (CNodeMultiAgg) obj;
        return super.equals(obj) && CollectionUtils.isEqualCollection(this._aggOps, cNodeMultiAgg._aggOps) && equalInputReferences(this._outputs, cNodeMultiAgg._outputs, this._inputs, cNodeMultiAgg._inputs);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public String getTemplateInfo() {
        return "SPOOF MULTIAGG [aggOps=" + Arrays.toString(this._aggOps.toArray(new Types.AggOp[0])) + "]";
    }

    private String getAggTemplate(int i) {
        switch (this._aggOps.get(i)) {
            case SUM:
                return TEMPLATE_OUT_SUM;
            case SUM_SQ:
                return TEMPLATE_OUT_SUMSQ;
            case MIN:
                return TEMPLATE_OUT_MIN;
            case MAX:
                return TEMPLATE_OUT_MAX;
            default:
                return null;
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public boolean isSupported(SpoofCompiler.GeneratorAPI generatorAPI) {
        boolean z = generatorAPI == SpoofCompiler.GeneratorAPI.JAVA;
        int i = 0;
        while (z && i < this._inputs.size()) {
            int i2 = i;
            i++;
            z = this._inputs.get(i2).isSupported(generatorAPI);
        }
        return z;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public int compile(SpoofCompiler.GeneratorAPI generatorAPI, String str) {
        if (generatorAPI == SpoofCompiler.GeneratorAPI.CUDA) {
            return compile_nvrtc(SpoofCompiler.native_contexts.get(generatorAPI).longValue(), this._genVar, str, this._sparseSafe);
        }
        return -1;
    }

    private native int compile_nvrtc(long j, String str, String str2, boolean z);
}
