package org.apache.sysds.runtime.instructions.gpu;

import java.util.ArrayList;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.class */
public class SpoofCUDAInstruction extends GPUInstruction {
    private static final Log LOG = LogFactory.getLog(SpoofCUDAInstruction.class.getName());
    public static SpoofCUDAOperator.PrecisionProxy proxy = null;
    private final SpoofCUDAOperator _op;
    private final CPOperand[] _in;
    public final CPOperand _out;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction$DoublePrecision.class */
    public static class DoublePrecision extends SpoofCUDAOperator.PrecisionProxy {
        @Override // org.apache.sysds.runtime.codegen.SpoofCUDAOperator.PrecisionProxy
        public int exec(SpoofCUDAOperator spoofCUDAOperator) {
            return spoofCUDAOperator.execute_dp(this.ctx);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction$SinglePrecision.class */
    public static class SinglePrecision extends SpoofCUDAOperator.PrecisionProxy {
        @Override // org.apache.sysds.runtime.codegen.SpoofCUDAOperator.PrecisionProxy
        public int exec(SpoofCUDAOperator spoofCUDAOperator) {
            return spoofCUDAOperator.execute_sp(this.ctx);
        }
    }

    public static void resetFloatingPointPrecision() {
        if (DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("single")) {
            proxy = new SinglePrecision();
        } else {
            if (!DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase(Statement.DOUBLE_VALUE_TYPE)) {
                throw new DMLRuntimeException("Unsupported floating point precision: " + DMLScript.FLOATING_POINT_PRECISION);
            }
            proxy = new DoublePrecision();
        }
    }

    private SpoofCUDAInstruction(SpoofCUDAOperator spoofCUDAOperator, CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(null, str, str2);
        this._op = spoofCUDAOperator;
        this._in = cPOperandArr;
        this._out = cPOperand;
        this.instString = str2;
        this.instOpcode = str;
    }

    public static SpoofCUDAInstruction parseInstruction(String str) {
        if (proxy == null) {
            throw new RuntimeException("SpoofCUDA Executor has not been initialized");
        }
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList arrayList = new ArrayList();
        Integer cUDAopID = CodegenUtils.getCUDAopID(instructionPartsWithValueType[2]);
        SpoofOperator createInstance = CodegenUtils.createInstance(CodegenUtils.getClass(instructionPartsWithValueType[2]));
        SpoofCUDAOperator createCUDAInstrcution = createInstance.createCUDAInstrcution(cUDAopID, proxy);
        String str2 = instructionPartsWithValueType[0] + "CUDA" + createInstance.getSpoofType();
        for (int i = 3; i < instructionPartsWithValueType.length - 2; i++) {
            arrayList.add(new CPOperand(instructionPartsWithValueType[i]));
        }
        return new SpoofCUDAInstruction(createCUDAInstrcution, (CPOperand[]) arrayList.toArray(new CPOperand[0]), new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 2]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        ArrayList<MatrixObject> arrayList = new ArrayList<>();
        ArrayList<ScalarObject> arrayList2 = new ArrayList<>();
        for (CPOperand cPOperand : this._in) {
            if (cPOperand.getDataType() == Types.DataType.MATRIX) {
                arrayList.add(executionContext.getMatrixInputForGPUInstruction(cPOperand.getName(), getExtendedOpcode()));
            } else if (cPOperand.getDataType() == Types.DataType.SCALAR) {
                arrayList2.add(executionContext.getScalarInput(cPOperand));
            }
        }
        try {
            if (this._out.getDataType() == Types.DataType.MATRIX) {
                this._op.execute(executionContext, arrayList, arrayList2, this._out.getName());
                executionContext.releaseMatrixOutputForGPUInstruction(this._out.getName());
            } else if (this._out.getDataType() == Types.DataType.SCALAR) {
                executionContext.setScalarOutput(this._out.getName(), this._op.execute(executionContext, arrayList, arrayList2));
            }
            for (CPOperand cPOperand2 : this._in) {
                if (cPOperand2.getDataType() == Types.DataType.MATRIX) {
                    executionContext.releaseMatrixInputForGPUInstruction(cPOperand2.getName());
                }
            }
        } catch (Exception e) {
            LOG.error("SpoofCUDAInstruction: " + this._op.getName() + " operator failed to execute :(\n");
            throw new DMLRuntimeException(e);
        }
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this._out.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this._in)));
    }
}
