package org.datavec.api.transform.ndarray;

import java.util.Arrays;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnsMathOpTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.class */
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
    public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String str, @JsonProperty("mathOp") MathOp mathOp, @JsonProperty("columns") String... strArr) {
        super(str, mathOp, strArr);
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    protected ColumnMetaData derivedColumnMetaData(String str, Schema schema) {
        for (int i = 0; i < this.columns.length; i++) {
            if (schema.getMetaData(this.columns[i]).getColumnType() != ColumnType.NDArray) {
                throw new RuntimeException("Column " + this.columns[i] + " is not an NDArray column");
            }
        }
        NDArrayMetaData nDArrayMetaData = (NDArrayMetaData) schema.getMetaData(this.columns[0]);
        for (int i2 = 1; i2 < this.columns.length; i2++) {
            NDArrayMetaData nDArrayMetaData2 = (NDArrayMetaData) schema.getMetaData(this.columns[i2]);
            if (!Arrays.equals(nDArrayMetaData.getShape(), nDArrayMetaData2.getShape())) {
                throw new UnsupportedOperationException("Cannot perform NDArray operation on columns with different shapes: Columns \"" + this.columns[0] + "\" and \"" + this.columns[i2] + "\" have shapes: " + Arrays.toString(nDArrayMetaData.getShape()) + " and " + Arrays.toString(nDArrayMetaData2.getShape()));
            }
        }
        return new NDArrayMetaData(str, nDArrayMetaData.getShape());
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    protected Writable doOp(Writable... writableArr) {
        INDArray dup = ((NDArrayWritable) writableArr[0]).get().dup();
        switch (this.mathOp) {
            case Add:
                for (int i = 1; i < writableArr.length; i++) {
                    dup.addi(((NDArrayWritable) writableArr[i]).get());
                }
                break;
            case Subtract:
                dup.subi(((NDArrayWritable) writableArr[1]).get());
                break;
            case Multiply:
                for (int i2 = 1; i2 < writableArr.length; i2++) {
                    dup.muli(((NDArrayWritable) writableArr[i2]).get());
                }
                break;
            case Divide:
                dup.divi(((NDArrayWritable) writableArr[1]).get());
                break;
            case ReverseSubtract:
                dup.rsubi(((NDArrayWritable) writableArr[1]).get());
                break;
            case ReverseDivide:
                dup.rdivi(((NDArrayWritable) writableArr[1]).get());
                break;
            case Modulus:
            case ScalarMin:
            case ScalarMax:
                throw new IllegalArgumentException("Invalid MathOp: cannot use " + this.mathOp + " with NDArrayColumnsMathOpTransform");
            default:
                throw new RuntimeException("Unknown MathOp: " + this.mathOp);
        }
        Nd4j.getExecutioner().commit();
        return new NDArrayWritable(dup);
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    public String toString() {
        return "NDArrayColumnsMathOpTransform(newColumnName=\"" + this.newColumnName + "\",mathOp=" + this.mathOp + ",columns=" + Arrays.toString(this.columns) + ")";
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof NDArrayColumnsMathOpTransform) && ((NDArrayColumnsMathOpTransform) obj).canEqual(this);
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    protected boolean canEqual(Object obj) {
        return obj instanceof NDArrayColumnsMathOpTransform;
    }

    @Override // org.datavec.api.transform.transform.BaseColumnsMathOpTransform
    public int hashCode() {
        return 1;
    }
}
