package org.nd4j.linalg.dataset;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/DataSet.class */
public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
    private static final Logger log = LoggerFactory.getLogger(DataSet.class);
    private static final long serialVersionUID = 1935520764586513365L;
    private static final byte BITMASK_FEATURES_PRESENT = 1;
    private static final byte BITMASK_LABELS_PRESENT = 2;
    private static final byte BITMASK_LABELS_SAME_AS_FEATURES = 4;
    private static final byte BITMASK_FEATURE_MASK_PRESENT = 8;
    private static final byte BITMASK_LABELS_MASK_PRESENT = 16;
    private static final byte BITMASK_METADATA_PRESET = 32;
    private List<String> columnNames;
    private List<String> labelNames;
    private INDArray features;
    private INDArray labels;
    private INDArray featuresMask;
    private INDArray labelsMask;
    private List<Serializable> exampleMetaData;
    private transient boolean preProcessed;

    public DataSet() {
        this(null, null);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<Serializable> getExampleMetaData() {
        return this.exampleMetaData;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public <T extends Serializable> List<T> getExampleMetaData(Class<T> cls) {
        return (List<T>) this.exampleMetaData;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setExampleMetaData(List<? extends Serializable> list) {
        this.exampleMetaData = list;
    }

    public DataSet(INDArray iNDArray, INDArray iNDArray2) {
        this(iNDArray, iNDArray2, null, null);
    }

    public DataSet(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        this.columnNames = new ArrayList();
        this.labelNames = new ArrayList();
        this.preProcessed = false;
        this.features = iNDArray;
        this.labels = iNDArray2;
        this.featuresMask = iNDArray3;
        this.labelsMask = iNDArray4;
        Nd4j.getExecutioner().commit();
    }

    public boolean isPreProcessed() {
        return this.preProcessed;
    }

    public void markAsPreProcessed() {
        this.preProcessed = true;
    }

    public static DataSet empty() {
        return new DataSet(null, null);
    }

    public static DataSet merge(List<? extends org.nd4j.linalg.dataset.api.DataSet> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        int i = 0;
        boolean z = false;
        boolean z2 = false;
        boolean z3 = true;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet : list) {
            if (!dataSet.isEmpty()) {
                i++;
                if ((z && dataSet.getFeatures() == null) || (!z3 && !z && dataSet.getFeatures() != null)) {
                    throw new IllegalStateException("Cannot merge features: encountered null features in one or more DataSets");
                }
                if ((z2 && dataSet.getLabels() == null) || (!z3 && !z2 && dataSet.getLabels() != null)) {
                    throw new IllegalStateException("Cannot merge labels: enountered null labels in one or more DataSets");
                }
                z |= dataSet.getFeatures() != null;
                z2 |= dataSet.getLabels() != null;
                z3 = false;
            }
        }
        INDArray[] iNDArrayArr = new INDArray[i];
        INDArray[] iNDArrayArr2 = new INDArray[i];
        INDArray[] iNDArrayArr3 = null;
        INDArray[] iNDArrayArr4 = null;
        int i2 = 0;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet2 : list) {
            if (!dataSet2.isEmpty()) {
                iNDArrayArr[i2] = dataSet2.getFeatures();
                iNDArrayArr2[i2] = dataSet2.getLabels();
                if (dataSet2.getFeaturesMaskArray() != null) {
                    if (iNDArrayArr3 == null) {
                        iNDArrayArr3 = new INDArray[i];
                    }
                    iNDArrayArr3[i2] = dataSet2.getFeaturesMaskArray();
                }
                if (dataSet2.getLabelsMaskArray() != null) {
                    if (iNDArrayArr4 == null) {
                        iNDArrayArr4 = new INDArray[i];
                    }
                    iNDArrayArr4[i2] = dataSet2.getLabelsMaskArray();
                }
                i2++;
            }
        }
        Pair<INDArray, INDArray> mergeFeatures = DataSetUtil.mergeFeatures(iNDArrayArr, iNDArrayArr3);
        INDArray iNDArray = (INDArray) mergeFeatures.getFirst();
        INDArray iNDArray2 = (INDArray) mergeFeatures.getSecond();
        Pair<INDArray, INDArray> mergeLabels = DataSetUtil.mergeLabels(iNDArrayArr2, iNDArrayArr4);
        DataSet dataSet3 = new DataSet(iNDArray, (INDArray) mergeLabels.getFirst(), iNDArray2, (INDArray) mergeLabels.getSecond());
        ArrayList arrayList = null;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet4 : list) {
            if (dataSet4.getExampleMetaData() == null || dataSet4.getExampleMetaData().size() != dataSet4.numExamples()) {
                arrayList = null;
                break;
            }
            if (arrayList == null) {
                arrayList = new ArrayList();
            }
            arrayList.addAll(dataSet4.getExampleMetaData());
        }
        if (arrayList != null) {
            dataSet3.setExampleMetaData(arrayList);
        }
        return dataSet3;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public org.nd4j.linalg.dataset.api.DataSet getRange(int i, int i2) {
        if (hasMaskArrays()) {
            return new DataSet(this.features.get(NDArrayIndex.interval(i, i2)), this.labels.get(NDArrayIndex.interval(i, i2)), this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(i, i2)) : null, this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(i, i2)) : null);
        }
        return new DataSet(this.features.get(NDArrayIndex.interval(i, i2)), this.labels.get(NDArrayIndex.interval(i, i2)));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void load(InputStream inputStream) {
        try {
            DataInputStream dataInputStream = inputStream instanceof BufferedInputStream ? new DataInputStream(inputStream) : new DataInputStream(new BufferedInputStream(inputStream));
            byte readByte = dataInputStream.readByte();
            boolean z = (readByte & 1) != 0;
            boolean z2 = (readByte & 2) != 0;
            boolean z3 = (readByte & 4) != 0;
            boolean z4 = (readByte & 8) != 0;
            boolean z5 = (readByte & 16) != 0;
            boolean z6 = (readByte & BITMASK_METADATA_PRESET) != 0;
            this.features = z ? Nd4j.read(dataInputStream) : null;
            if (z2) {
                this.labels = Nd4j.read(dataInputStream);
            } else if (z3) {
                this.labels = this.features;
            } else {
                this.labels = null;
            }
            this.featuresMask = z4 ? Nd4j.read(dataInputStream) : null;
            this.labelsMask = z5 ? Nd4j.read(dataInputStream) : null;
            if (z6) {
                this.exampleMetaData = (List) new ObjectInputStream(dataInputStream).readObject();
            }
            dataInputStream.close();
        } catch (Exception e) {
            throw new RuntimeException("Error loading DataSet", e);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void load(File file) {
        try {
            FileInputStream fileInputStream = new FileInputStream(file);
            Throwable th = null;
            try {
                BufferedInputStream bufferedInputStream = new BufferedInputStream(fileInputStream, 1048576);
                Throwable th2 = null;
                try {
                    try {
                        load(bufferedInputStream);
                        if (bufferedInputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedInputStream.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                bufferedInputStream.close();
                            }
                        }
                        if (fileInputStream != null) {
                            if (0 != 0) {
                                try {
                                    fileInputStream.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                fileInputStream.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (bufferedInputStream != null) {
                        if (th2 != null) {
                            try {
                                bufferedInputStream.close();
                            } catch (Throwable th6) {
                                th2.addSuppressed(th6);
                            }
                        } else {
                            bufferedInputStream.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void save(OutputStream outputStream) {
        byte b = 0;
        if (this.features != null) {
            b = (byte) (0 | 1);
        }
        if (this.labels != null) {
            b = this.labels == this.features ? (byte) (b | 4) : (byte) (b | 2);
        }
        if (this.featuresMask != null) {
            b = (byte) (b | 8);
        }
        if (this.labelsMask != null) {
            b = (byte) (b | 16);
        }
        if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
            b = (byte) (b | BITMASK_METADATA_PRESET);
        }
        try {
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
            DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
            dataOutputStream.writeByte(b);
            if (this.features != null) {
                Nd4j.write(this.features, dataOutputStream);
            }
            if (this.labels != null && this.labels != this.features) {
                Nd4j.write(this.labels, dataOutputStream);
            }
            if (this.featuresMask != null) {
                Nd4j.write(this.featuresMask, dataOutputStream);
            }
            if (this.labelsMask != null) {
                Nd4j.write(this.labelsMask, dataOutputStream);
            }
            if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream);
                objectOutputStream.writeObject(this.exampleMetaData);
                objectOutputStream.flush();
            }
            dataOutputStream.flush();
            dataOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r7v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r7v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r8v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r8v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 7, insn: 0x00a1: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r7 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:54:0x00a1 */
    /* JADX WARN: Not initialized variable reg: 8, insn: 0x00a5: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r8 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:56:0x00a5 */
    /* JADX WARN: Type inference failed for: r5v0, types: [org.nd4j.linalg.dataset.DataSet] */
    /* JADX WARN: Type inference failed for: r7v0, types: [java.io.FileOutputStream] */
    /* JADX WARN: Type inference failed for: r8v0, types: [java.lang.Throwable] */
    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void save(File file) {
        try {
            try {
                FileOutputStream fileOutputStream = new FileOutputStream(file, false);
                Throwable th = null;
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                try {
                    try {
                        save(bufferedOutputStream);
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    fileOutputStream.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                fileOutputStream.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (bufferedOutputStream != null) {
                        if (th2 != null) {
                            try {
                                bufferedOutputStream.close();
                            } catch (Throwable th6) {
                                th2.addSuppressed(th6);
                            }
                        } else {
                            bufferedOutputStream.close();
                        }
                    }
                    throw th5;
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } finally {
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSetIterator iterateWithMiniBatches() {
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public String id() {
        return "";
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getFeatures() {
        return this.features;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setFeatures(INDArray iNDArray) {
        this.features = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public Map<Integer, Double> labelCounts() {
        HashMap hashMap = new HashMap();
        if (this.labels == null) {
            return hashMap;
        }
        long tensorssAlongDimension = this.labels.tensorssAlongDimension(1);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            INDArray tensorAlongDimension = this.labels.tensorAlongDimension(i, 1);
            INDArray javaTensorAlongDimension = this.labels.javaTensorAlongDimension(i, 1);
            int iamax = Nd4j.getBlasWrapper().iamax(tensorAlongDimension);
            Nd4j.getBlasWrapper().iamax(javaTensorAlongDimension);
            if (iamax < 0) {
                throw new IllegalStateException("Please check the iamax implementation for " + Nd4j.getBlasWrapper().getClass().getName());
            }
            if (hashMap.get(Integer.valueOf(iamax)) == null) {
                hashMap.put(Integer.valueOf(iamax), Double.valueOf(1.0d));
            } else {
                hashMap.put(Integer.valueOf(iamax), Double.valueOf(((Double) hashMap.get(Integer.valueOf(iamax))).doubleValue() + 1.0d));
            }
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void apply(Condition condition, Function<Number, Number> function) {
        BooleanIndexing.applyWhere(getFeatures(), condition, function);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet copy() {
        DataSet dataSet = new DataSet(getFeatures().dup(), getLabels().dup());
        if (getLabelsMaskArray() != null) {
            dataSet.setLabelsMaskArray(getLabelsMaskArray().dup());
        }
        if (getFeaturesMaskArray() != null) {
            dataSet.setFeaturesMaskArray(getFeaturesMaskArray().dup());
        }
        dataSet.setColumnNames(getColumnNames());
        dataSet.setLabelNames(getLabelNames());
        return dataSet;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet reshape(int i, int i2) {
        return new DataSet(getFeatures().reshape(i, i2), getLabels());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void multiplyBy(double d) {
        getFeatures().muli(Nd4j.scalar(d));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void divideBy(int i) {
        getFeatures().divi(Nd4j.scalar(i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void shuffle() {
        shuffle(System.currentTimeMillis());
    }

    public void shuffle(long j) {
        if (numExamples() < 2) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(getFeatures());
        arrayList2.add(ArrayUtil.range(1, getFeatures().rank()));
        arrayList.add(getLabels());
        arrayList2.add(ArrayUtil.range(1, getLabels().rank()));
        if (this.featuresMask != null) {
            arrayList.add(getFeaturesMaskArray());
            arrayList2.add(ArrayUtil.range(1, getFeaturesMaskArray().rank()));
        }
        if (this.labelsMask != null) {
            arrayList.add(getLabelsMaskArray());
            arrayList2.add(ArrayUtil.range(1, getLabelsMaskArray().rank()));
        }
        Nd4j.shuffle(arrayList, new Random(j), arrayList2);
        if (this.exampleMetaData != null) {
            ArrayUtil.shuffleWithMap(this.exampleMetaData, ArrayUtil.buildInterleavedVector(new Random(j), numExamples()));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void squishToRange(double d, double d2) {
        for (int i = 0; i < getFeatures().length(); i++) {
            double doubleValue = ((Double) getFeatures().getScalar(i).element()).doubleValue();
            if (doubleValue < d) {
                getFeatures().put(i, Nd4j.scalar(d));
            } else if (doubleValue > d2) {
                getFeatures().put(i, Nd4j.scalar(d2));
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void scaleMinAndMax(double d, double d2) {
        FeatureUtil.scaleMinMax(d, d2, getFeatures());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void scale() {
        FeatureUtil.scaleByMax(getFeatures());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addFeatureVector(INDArray iNDArray) {
        setFeatures(Nd4j.hstack(getFeatures(), iNDArray));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addFeatureVector(INDArray iNDArray, int i) {
        getFeatures().putRow(i, iNDArray);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void normalize() {
        NormalizerStandardize normalizerStandardize = new NormalizerStandardize();
        normalizerStandardize.fit((org.nd4j.linalg.dataset.api.DataSet) this);
        normalizerStandardize.transform((org.nd4j.linalg.dataset.api.DataSet) this);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void binarize() {
        binarize(0.0d);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void binarize(double d) {
        INDArray linearView = getFeatures().linearView();
        for (int i = 0; i < getFeatures().length(); i++) {
            if (linearView.getDouble(i) > d) {
                getFeatures().putScalar(i, 1);
            } else {
                getFeatures().putScalar(i, 0);
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    @Deprecated
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray mean = getFeatures().mean(0);
        INDArray std = getFeatures().std(0);
        setFeatures(getFeatures().subiRowVector(mean));
        std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        setFeatures(getFeatures().diviRowVector(std));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numInputs() {
        return (int) getFeatures().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void validate() {
        if (getFeatures().size(0) != getLabels().size(0)) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int outcome() {
        return Nd4j.getBlasWrapper().iamax(getLabels());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setNewNumberOfLabels(int i) {
        setLabels(Nd4j.create(numExamples(), i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setOutcome(int i, int i2) {
        if (i > numExamples()) {
            throw new IllegalArgumentException("No example at " + i);
        }
        if (i2 > numOutcomes() || i2 < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        getLabels().putRow(i, FeatureUtil.toOutcomeVector(i2, numOutcomes()));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet get(int i) {
        if (i >= numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number: must be 0 to " + (numExamples() - 1) + ", got " + i);
        }
        return (i == 0 && numExamples() == 1) ? this : new DataSet(getHelper(this.features, i), getHelper(this.labels, i), getHelper(this.featuresMask, i), getHelper(this.labelsMask, i));
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet get(int[] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i : iArr) {
            arrayList.add(get(i));
        }
        return merge(arrayList);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> batchBy(int i) {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = Lists.partition(asList(), i).iterator();
        while (it.hasNext()) {
            newArrayList.add(merge((List) it.next()));
        }
        return newArrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet filterBy(int[] iArr) {
        List<DataSet> asList = asList();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i : iArr) {
            arrayList2.add(Integer.valueOf(i));
        }
        for (DataSet dataSet : asList) {
            if (arrayList2.contains(Integer.valueOf(dataSet.outcome()))) {
                arrayList.add(dataSet);
            }
        }
        return merge(arrayList);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void filterAndStrip(int[] iArr) {
        DataSet filterBy = filterBy(iArr);
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < iArr.length; i++) {
            hashMap.put(Integer.valueOf(iArr[i]), Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < filterBy.numExamples(); i2++) {
            arrayList.add((Integer) hashMap.get(Integer.valueOf(filterBy.get(i2).outcome())));
        }
        INDArray create = Nd4j.create(filterBy.numExamples(), iArr.length);
        if (create.rows() != arrayList.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i3 = 0; i3 < create.rows(); i3++) {
            if (((Integer) arrayList.get(i3)) == null) {
                throw new IllegalStateException("Label not found on row " + i3);
            }
            create.putRow(i3, FeatureUtil.toOutcomeVector(r0.intValue(), iArr.length));
        }
        setFeatures(filterBy.getFeatures());
        setLabels(create);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> dataSetBatches(int i) {
        List partition = Lists.partition(asList(), i);
        ArrayList arrayList = new ArrayList();
        Iterator it = partition.iterator();
        while (it.hasNext()) {
            arrayList.add(merge((List) it.next()));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> sortAndBatchByNumLabels() {
        sortByLabel();
        return batchByNumLabels();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> batchByNumLabels() {
        return batchBy(numOutcomes());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<DataSet> asList() {
        ArrayList arrayList = new ArrayList(numExamples());
        getFeatures().rank();
        getLabels().rank();
        for (int i = 0; i < numExamples(); i++) {
            DataSet dataSet = new DataSet(getHelper(getFeatures(), i), getHelper(this.labels, i), getHelper(this.featuresMask, i), getHelper(this.labelsMask, i));
            if (this.exampleMetaData != null && this.exampleMetaData.size() > i) {
                dataSet.setExampleMetaData(Collections.singletonList(this.exampleMetaData.get(i)));
            }
            arrayList.add(dataSet);
        }
        return arrayList;
    }

    private INDArray getHelper(INDArray iNDArray, int i) {
        if (iNDArray == null) {
            return null;
        }
        switch (iNDArray.rank()) {
            case 2:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all());
            case 3:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all());
            case 4:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            case 5:
                return iNDArray.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            default:
                throw new IllegalStateException("Cannot convert to list: feature set rank must be in range 2 to 5 inclusive. Got shape: " + Arrays.toString(iNDArray.shape()));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public SplitTestAndTrain splitTestAndTrain(int i, Random random) {
        shuffle(random.nextLong());
        return splitTestAndTrain(i);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public SplitTestAndTrain splitTestAndTrain(int i) {
        int numExamples = numExamples();
        if (numExamples <= 1) {
            throw new IllegalStateException("Cannot split DataSet with <= 1 rows (data set has " + numExamples + " example)");
        }
        if (i >= numExamples) {
            throw new IllegalArgumentException("Unable to split on size equal or larger than the number of rows (# numExamples=" + numExamples + ", numHoldout=" + i + ")");
        }
        DataSet dataSet = new DataSet();
        DataSet dataSet2 = new DataSet();
        switch (this.features.rank()) {
            case 2:
                dataSet.setFeatures(this.features.get(NDArrayIndex.interval(0, i), NDArrayIndex.all()));
                dataSet2.setFeatures(this.features.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all()));
                break;
            case 3:
                dataSet.setFeatures(this.features.get(NDArrayIndex.interval(0, i), NDArrayIndex.all(), NDArrayIndex.all()));
                dataSet2.setFeatures(this.features.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            case 4:
                dataSet.setFeatures(this.features.get(NDArrayIndex.interval(0, i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                dataSet2.setFeatures(this.features.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            default:
                throw new UnsupportedOperationException("Features rank: " + this.features.rank());
        }
        switch (this.labels.rank()) {
            case 2:
                dataSet.setLabels(this.labels.get(NDArrayIndex.interval(0, i), NDArrayIndex.all()));
                dataSet2.setLabels(this.labels.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all()));
                break;
            case 3:
                dataSet.setLabels(this.labels.get(NDArrayIndex.interval(0, i), NDArrayIndex.all(), NDArrayIndex.all()));
                dataSet2.setLabels(this.labels.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            case 4:
                dataSet.setLabels(this.labels.get(NDArrayIndex.interval(0, i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                dataSet2.setLabels(this.labels.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            default:
                throw new UnsupportedOperationException("Labels rank: " + this.features.rank());
        }
        if (this.featuresMask != null) {
            dataSet.setFeaturesMaskArray(this.featuresMask.get(NDArrayIndex.interval(0, i), NDArrayIndex.all()));
            dataSet2.setFeaturesMaskArray(this.featuresMask.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all()));
        }
        if (this.labelsMask != null) {
            dataSet.setLabelsMaskArray(this.labelsMask.get(NDArrayIndex.interval(0, i), NDArrayIndex.all()));
            dataSet2.setLabelsMaskArray(this.labelsMask.get(NDArrayIndex.interval(i, numExamples), NDArrayIndex.all()));
        }
        if (this.exampleMetaData != null) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < i && i2 < this.exampleMetaData.size(); i2++) {
                arrayList.add(this.exampleMetaData.get(i2));
            }
            for (int i3 = i; i3 < numExamples && i3 < this.exampleMetaData.size(); i3++) {
                arrayList2.add(this.exampleMetaData.get(i3));
            }
            dataSet.setExampleMetaData(arrayList);
            dataSet2.setExampleMetaData(arrayList2);
        }
        return new SplitTestAndTrain(dataSet, dataSet2);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getLabels() {
        return this.labels;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public String getLabelName(int i) {
        if (this.labelNames.isEmpty()) {
            throw new IllegalStateException("Label names are not defined on this dataset. Add label names in order to use getLabelName with an id.");
        }
        if (i < this.labelNames.size()) {
            return this.labelNames.get(i);
        }
        throw new IllegalStateException("Index requested is longer than the number of labels used for classification.");
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<String> getLabelNames(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArray.length(); i++) {
            arrayList.add(i, getLabelName(i));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void sortByLabel() {
        HashMap hashMap = new HashMap();
        List<DataSet> asList = asList();
        int numOutcomes = numOutcomes();
        int numExamples = numExamples();
        for (DataSet dataSet : asList) {
            int outcome = dataSet.outcome();
            Queue queue = (Queue) hashMap.get(Integer.valueOf(outcome));
            if (queue == null) {
                queue = new ArrayDeque();
                hashMap.put(Integer.valueOf(outcome), queue);
            }
            queue.add(dataSet);
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            log.info("Label " + entry + " has " + ((Queue) entry.getValue()).size() + " elements");
        }
        boolean z = true;
        int i = 0;
        while (i < numExamples) {
            if (z) {
                int i2 = 0;
                while (true) {
                    if (i2 >= numOutcomes) {
                        break;
                    }
                    Queue queue2 = (Queue) hashMap.get(Integer.valueOf(i2));
                    if (queue2 == null) {
                        z = false;
                        break;
                    }
                    DataSet dataSet2 = (DataSet) queue2.poll();
                    if (dataSet2 == null) {
                        z = false;
                        break;
                    } else {
                        addRow(dataSet2, i);
                        i++;
                        i2++;
                    }
                }
            } else {
                DataSet dataSet3 = null;
                Iterator it = hashMap.values().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Queue queue3 = (Queue) it.next();
                    if (!queue3.isEmpty()) {
                        dataSet3 = (DataSet) queue3.poll();
                        break;
                    }
                }
                addRow(dataSet3, i);
            }
            i++;
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void addRow(DataSet dataSet, int i) {
        if (i > numExamples() || dataSet == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        getFeatures().putRow(i, dataSet.getFeatures());
        getLabels().putRow(i, dataSet.getLabels());
    }

    private int getLabel(DataSet dataSet) {
        return Float.valueOf(dataSet.getLabels().maxNumber().floatValue()).intValue();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleSums() {
        return getFeatures().sum(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleMaxs() {
        return getFeatures().max(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray exampleMeans() {
        return getFeatures().mean(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i) {
        return sample(i, Nd4j.getRandom());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, org.nd4j.linalg.api.rng.Random random) {
        return sample(i, random, false);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, boolean z) {
        return sample(i, Nd4j.getRandom(), z);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public DataSet sample(int i, org.nd4j.linalg.api.rng.Random random, boolean z) {
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        boolean z2 = false;
        for (int i2 = 0; i2 < i && !z2; i2++) {
            int nextInt = random.nextInt(numExamples());
            if (z) {
                hashSet.add(Integer.valueOf(nextInt));
                arrayList.add(get(nextInt));
            }
            while (true) {
                if (hashSet.contains(Integer.valueOf(nextInt))) {
                    nextInt = random.nextInt(numExamples());
                    if (hashSet.size() == numExamples()) {
                        z2 = true;
                        break;
                    }
                }
            }
            hashSet.add(Integer.valueOf(nextInt));
            arrayList.add(get(nextInt));
        }
        return merge(arrayList);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void roundToTheNearest(int i) {
        for (int i2 = 0; i2 < getFeatures().length(); i2++) {
            getFeatures().put(i2, Nd4j.scalar(MathUtils.roundDouble(((Double) getFeatures().getScalar(i2).element()).doubleValue(), i)));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numOutcomes() {
        return (int) getLabels().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public int numExamples() {
        if (getFeatures() != null) {
            return (int) getFeatures().size(0);
        }
        if (getLabels() != null) {
            return (int) getLabels().size(0);
        }
        return 0;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.features == null || this.labels == null) {
            log.info("Features or labels are null values");
            return "";
        }
        sb.append("===========INPUT===================\n").append(getFeatures().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(getLabels().toString().replaceAll(";", "\n"));
        if (this.featuresMask != null) {
            sb.append("\n===========INPUT MASK===================\n").append(getFeaturesMaskArray().toString().replaceAll(";", "\n"));
        }
        if (this.labelsMask != null) {
            sb.append("\n===========OUTPUT MASK===================\n").append(getLabelsMaskArray().toString().replaceAll(";", "\n"));
        }
        return sb.toString();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    @Deprecated
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public List<String> getLabelNamesList() {
        return this.labelNames;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setLabelNames(List<String> list) {
        this.labelNames = list;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    @Deprecated
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    @Deprecated
    public void setColumnNames(List<String> list) {
        this.columnNames = list;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public SplitTestAndTrain splitTestAndTrain(double d) {
        Preconditions.checkArgument(d > 0.0d && d < 1.0d, "Train fraction must be > 0.0 and < 1.0 - got %s", d);
        int numExamples = (int) (d * numExamples());
        if (numExamples <= 0) {
            numExamples = 1;
        }
        return splitTestAndTrain(numExamples);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet, java.lang.Iterable
    public Iterator<DataSet> iterator() {
        return asList().iterator();
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getFeaturesMaskArray() {
        return this.featuresMask;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setFeaturesMaskArray(INDArray iNDArray) {
        this.featuresMask = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public INDArray getLabelsMaskArray() {
        return this.labelsMask;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void setLabelsMaskArray(INDArray iNDArray) {
        this.labelsMask = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public boolean hasMaskArrays() {
        return (this.labelsMask == null && this.featuresMask == null) ? false : true;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof DataSet)) {
            return false;
        }
        DataSet dataSet = (DataSet) obj;
        if (equalOrBothNull(this.features, dataSet.features) && equalOrBothNull(this.labels, dataSet.labels) && equalOrBothNull(this.featuresMask, dataSet.featuresMask)) {
            return equalOrBothNull(this.labelsMask, dataSet.labelsMask);
        }
        return false;
    }

    private static boolean equalOrBothNull(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null && iNDArray2 == null) {
            return true;
        }
        if (iNDArray == null || iNDArray2 == null) {
            return false;
        }
        return iNDArray.equals(iNDArray2);
    }

    public int hashCode() {
        return (31 * ((31 * ((31 * (getFeatures() != null ? getFeatures().hashCode() : 0)) + (getLabels() != null ? getLabels().hashCode() : 0))) + (getFeaturesMaskArray() != null ? getFeaturesMaskArray().hashCode() : 0))) + (getLabelsMaskArray() != null ? getLabelsMaskArray().hashCode() : 0);
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public long getMemoryFootprint() {
        return (this.features.lengthLong() * Nd4j.sizeOfDataType()) + (this.labels == null ? 0L : this.labels.lengthLong() * Nd4j.sizeOfDataType()) + (this.featuresMask == null ? 0L : this.featuresMask.lengthLong() * Nd4j.sizeOfDataType()) + (this.labelsMask == null ? 0L : this.labelsMask.lengthLong() * Nd4j.sizeOfDataType());
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void migrate() {
        if (Nd4j.getMemoryManager().getCurrentWorkspace() != null) {
            if (this.features != null) {
                this.features = this.features.migrate();
            }
            if (this.labels != null) {
                this.labels = this.labels.migrate();
            }
            if (this.featuresMask != null) {
                this.featuresMask = this.featuresMask.migrate();
            }
            if (this.labelsMask != null) {
                this.labelsMask = this.labelsMask.migrate();
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public void detach() {
        if (this.features != null) {
            this.features = this.features.detach();
        }
        if (this.labels != null) {
            this.labels = this.labels.detach();
        }
        if (this.featuresMask != null) {
            this.featuresMask = this.featuresMask.detach();
        }
        if (this.labelsMask != null) {
            this.labelsMask = this.labelsMask.detach();
        }
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public boolean isEmpty() {
        return this.features == null && this.labels == null && this.featuresMask == null && this.labelsMask == null;
    }

    @Override // org.nd4j.linalg.dataset.api.DataSet
    public org.nd4j.linalg.dataset.api.MultiDataSet toMultiDataSet() {
        INDArray features = getFeatures();
        INDArray labels = getLabels();
        INDArray featuresMaskArray = getFeaturesMaskArray();
        INDArray labelsMaskArray = getLabelsMaskArray();
        return new MultiDataSet(features == null ? null : new INDArray[]{features}, labels == null ? null : new INDArray[]{labels}, featuresMaskArray != null ? new INDArray[]{featuresMaskArray} : null, labelsMaskArray != null ? new INDArray[]{labelsMaskArray} : null);
    }
}
