package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.class */
public class CNodeOuterProduct extends CNodeTpl {
    private static final String TEMPLATE = "package codegen;\nimport org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.sysds.runtime.codegen.SpoofOuterProduct;\nimport org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofOuterProduct { \n  public %TMP%() {\n    super(OutProdType.%TYPE%);\n  }\n  protected void genexecDense(double a, double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars, double[] c, int ci, int m, int n, int len, int rix, int cix) { \n%BODY_dense%  }\n  protected double genexecCellwise(double a, double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars, int m, int n, int len, int rix, int cix) { \n%BODY_cellwise%    return %OUT_cellwise%;\n  }\n}\n";
    private SpoofOuterProduct.OutProdType _type;
    MMTSJ.MMTSJType _mmtsj;
    private boolean _transposeOutput;

    public MMTSJ.MMTSJType getMMTSJtype() {
        return this._mmtsj;
    }

    public CNodeOuterProduct(ArrayList<CNode> arrayList, CNode cNode, MMTSJ.MMTSJType mMTSJType) {
        super(arrayList, cNode);
        this._type = null;
        this._transposeOutput = false;
        this._mmtsj = mMTSJType;
        if (this._mmtsj != MMTSJ.MMTSJType.NONE) {
            this._inputs.add(1, arrayList.get(1));
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public void renameInputs() {
        rRenameDataNode(this._output, this._inputs.get(0), GPUInstruction.MISC_TIMER_ALLOCATE);
        rRenameDataNode(this._output, this._inputs.get(1), "a1");
        rRenameDataNode(this._output, this._inputs.get(2), "a2");
        renameInputs(this._inputs, 3);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public String codegen(boolean z, SpoofCompiler.GeneratorAPI generatorAPI) {
        String codegen = this._output.codegen(false, generatorAPI);
        this._output.resetGenerated();
        String replace = TEMPLATE.replace("%TMP%", createVarname());
        return ((this._type == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT || this._type == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) ? replace.replace("%BODY_dense%", codegen).replace("%OUT%", "c").replace("%BODY_cellwise%", "").replace("%OUT_cellwise%", "0") : replace.replace("%BODY_dense%", "").replace("%BODY_cellwise%", codegen).replace("%OUT_cellwise%", this._output.getVarname())).replace("%LEN%", "len").replace("%POSOUT%", "ci").replace("%TYPE%", this._type.toString());
    }

    public void setOutProdType(SpoofOuterProduct.OutProdType outProdType) {
        this._type = outProdType;
        this._hash = 0;
    }

    public SpoofOuterProduct.OutProdType getOutProdType() {
        return this._type;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public void setOutputDims() {
    }

    public void setTransposeOutput(boolean z) {
        this._transposeOutput = z;
        this._hash = 0;
    }

    public boolean isTransposeOutput() {
        return this._transposeOutput;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        switch (this._type) {
            case LEFT_OUTER_PRODUCT:
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_RANK_DIMS;
            case RIGHT_OUTER_PRODUCT:
                return SpoofFusedOp.SpoofOutputDimsType.ROW_RANK_DIMS;
            case CELLWISE_OUTER_PRODUCT:
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS;
            case AGG_OUTER_PRODUCT:
                return SpoofFusedOp.SpoofOutputDimsType.SCALAR;
            default:
                throw new RuntimeException("Unsupported outer product type: " + this._type.toString());
        }
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    /* renamed from: clone */
    public CNodeTpl mo132clone() {
        return new CNodeOuterProduct(this._inputs, this._output, this._mmtsj);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode()), Boolean.hashCode(this._transposeOutput));
        }
        return this._hash;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl, org.apache.sysds.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeOuterProduct)) {
            return false;
        }
        CNodeOuterProduct cNodeOuterProduct = (CNodeOuterProduct) obj;
        return super.equals(cNodeOuterProduct) && this._type == cNodeOuterProduct._type && this._transposeOutput == cNodeOuterProduct._transposeOutput && equalInputReferences(this._output, cNodeOuterProduct._output, this._inputs, cNodeOuterProduct._inputs);
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public String getTemplateInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("SPOOF OUTERPRODUCT [type=");
        sb.append(this._type.name());
        sb.append(", to=" + this._transposeOutput);
        sb.append("]");
        return sb.toString();
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNode
    public boolean isSupported(SpoofCompiler.GeneratorAPI generatorAPI) {
        boolean z = generatorAPI == SpoofCompiler.GeneratorAPI.JAVA;
        int i = 0;
        while (z && i < this._inputs.size()) {
            int i2 = i;
            i++;
            z = this._inputs.get(i2).isSupported(generatorAPI);
        }
        return z;
    }

    @Override // org.apache.sysds.hops.codegen.cplan.CNodeTpl
    public int compile(SpoofCompiler.GeneratorAPI generatorAPI, String str) {
        if (generatorAPI == SpoofCompiler.GeneratorAPI.CUDA) {
            return compile_nvrtc(SpoofCompiler.native_contexts.get(generatorAPI).longValue(), this._genVar, str, this._type.ordinal());
        }
        return -1;
    }

    private native int compile_nvrtc(long j, String str, String str2, int i);
}
