package org.nd4j.linalg.api.ops.impl.transforms.comparison;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.class */
public class CompareAndReplace extends BaseTransformSameOp {
    private Condition condition;
    private double compare;
    private double set;
    private double eps;
    private Conditions.ConditionMode mode;

    public CompareAndReplace(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, Condition condition) {
        super(sameDiff, sDVariable, sDVariable2, false);
        this.condition = condition;
        this.compare = condition.getValue();
        this.set = 0.0d;
        this.mode = condition.conditionType();
        this.eps = condition.epsThreshold();
        this.extraArgs = new Object[]{Double.valueOf(this.compare), Double.valueOf(this.set), Double.valueOf(this.eps), Double.valueOf(this.mode.index)};
    }

    public CompareAndReplace() {
    }

    public CompareAndReplace(INDArray iNDArray, INDArray iNDArray2, Condition condition) {
        this(iNDArray, iNDArray2, (INDArray) null, condition);
    }

    public CompareAndReplace(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Condition condition) {
        super(iNDArray, iNDArray2, iNDArray3);
        this.compare = condition.getValue();
        this.set = 0.0d;
        this.mode = condition.conditionType();
        this.eps = condition.epsThreshold();
        this.extraArgs = new Object[]{Double.valueOf(this.compare), Double.valueOf(this.set), Double.valueOf(this.eps), Double.valueOf(this.mode.index)};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("compare", Double.valueOf(this.compare));
        linkedHashMap.put("set", Double.valueOf(this.set));
        linkedHashMap.put("eps", Double.valueOf(this.eps));
        linkedHashMap.put("mode", this.mode);
        return linkedHashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void setPropertiesForFunction(Map<String, Object> map) {
        if (map.containsKey("mode")) {
            if (map.get("mode") instanceof Integer) {
                Integer num = (Integer) map.get("mode");
                this.mode = Conditions.ConditionMode.fromNumber(num.intValue());
                if (!map.containsKey("compare")) {
                    this.condition = Conditions.fromInt(num.intValue());
                }
            } else if (map.get("mode") instanceof Conditions.ConditionMode) {
                Conditions.ConditionMode conditionMode = (Conditions.ConditionMode) map.get("mode");
                this.mode = conditionMode;
                if (!map.containsKey("compare")) {
                    this.condition = Conditions.fromInt(conditionMode.index);
                }
            }
        }
        if (map.containsKey("compare")) {
            Double d = (Double) map.get("compare");
            this.compare = d.doubleValue();
            if (map.containsKey("mode")) {
                this.condition = Conditions.fromInt(this.mode.index, d);
            }
        }
        if (map.containsKey("set")) {
            this.set = ((Double) map.get("set")).doubleValue();
        }
        if (map.containsKey("eps")) {
            this.eps = ((Double) map.get("eps")).doubleValue();
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int opNum() {
        return 13;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "car";
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable castTo = this.sameDiff.matchCondition(arg(0), this.condition).castTo(arg().dataType());
        return Arrays.asList(list.get(0).mul(castTo.rsub(1.0d)), list.get(0).mul(castTo));
    }

    @Override // org.nd4j.linalg.api.ops.BaseTransformSameOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), list);
        Preconditions.checkState(list.get(0) == list.get(1), "Input data types must be the same: got %s", list);
        return Collections.singletonList(list.get(0));
    }
}
