package com.github.chen0040.rl.learning.sarsa;

import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;

/* loaded from: input_file:com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.class */
public class SarsaLambdaLearner extends SarsaLearner {
    private double lambda;
    private Matrix e;
    private EligibilityTraceUpdateMode traceUpdateMode;

    public EligibilityTraceUpdateMode getTraceUpdateMode() {
        return this.traceUpdateMode;
    }

    public void setTraceUpdateMode(EligibilityTraceUpdateMode eligibilityTraceUpdateMode) {
        this.traceUpdateMode = eligibilityTraceUpdateMode;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public Object clone() {
        SarsaLambdaLearner sarsaLambdaLearner = new SarsaLambdaLearner();
        sarsaLambdaLearner.copy(this);
        return sarsaLambdaLearner;
    }

    @Override // com.github.chen0040.rl.learning.sarsa.SarsaLearner
    public void copy(SarsaLearner sarsaLearner) {
        super.copy(sarsaLearner);
        SarsaLambdaLearner sarsaLambdaLearner = (SarsaLambdaLearner) sarsaLearner;
        this.lambda = sarsaLambdaLearner.lambda;
        this.e = sarsaLambdaLearner.e.makeCopy();
        this.traceUpdateMode = sarsaLambdaLearner.traceUpdateMode;
    }

    @Override // com.github.chen0040.rl.learning.sarsa.SarsaLearner
    public boolean equals(Object obj) {
        if (!super.equals(obj) || !(obj instanceof SarsaLambdaLearner)) {
            return false;
        }
        SarsaLambdaLearner sarsaLambdaLearner = (SarsaLambdaLearner) obj;
        return sarsaLambdaLearner.lambda == this.lambda && this.e.equals(sarsaLambdaLearner.e) && this.traceUpdateMode == sarsaLambdaLearner.traceUpdateMode;
    }

    public SarsaLambdaLearner() {
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
    }

    public SarsaLambdaLearner(int i, int i2) {
        super(i, i2);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.e = new Matrix(i, i2);
    }

    public SarsaLambdaLearner(int i, int i2, double d, double d2, double d3) {
        super(i, i2, d, d2, d3);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.e = new Matrix(i, i2);
    }

    public SarsaLambdaLearner(SarsaLearner sarsaLearner) {
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        copy(sarsaLearner);
        this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount());
    }

    public Matrix getEligibility() {
        return this.e;
    }

    public void setEligibility(Matrix matrix) {
        this.e = matrix;
    }

    @Override // com.github.chen0040.rl.learning.sarsa.SarsaLearner
    public void update(int i, int i2, int i3, int i4, double d) {
        double q = this.model.getQ(i, i2);
        double alpha = this.model.getAlpha(i, i2);
        double gamma = this.model.getGamma();
        double q2 = (d + (gamma * this.model.getQ(i3, i4))) - q;
        int stateCount = this.model.getStateCount();
        int actionCount = this.model.getActionCount();
        this.e.set(i, i2, this.e.get(i, i2) + 1.0d);
        for (int i5 = 0; i5 < stateCount; i5++) {
            for (int i6 = 0; i6 < actionCount; i6++) {
                this.model.setQ(i5, i6, this.model.getQ(i5, i6) + (alpha * q2 * this.e.get(i5, i6)));
                if (i6 != i2) {
                    this.e.set(i, i6, 0.0d);
                } else {
                    this.e.set(i5, i6, this.e.get(i5, i6) * gamma * this.lambda);
                }
            }
        }
    }
}
