/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.dropout;

import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"mask", "helper"})
public class Dropout
implements IDropout {
    private static final Logger log = LoggerFactory.getLogger(Dropout.class);
    private double p;
    private ISchedule pSchedule;
    private transient INDArray mask;
    private transient DropoutHelper helper;
    private boolean initializedHelper = false;

    public Dropout(double activationRetainProbability) {
        this(activationRetainProbability, null);
        if (activationRetainProbability < 0.0) {
            throw new IllegalArgumentException("Activation retain probability must be > 0. Got: " + activationRetainProbability);
        }
        if (activationRetainProbability == 0.0) {
            throw new IllegalArgumentException("Invalid probability value: Dropout with 0.0 probability of retaining activations is not supported");
        }
    }

    public Dropout(ISchedule activationRetainProbabilitySchedule) {
        this(Double.NaN, activationRetainProbabilitySchedule);
    }

    protected Dropout(@JsonProperty(value="p") double activationRetainProbability, @JsonProperty(value="pSchedule") ISchedule activationRetainProbabilitySchedule) {
        this.p = activationRetainProbability;
        this.pSchedule = activationRetainProbabilitySchedule;
    }

    protected void initializeHelper(DataType dataType) {
        block4: {
            String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
            if ("CUDA".equalsIgnoreCase(backend)) {
                try {
                    this.helper = Class.forName("org.deeplearning4j.nn.layers.dropout.CudnnDropoutHelper").asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType);
                    log.debug("CudnnDropoutHelper successfully initialized");
                    if (!this.helper.checkSupported()) {
                        this.helper = null;
                    }
                }
                catch (Throwable t) {
                    if (t instanceof ClassNotFoundException) break block4;
                    log.warn("Could not initialize CudnnDropoutHelper", t);
                }
            }
        }
        this.initializedHelper = true;
    }

    @Override
    public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
        Preconditions.checkState((boolean)output.dataType().isFPType(), (String)"Output array must be a floating point type, got %s for array of shape %ndShape", (Object)output.dataType(), (Object)output);
        double currP = this.pSchedule != null ? this.pSchedule.valueAt(iteration, epoch) : this.p;
        if (!this.initializedHelper) {
            this.initializeHelper(output.dataType());
        }
        if (this.helper != null) {
            this.helper.applyDropout(inputActivations, output, this.p);
            return output;
        }
        INDArray inputCast = inputActivations;
        if (inputCast != output && inputCast.dataType() != output.dataType()) {
            inputCast = inputCast.castTo(output.dataType());
        }
        this.mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign((Number)1.0);
        Nd4j.getExecutioner().exec((RandomOp)new DropOutInverted(this.mask, this.mask, currP));
        Nd4j.getExecutioner().exec((Op)new OldMulOp(inputCast, this.mask, output));
        return output;
    }

    @Override
    public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iteration, int epoch) {
        if (this.helper != null) {
            this.helper.backprop(gradAtOutput, gradAtInput);
            return gradAtInput;
        }
        Preconditions.checkState((this.mask != null ? 1 : 0) != 0, (String)"Cannot perform backprop: Dropout mask array is absent (already cleared?)");
        INDArray m = this.mask;
        if (m.dataType() != gradAtInput.dataType()) {
            m = m.castTo(gradAtInput.dataType());
        }
        Nd4j.getExecutioner().exec((Op)new OldMulOp(gradAtOutput, m, gradAtInput));
        this.mask = null;
        return gradAtInput;
    }

    @Override
    public void clear() {
        this.mask = null;
    }

    @Override
    public Dropout clone() {
        return new Dropout(this.p, this.pSchedule == null ? null : this.pSchedule.clone());
    }

    public double getP() {
        return this.p;
    }

    public ISchedule getPSchedule() {
        return this.pSchedule;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public DropoutHelper getHelper() {
        return this.helper;
    }

    public boolean isInitializedHelper() {
        return this.initializedHelper;
    }

    public void setP(double p) {
        this.p = p;
    }

    public void setPSchedule(ISchedule pSchedule) {
        this.pSchedule = pSchedule;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    public void setHelper(DropoutHelper helper) {
        this.helper = helper;
    }

    public void setInitializedHelper(boolean initializedHelper) {
        this.initializedHelper = initializedHelper;
    }

    public String toString() {
        return "Dropout(p=" + this.getP() + ", pSchedule=" + this.getPSchedule() + ", mask=" + this.getMask() + ", helper=" + this.getHelper() + ", initializedHelper=" + this.isInitializedHelper() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Dropout)) {
            return false;
        }
        Dropout other = (Dropout)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getP(), other.getP()) != 0) {
            return false;
        }
        ISchedule this$pSchedule = this.getPSchedule();
        ISchedule other$pSchedule = other.getPSchedule();
        if (this$pSchedule == null ? other$pSchedule != null : !this$pSchedule.equals(other$pSchedule)) {
            return false;
        }
        return this.isInitializedHelper() == other.isInitializedHelper();
    }

    protected boolean canEqual(Object other) {
        return other instanceof Dropout;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $p = Double.doubleToLongBits(this.getP());
        result = result * 59 + (int)($p >>> 32 ^ $p);
        ISchedule $pSchedule = this.getPSchedule();
        result = result * 59 + ($pSchedule == null ? 43 : $pSchedule.hashCode());
        result = result * 59 + (this.isInitializedHelper() ? 79 : 97);
        return result;
    }
}

