/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.writable.batch;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.AbstractWritableRecordBatch;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class NDArrayRecordBatch
extends AbstractWritableRecordBatch {
    private List<INDArray> arrays;
    private long size;

    public NDArrayRecordBatch(INDArray ... arrays) {
        this(Arrays.asList(arrays));
    }

    public NDArrayRecordBatch(@NonNull List<INDArray> arrays) {
        if (arrays == null) {
            throw new NullPointerException("arrays is marked @NonNull but is null");
        }
        Preconditions.checkArgument((arrays.size() > 0 ? 1 : 0) != 0, (Object)"Input list must not be empty");
        this.arrays = arrays;
        if (arrays.size() > 1) {
            this.size = arrays.get(0).size(0);
            for (int i = 1; i < arrays.size(); ++i) {
                if (this.size == arrays.get(i).size(0)) continue;
                throw new IllegalArgumentException("Invalid input arrays: all arrays must have same size fordimension 0. arrays.get(0).size(0)=" + this.size + ", arrays.get(" + i + ").size(0)=" + arrays.get(i).size(0));
            }
        }
    }

    @Override
    public int size() {
        return (int)this.size;
    }

    @Override
    public List<Writable> get(int index) {
        Preconditions.checkArgument((index >= 0 && (long)index < this.size ? 1 : 0) != 0, (Object)("Invalid index: " + index + ", size = " + this.size));
        ArrayList<Writable> out = new ArrayList<Writable>((int)this.size);
        for (INDArray orig : this.arrays) {
            INDArray view = NDArrayRecordBatch.getExample(index, orig);
            out.add(new NDArrayWritable(view));
        }
        return out;
    }

    private static INDArray getExample(int idx, INDArray from) {
        INDArrayIndex[] idxs = new INDArrayIndex[from.rank()];
        idxs[0] = NDArrayIndex.interval((long)idx, (long)idx, (boolean)true);
        for (int i = 1; i < from.rank(); ++i) {
            idxs[i] = NDArrayIndex.all();
        }
        return from.get(idxs);
    }

    public List<INDArray> getArrays() {
        return this.arrays;
    }

    public long getSize() {
        return this.size;
    }

    public void setArrays(List<INDArray> arrays) {
        this.arrays = arrays;
    }

    public void setSize(long size) {
        this.size = size;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof NDArrayRecordBatch)) {
            return false;
        }
        NDArrayRecordBatch other = (NDArrayRecordBatch)o;
        if (!other.canEqual(this)) {
            return false;
        }
        List<INDArray> this$arrays = this.getArrays();
        List<INDArray> other$arrays = other.getArrays();
        if (this$arrays == null ? other$arrays != null : !((Object)this$arrays).equals(other$arrays)) {
            return false;
        }
        return this.getSize() == other.getSize();
    }

    protected boolean canEqual(Object other) {
        return other instanceof NDArrayRecordBatch;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        List<INDArray> $arrays = this.getArrays();
        result = result * 59 + ($arrays == null ? 43 : ((Object)$arrays).hashCode());
        long $size = this.getSize();
        result = result * 59 + (int)($size >>> 32 ^ $size);
        return result;
    }

    public String toString() {
        return "NDArrayRecordBatch(arrays=" + this.getArrays() + ", size=" + this.getSize() + ")";
    }
}

