package org.datavec.api.transform.ndarray;

import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.class */
public class NDArrayScalarOpTransform extends BaseColumnTransform {
    private final MathOp mathOp;
    private final double scalar;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.datavec.api.transform.ndarray.NDArrayScalarOpTransform$1, reason: invalid class name */
    /* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayScalarOpTransform$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$datavec$api$transform$MathOp = new int[MathOp.values().length];

        static {
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.Add.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.Subtract.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.Multiply.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.Divide.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.Modulus.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.ReverseSubtract.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.ReverseDivide.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.ScalarMin.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$MathOp[MathOp.ScalarMax.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    public NDArrayScalarOpTransform(@JsonProperty("columnName") String str, @JsonProperty("mathOp") MathOp mathOp, @JsonProperty("scalar") double d) {
        super(str);
        this.mathOp = mathOp;
        this.scalar = d;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public ColumnMetaData getNewColumnMetaData(String str, ColumnMetaData columnMetaData) {
        if (!(columnMetaData instanceof NDArrayMetaData)) {
            throw new IllegalStateException("Column " + str + " is not a NDArray column");
        }
        NDArrayMetaData mo38clone = ((NDArrayMetaData) columnMetaData).mo38clone();
        mo38clone.setName(str);
        return mo38clone;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform
    public NDArrayWritable map(Writable writable) {
        if (!(writable instanceof NDArrayWritable)) {
            throw new IllegalArgumentException("Input writable is not an NDArrayWritable: is " + writable.getClass());
        }
        INDArray dup = ((NDArrayWritable) writable).get().dup();
        switch (AnonymousClass1.$SwitchMap$org$datavec$api$transform$MathOp[this.mathOp.ordinal()]) {
            case NDArrayWritable.NDARRAY_SER_VERSION_HEADER /* 1 */:
                dup.addi(Double.valueOf(this.scalar));
                break;
            case 2:
                dup.subi(Double.valueOf(this.scalar));
                break;
            case 3:
                dup.muli(Double.valueOf(this.scalar));
                break;
            case 4:
                dup.divi(Double.valueOf(this.scalar));
                break;
            case 5:
                throw new UnsupportedOperationException(this.mathOp + " is not supported for NDArrayWritable");
            case 6:
                dup.rsubi(Double.valueOf(this.scalar));
                break;
            case 7:
                dup.rdivi(Double.valueOf(this.scalar));
                break;
            case 8:
                Transforms.min(dup, this.scalar, false);
                break;
            case 9:
                Transforms.max(dup, this.scalar, false);
                break;
            default:
                throw new UnsupportedOperationException("Unknown or not supported op: " + this.mathOp);
        }
        Nd4j.getExecutioner().commit();
        return new NDArrayWritable(dup);
    }

    @Override // org.datavec.api.transform.transform.BaseColumnTransform, org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "NDArrayScalarOpTransform(mathOp=" + this.mathOp + ",scalar=" + this.scalar + ")";
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        if (obj instanceof INDArray) {
            return map((Writable) new NDArrayWritable((INDArray) obj)).get();
        }
        throw new RuntimeException("Unsupported class: " + obj.getClass());
    }
}
