package org.apache.sysds.utils.stats;

import java.util.concurrent.atomic.LongAdder;
import org.apache.sysds.utils.NativeHelper;

/* loaded from: input_file:org/apache/sysds/utils/stats/NativeStatistics.class */
public class NativeStatistics {
    private static LongAdder numFailures = new LongAdder();
    private static LongAdder numLibMatrixMultCalls = new LongAdder();
    private static LongAdder libMatrixMultTime = new LongAdder();
    private static LongAdder numConv2dCalls = new LongAdder();
    private static LongAdder conv2dTime = new LongAdder();
    private static LongAdder numConv2dBwdDataCalls = new LongAdder();
    private static LongAdder conv2dBwdDataTime = new LongAdder();
    private static LongAdder numConv2dBwdFilterCalls = new LongAdder();
    private static LongAdder conv2dBwdFilterTime = new LongAdder();
    private static LongAdder numSparseConv2dCalls = new LongAdder();
    private static LongAdder numSparseConv2dBwdFilterCalls = new LongAdder();
    private static LongAdder numSparseConv2dBwdDataCalls = new LongAdder();

    public static void incrementFailuresCounter() {
        numFailures.increment();
        throw new RuntimeException("Unexpected ERROR: OOM caused during JNI transfer. Please disable native BLAS by setting enviroment variable: SYSTEMDS_BLAS=none");
    }

    public static void incrementNumLibMatrixMultCalls() {
        numLibMatrixMultCalls.increment();
    }

    public static void incrementLibMatrixMultTime(long j) {
        libMatrixMultTime.add(j);
    }

    public static void incrementNumConv2dCalls() {
        numConv2dCalls.increment();
    }

    public static void incrementConv2dTime(long j) {
        conv2dTime.add(j);
    }

    public static void incrementNumConv2dBwdDataCalls() {
        numConv2dBwdDataCalls.increment();
    }

    public static void incrementConv2dBwdDataTime(long j) {
        conv2dBwdDataTime.add(j);
    }

    public static void incrementNumConv2dBwdFilterCalls() {
        numConv2dBwdFilterCalls.increment();
    }

    public static void incrementConv2dBwdFilterTime(long j) {
        conv2dBwdFilterTime.add(j);
    }

    public static void incrementNumSparseConv2dCalls() {
        numSparseConv2dCalls.increment();
    }

    public static void incrementNumSparseConv2dBwdFilterCalls() {
        numSparseConv2dBwdFilterCalls.increment();
    }

    public static void incrementNumSparseConv2dBwdDataCalls() {
        numSparseConv2dBwdDataCalls.increment();
    }

    public static void reset() {
        numLibMatrixMultCalls.reset();
        numSparseConv2dCalls.reset();
        numSparseConv2dBwdDataCalls.reset();
        numSparseConv2dBwdFilterCalls.reset();
        numConv2dCalls.reset();
        numConv2dBwdDataCalls.reset();
        numConv2dBwdFilterCalls.reset();
        numFailures.reset();
        libMatrixMultTime.reset();
        conv2dTime.reset();
        conv2dBwdFilterTime.reset();
        conv2dBwdDataTime.reset();
    }

    public static String displayStatistics() {
        StringBuilder sb = new StringBuilder();
        String currentBLAS = NativeHelper.getCurrentBLAS();
        long longValue = numLibMatrixMultCalls.longValue();
        long longValue2 = numConv2dCalls.longValue();
        numConv2dBwdFilterCalls.longValue();
        numConv2dBwdDataCalls.longValue();
        sb.append("Native " + currentBLAS + " calls (dense mult/conv/bwdF/bwdD):\t" + longValue + "/" + sb + "/" + longValue2 + "/" + sb + ".\n");
        long longValue3 = numSparseConv2dCalls.longValue();
        long longValue4 = numSparseConv2dBwdFilterCalls.longValue();
        numSparseConv2dBwdDataCalls.longValue();
        sb.append("Native " + currentBLAS + " calls (sparse conv/bwdF/bwdD):\t" + longValue3 + "/" + sb + "/" + longValue4 + ".\n");
        sb.append("Native " + currentBLAS + " times (dense mult/conv/bwdF/bwdD):\t" + String.format("%.3f", Double.valueOf(libMatrixMultTime.longValue() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(conv2dTime.longValue() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(conv2dBwdFilterTime.longValue() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(conv2dBwdDataTime.longValue() * 1.0E-9d)) + ".\n");
        return sb.toString();
    }
}
