package org.nd4j.linalg.cpu.nativecpu.compression;

import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.compression.impl.AbstractCompressor;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/compression/CpuThreshold.class */
public class CpuThreshold extends AbstractCompressor {
    private static final Logger log = LoggerFactory.getLogger(CpuThreshold.class);
    protected float threshold = 0.001f;

    @Override // org.nd4j.linalg.compression.NDArrayCompressor
    public String getDescriptor() {
        return "THRESHOLD";
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public void configure(Object... objArr) {
        if (!(objArr[0] instanceof Number)) {
            throw new ND4JIllegalStateException("Threshold value should be Number");
        }
        this.threshold = FastMath.abs(((Number) objArr[0]).floatValue());
        log.info("Setting threshold to [{}]", Float.valueOf(this.threshold));
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public INDArray compress(INDArray iNDArray) {
        Nd4j.getExecutioner().commit();
        Nd4j.getAffinityManager().ensureLocation(iNDArray, AffinityManager.Location.HOST);
        DataBuffer compress = compress(iNDArray.data());
        if (compress == null) {
            return null;
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(compress, iNDArray.shapeInfoDataBuffer());
        createArrayFromShapeBuffer.markAsCompressed(true);
        return createArrayFromShapeBuffer;
    }

    @Override // org.nd4j.linalg.compression.NDArrayCompressor
    public CompressionType getCompressionType() {
        return CompressionType.LOSSLESS;
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public DataBuffer decompress(DataBuffer dataBuffer) {
        return Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.THRESHOLD, dataBuffer, getGlobalTypeEx());
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public DataBuffer compress(DataBuffer dataBuffer) {
        int i = Nd4j.getExecutioner().exec((Accumulation) new MatchCondition(Nd4j.createArrayFromShapeBuffer(dataBuffer, Nd4j.getShapeInfoProvider().createShapeInformation(new int[]{1, (int) dataBuffer.length()}).getFirst()), Conditions.absGreaterThanOrEqual(Float.valueOf(this.threshold))), Integer.MAX_VALUE).getInt(0);
        if (i < 2) {
            return null;
        }
        long length = dataBuffer.length() * Nd4j.sizeOfDataType(dataBuffer.dataType());
        IntPointer intPointer = new IntPointer(i + 4);
        intPointer.put(0L, i);
        intPointer.put(1L, (int) dataBuffer.length());
        intPointer.put(2L, Float.floatToIntBits(this.threshold));
        intPointer.put(3L, 0);
        CompressionDescriptor compressionDescriptor = new CompressionDescriptor();
        compressionDescriptor.setCompressedLength(r0 * 4);
        compressionDescriptor.setOriginalLength(length);
        compressionDescriptor.setOriginalElementSize(Nd4j.sizeOfDataType(dataBuffer.dataType()));
        compressionDescriptor.setNumberOfElements(dataBuffer.length());
        compressionDescriptor.setCompressionAlgorithm(getDescriptor());
        compressionDescriptor.setCompressionType(getCompressionType());
        CompressedDataBuffer compressedDataBuffer = new CompressedDataBuffer(intPointer, compressionDescriptor);
        Nd4j.getNDArrayFactory().convertDataEx(getBufferTypeEx(dataBuffer), dataBuffer.addressPointer(), DataBuffer.TypeEx.THRESHOLD, intPointer, dataBuffer.length());
        Nd4j.getAffinityManager().tagLocation(dataBuffer, AffinityManager.Location.HOST);
        return compressedDataBuffer;
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor
    protected CompressedDataBuffer compressPointer(DataBuffer.TypeEx typeEx, Pointer pointer, int i, int i2) {
        throw new UnsupportedOperationException();
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float f) {
        this.threshold = f;
    }
}
