package org.deeplearning4j.optimize.solvers.accumulation;

import com.google.common.util.concurrent.AtomicDouble;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.eval.EvaluationBinary;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.class */
public class EncodingHandler implements MessageHandler {
    private static final Logger log = LoggerFactory.getLogger(EncodingHandler.class);
    protected transient GradientsAccumulator accumulator;
    protected double threshold;
    protected double minThreshold;
    protected double thresholdStep;
    protected double stepTrigger;
    protected int shakeFrequency;
    protected int stepDelay;
    protected Double boundary;
    protected NDArrayCompressor compressor;
    protected AtomicInteger atomicBoundary;
    protected ThreadLocal<AtomicLong> iterations;
    protected ThreadLocal<AtomicLong> lastStep;
    protected ThreadLocal<AtomicDouble> currentThreshold;
    protected ThreadLocal<AtomicBoolean> bitmapMode;

    public EncodingHandler() {
        this(0.001d);
    }

    public EncodingHandler(double d) {
        this(d, null);
    }

    public EncodingHandler(double d, Double d2) {
        this(d, d, EvaluationBinary.DEFAULT_EDGE_VALUE, EvaluationBinary.DEFAULT_EDGE_VALUE, 0, 0, d2);
    }

    public EncodingHandler(double d, double d2, double d3, double d4, int i, int i2) {
        this(d, d2, d3, d4, i, i2, null);
    }

    public EncodingHandler(double d, double d2, double d3, double d4, int i, int i2, Double d5) {
        this.boundary = null;
        this.atomicBoundary = new AtomicInteger(-1);
        this.iterations = new ThreadLocal<>();
        this.lastStep = new ThreadLocal<>();
        this.currentThreshold = new ThreadLocal<>();
        this.bitmapMode = new ThreadLocal<>();
        this.threshold = d;
        this.minThreshold = d2;
        this.stepTrigger = d4;
        this.stepDelay = i;
        this.thresholdStep = d3;
        this.shakeFrequency = i2;
        this.boundary = d5;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.MessageHandler
    public void initialize(@NonNull GradientsAccumulator gradientsAccumulator) {
        if (gradientsAccumulator == null) {
            throw new NullPointerException("accumulator is marked @NonNull but is null");
        }
        this.accumulator = gradientsAccumulator;
        this.compressor = Nd4j.getCompressor().getCompressor("THRESHOLD");
        if (this.compressor == null) {
            throw new ND4JIllegalStateException("Can't find Threshold compressor implementation!");
        }
        this.compressor.configure(new Object[]{Double.valueOf(this.threshold)});
    }

    public INDArray encodeUpdates(INDArray iNDArray) {
        INDArray createArrayFromShapeBuffer;
        if (this.bitmapMode.get() == null) {
            this.bitmapMode.set(new AtomicBoolean(true));
            this.currentThreshold.set(new AtomicDouble(this.threshold));
            this.iterations.set(new AtomicLong(0L));
            this.lastStep.set(new AtomicLong(0L));
        }
        this.iterations.get().incrementAndGet();
        if (this.boundary != null && this.atomicBoundary.get() < 0) {
            this.atomicBoundary.compareAndSet(-1, (int) (iNDArray.lengthLong() * this.boundary.doubleValue()));
        }
        if (this.bitmapMode.get().get()) {
            createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.lengthLong() / 16) + 5), iNDArray.shapeInfoDataBuffer());
            if (Nd4j.getExecutioner().bitmapEncode(iNDArray, createArrayFromShapeBuffer, this.currentThreshold.get().get()) < ((iNDArray.lengthLong() / 16) + 5) / 2) {
                this.bitmapMode.get().set(false);
                log.debug("Switched to threshold encoding");
            }
        } else if (this.shakeFrequency == 0 || this.iterations.get().get() % this.shakeFrequency != 0) {
            createArrayFromShapeBuffer = Nd4j.getExecutioner().thresholdEncode(iNDArray, this.currentThreshold.get().get(), this.boundary == null ? null : Integer.valueOf(this.atomicBoundary.get()));
            if (createArrayFromShapeBuffer == null) {
                return null;
            }
            double d = createArrayFromShapeBuffer.data().getInt(0L);
            double length = (d * 100.0d) / iNDArray.length();
            if (d >= iNDArray.lengthLong() / 16) {
                log.debug("Going back to bitmapEncoding");
                this.bitmapMode.get().set(true);
                INDArray createArrayFromShapeBuffer2 = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.lengthLong() / 16) + 5), iNDArray.shapeInfoDataBuffer());
                Nd4j.getExecutioner().bitmapEncode(iNDArray, createArrayFromShapeBuffer2, this.currentThreshold.get().get());
                return createArrayFromShapeBuffer2;
            }
            if (this.minThreshold <= this.currentThreshold.get().get() && this.minThreshold < this.currentThreshold.get().get() - this.thresholdStep && this.iterations.get().get() > this.lastStep.get().get() + this.stepDelay && length < this.stepTrigger) {
                this.currentThreshold.get().addAndGet(-this.thresholdStep);
                this.lastStep.set(this.iterations.get());
                log.debug("Threshold steps down to {}", Double.valueOf(this.currentThreshold.get().get()));
            }
        } else {
            createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.lengthLong() / 16) + 5), iNDArray.shapeInfoDataBuffer());
            Nd4j.getExecutioner().bitmapEncode(iNDArray, createArrayFromShapeBuffer, this.currentThreshold.get().get() / 3.0d);
        }
        return createArrayFromShapeBuffer;
    }

    @Deprecated
    public INDArray decodeUpdates(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    protected void sendMessage(INDArray iNDArray) {
        this.accumulator.receiveUpdate(iNDArray);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.MessageHandler
    public boolean broadcastUpdates(INDArray iNDArray) {
        INDArray encodeUpdates = encodeUpdates(iNDArray);
        if (encodeUpdates == null) {
            return false;
        }
        sendMessage(encodeUpdates);
        return true;
    }
}
