package org.nd4j.weightinit;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/weightinit/BaseWeightInitScheme.class */
public abstract class BaseWeightInitScheme implements WeightInitScheme {
    private char order;

    public BaseWeightInitScheme() {
        this('c');
    }

    public BaseWeightInitScheme(char c) {
        this.order = c;
    }

    public abstract INDArray doCreate(int[] iArr, INDArray iNDArray);

    @Override // org.nd4j.weightinit.WeightInitScheme
    public INDArray create(int[] iArr, INDArray iNDArray) {
        return handleParamsView(doCreate(iArr, iNDArray), iNDArray);
    }

    @Override // org.nd4j.weightinit.WeightInitScheme
    public INDArray create(int[] iArr) {
        return doCreate(iArr, null);
    }

    @Override // org.nd4j.weightinit.WeightInitScheme
    public char order() {
        return this.order;
    }

    protected INDArray handleParamsView(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray2 == null || iNDArray2 == iNDArray) {
            return iNDArray;
        }
        INDArray flattened = Nd4j.toFlattened(order(), iNDArray);
        if (flattened.length() != iNDArray2.length()) {
            throw new RuntimeException("ParamView length does not match initialized weights length (view length: " + iNDArray2.length() + ", view shape: " + Arrays.toString(iNDArray2.shape()) + "; flattened length: " + flattened.length());
        }
        iNDArray2.assign(flattened);
        return iNDArray2.reshape(order(), iNDArray.shape());
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BaseWeightInitScheme)) {
            return false;
        }
        BaseWeightInitScheme baseWeightInitScheme = (BaseWeightInitScheme) obj;
        return baseWeightInitScheme.canEqual(this) && this.order == baseWeightInitScheme.order;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BaseWeightInitScheme;
    }

    public int hashCode() {
        return (1 * 59) + this.order;
    }
}
