/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;

public abstract class BaseRecurrentLayer
extends FeedForwardLayer {
    protected IWeightInit weightInitFnRecurrent;

    protected BaseRecurrentLayer(Builder builder) {
        super(builder);
        this.weightInitFnRecurrent = builder.weightInitFnRecurrent;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex + ", layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent)inputType;
        return InputType.recurrent(this.nOut, itr.getTimeSeriesLength());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        if (this.nIn <= 0L || override) {
            InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent)inputType;
            this.nIn = r.getSize();
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, this.getLayerName());
    }

    public IWeightInit getWeightInitFnRecurrent() {
        return this.weightInitFnRecurrent;
    }

    public void setWeightInitFnRecurrent(IWeightInit weightInitFnRecurrent) {
        this.weightInitFnRecurrent = weightInitFnRecurrent;
    }

    public BaseRecurrentLayer() {
    }

    @Override
    public String toString() {
        return "BaseRecurrentLayer(super=" + super.toString() + ", weightInitFnRecurrent=" + this.getWeightInitFnRecurrent() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BaseRecurrentLayer)) {
            return false;
        }
        BaseRecurrentLayer other = (BaseRecurrentLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        IWeightInit this$weightInitFnRecurrent = this.getWeightInitFnRecurrent();
        IWeightInit other$weightInitFnRecurrent = other.getWeightInitFnRecurrent();
        return !(this$weightInitFnRecurrent == null ? other$weightInitFnRecurrent != null : !this$weightInitFnRecurrent.equals(other$weightInitFnRecurrent));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof BaseRecurrentLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        IWeightInit $weightInitFnRecurrent = this.getWeightInitFnRecurrent();
        result = result * 59 + ($weightInitFnRecurrent == null ? 43 : $weightInitFnRecurrent.hashCode());
        return result;
    }

    public static abstract class Builder<T extends Builder<T>>
    extends FeedForwardLayer.Builder<T> {
        protected List<LayerConstraint> recurrentConstraints;
        protected List<LayerConstraint> inputWeightConstraints;
        protected IWeightInit weightInitFnRecurrent;

        public T constrainRecurrent(LayerConstraint ... constraints) {
            this.setRecurrentConstraints(Arrays.asList(constraints));
            return (T)this;
        }

        public T constrainInputWeights(LayerConstraint ... constraints) {
            this.setInputWeightConstraints(Arrays.asList(constraints));
            return (T)this;
        }

        public T weightInitRecurrent(IWeightInit weightInit) {
            this.setWeightInitFnRecurrent(weightInit);
            return (T)this;
        }

        public T weightInitRecurrent(WeightInit weightInit) {
            if (weightInit == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use weightInit(Distribution distribution) instead!");
            }
            this.setWeightInitFnRecurrent(weightInit.getWeightInitFunction());
            return (T)this;
        }

        public T weightInitRecurrent(Distribution dist) {
            this.setWeightInitFnRecurrent(new WeightInitDistribution(dist));
            return (T)this;
        }

        public List<LayerConstraint> getRecurrentConstraints() {
            return this.recurrentConstraints;
        }

        public List<LayerConstraint> getInputWeightConstraints() {
            return this.inputWeightConstraints;
        }

        public IWeightInit getWeightInitFnRecurrent() {
            return this.weightInitFnRecurrent;
        }

        public void setRecurrentConstraints(List<LayerConstraint> recurrentConstraints) {
            this.recurrentConstraints = recurrentConstraints;
        }

        public void setInputWeightConstraints(List<LayerConstraint> inputWeightConstraints) {
            this.inputWeightConstraints = inputWeightConstraints;
        }

        public void setWeightInitFnRecurrent(IWeightInit weightInitFnRecurrent) {
            this.weightInitFnRecurrent = weightInitFnRecurrent;
        }
    }
}

