package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeNary.class */
public class CNodeNary extends CNode {
    private final NaryType _type;

    /* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeNary$NaryType.class */
    public enum NaryType {
        VECT_CBIND,
        VECT_MAX_POOL,
        VECT_AVG_POOL,
        VECT_IM2COL,
        VECT_CONV2DMM;

        public static boolean contains(String str) {
            for (NaryType naryType : values()) {
                if (naryType.name().equals(str)) {
                    return true;
                }
            }
            return false;
        }

        public String getTemplate(boolean z, long j, ArrayList<CNode> arrayList, SpoofCompiler.GeneratorAPI generatorAPI) {
            switch (this) {
                case VECT_CBIND:
                    StringBuilder sb = new StringBuilder();
                    sb.append("    double[] %TMP% = LibSpoofPrimitives.allocVector(" + j + ", true); //nary cbind\n");
                    int i = 0;
                    for (int i2 = 0; i2 < arrayList.size(); i2++) {
                        CNode cNode = arrayList.get(i2);
                        boolean z2 = z && (cNode instanceof CNodeData) && cNode.getVarname().startsWith(GPUInstruction.MISC_TIMER_ALLOCATE);
                        String varname = cNode.getVarname();
                        if (cNode.getDataType() == Types.DataType.MATRIX) {
                            String str = cNode instanceof CNodeData ? !varname.startsWith("b") ? varname + "i" : varname + ".pos(rix)" : "0";
                            sb.append(z2 ? "    LibSpoofPrimitives.vectWrite(" + varname + "vals, %TMP%, " + varname + "ix, " + str + ", " + i + ", " + cNode._cols + ");\n" : "    LibSpoofPrimitives.vectWrite(" + (varname.startsWith("b") ? varname + ".values(rix)" : varname) + ", %TMP%, " + str + ", " + i + ", " + cNode._cols + ");\n");
                            i = (int) (i + cNode._cols);
                        } else {
                            sb.append("    %TMP%[" + i + "] = " + varname + ";\n");
                            i++;
                        }
                    }
                    return sb.toString();
                case VECT_MAX_POOL:
                case VECT_AVG_POOL:
                    String str2 = this == VECT_MAX_POOL ? "Maxpool" : "Avgpool";
                    String dnnParameterString = CNodeNary.getDnnParameterString(arrayList, true);
                    return z ? "    double[] %TMP% = LibSpoofPrimitives.vect" + str2 + "Write(%IN1v%, %IN1i%, %POS1%, alen, len, " + dnnParameterString + ");\n" : "    double[] %TMP% = LibSpoofPrimitives.vect" + str2 + "Write(%IN1%, %POS1%, %LEN%, " + dnnParameterString + ");\n";
                case VECT_IM2COL:
                    String dnnParameterString2 = CNodeNary.getDnnParameterString(arrayList, true);
                    return z ? "    double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1v%, %IN1i%, %POS1%, alen, len, " + dnnParameterString2 + ");\n" : "    double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1%, %POS1%, %LEN%, " + dnnParameterString2 + ");\n";
                case VECT_CONV2DMM:
                    return "    double[] %TMP% = LibSpoofPrimitives.vectConv2dmmWrite(%IN2%, %IN1%, %POS2%, %POS1%, %LEN%, " + CNodeNary.getDnnParameterString(arrayList, false) + ");\n";
                default:
                    throw new RuntimeException("Invalid nary type: " + toString());
            }
        }

        public boolean isVectorPrimitive() {
            return this == VECT_CBIND || this == VECT_MAX_POOL || this == VECT_AVG_POOL || this == VECT_IM2COL || this == VECT_CONV2DMM;
        }
    }

    public CNodeNary(CNode[] cNodeArr, NaryType naryType) {
        for (CNode cNode : cNodeArr) {
            this._inputs.add(cNode);
        }
        this._type = naryType;
        setOutputDims();
    }

    public NaryType getType() {
        return this._type;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public String codegen(boolean z, SpoofCompiler.GeneratorAPI generatorAPI) {
        if (isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        Iterator<CNode> it = this._inputs.iterator();
        while (it.hasNext()) {
            sb.append(it.next().codegen(z, generatorAPI));
        }
        String replace = this._type.getTemplate(z && (this._inputs.get(0) instanceof CNodeData) && this._inputs.get(0).getVarname().startsWith(GPUInstruction.MISC_TIMER_ALLOCATE) && !this._inputs.get(0).isLiteral(), this._cols, this._inputs, generatorAPI).replace("%TMP%", createVarname());
        String varname = this._inputs.get(0).getVarname();
        sb.append(this._type == NaryType.VECT_CONV2DMM ? replaceBinaryPlaceholders(replace, new String[]{varname, this._inputs.get(1).getVarname()}, false, generatorAPI) : replaceUnaryPlaceholders(replace, varname, false, generatorAPI));
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case VECT_CBIND:
                return "n(cbind)";
            case VECT_MAX_POOL:
                return "n(maxpool)";
            case VECT_AVG_POOL:
                return "n(avgpool)";
            case VECT_IM2COL:
                return "n(im2col)";
            case VECT_CONV2DMM:
                return "n(conv2dmm)";
            default:
                return "m(" + this._type.name().toLowerCase() + ")";
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public void setOutputDims() {
        switch (this._type) {
            case VECT_CBIND:
                this._rows = this._inputs.get(0)._rows;
                this._cols = 0L;
                Iterator<CNode> it = this._inputs.iterator();
                while (it.hasNext()) {
                    this._cols += it.next()._cols;
                }
                this._dataType = Types.DataType.MATRIX;
                return;
            case VECT_MAX_POOL:
            case VECT_AVG_POOL:
                int parseInt = Integer.parseInt(this._inputs.get(6).getVarname());
                int parseInt2 = Integer.parseInt(this._inputs.get(7).getVarname());
                int parseInt3 = Integer.parseInt(this._inputs.get(8).getVarname());
                int parseInt4 = Integer.parseInt(this._inputs.get(11).getVarname());
                int parseInt5 = Integer.parseInt(this._inputs.get(12).getVarname());
                long p = DnnUtils.getP(parseInt2, parseInt4, 1L, 0L);
                long q = DnnUtils.getQ(parseInt3, parseInt5, 1L, 0L);
                this._rows = this._inputs.get(0)._rows;
                this._cols = parseInt * p * q;
                this._dataType = Types.DataType.MATRIX;
                return;
            case VECT_IM2COL:
                this._rows = 1L;
                this._cols = -1L;
                this._dataType = Types.DataType.MATRIX;
                return;
            case VECT_CONV2DMM:
                int parseInt6 = Integer.parseInt(this._inputs.get(8).getVarname());
                int parseInt7 = Integer.parseInt(this._inputs.get(9).getVarname());
                int parseInt8 = Integer.parseInt(this._inputs.get(10).getVarname());
                int parseInt9 = Integer.parseInt(this._inputs.get(12).getVarname());
                int parseInt10 = Integer.parseInt(this._inputs.get(13).getVarname());
                long p2 = DnnUtils.getP(parseInt6, parseInt9, 1L, 0L);
                long q2 = DnnUtils.getQ(parseInt7, parseInt10, 1L, 0L);
                this._rows = this._inputs.get(0)._rows;
                this._cols = parseInt8 * p2 * q2;
                this._dataType = Types.DataType.MATRIX;
                return;
            default:
                return;
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeNary)) {
            return false;
        }
        CNodeNary cNodeNary = (CNodeNary) obj;
        return super.equals(cNodeNary) && this._type == cNodeNary._type;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public boolean isSupported(SpoofCompiler.GeneratorAPI generatorAPI) {
        boolean z = generatorAPI == SpoofCompiler.GeneratorAPI.JAVA;
        int i = 0;
        while (z && i < this._inputs.size()) {
            int i2 = i;
            i++;
            z = this._inputs.get(i2).isSupported(generatorAPI);
        }
        return z;
    }

    private static String getDnnParameterString(List<CNode> list, boolean z) {
        int i = z ? 0 : 1;
        int parseInt = Integer.parseInt(list.get(i + 6).getVarname());
        int parseInt2 = Integer.parseInt(list.get(i + 7).getVarname());
        int parseInt3 = Integer.parseInt(list.get(i + 8).getVarname());
        int parseInt4 = Integer.parseInt(list.get(i + 9).getVarname());
        int parseInt5 = Integer.parseInt(list.get(i + 11).getVarname());
        int parseInt6 = Integer.parseInt(list.get(i + 12).getVarname());
        return "rix, " + StringUtils.join(new int[]{parseInt, (int) DnnUtils.getP(parseInt2, parseInt5, 1L, 0L), (int) DnnUtils.getQ(parseInt3, parseInt6, 1L, 0L), parseInt4, parseInt5, parseInt6, parseInt2, parseInt3}, ',');
    }

    private String replaceBinaryPlaceholders(String str, String[] strArr, boolean z, SpoofCompiler.GeneratorAPI generatorAPI) {
        for (int i = 0; i < 2; i++) {
            String str2 = strArr[i];
            str = str.replace("%IN" + (i + 1) + "v%", str2 + "vals").replace("%IN" + (i + 1) + "i%", str2 + "ix").replace("%IN" + (i + 1) + "%", str2.startsWith("b") ? generatorAPI == SpoofCompiler.GeneratorAPI.JAVA ? str2 + ".values(rix)" : str2 + ".vals(0)" : str2).replace("%POS" + (i + 1) + "%", ((this._inputs.get(i) instanceof CNodeData) && this._inputs.get(i).getDataType().isMatrix()) ? !str2.startsWith("b") ? str2 + "i" : (!TemplateUtils.isMatrix(this._inputs.get(i)) || this._type == NaryType.VECT_CONV2DMM) ? "0" : str2 + ".pos(rix)" : "0");
        }
        if (this._inputs.get(0).getDataType().isMatrix()) {
            str = str.replace("%LEN%", this._inputs.get(0).getVectorLength(generatorAPI));
        }
        return str;
    }
}
