package com.github.chen0040.rl.learning.actorcritic;

import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.class */
public class ActorCriticLambdaLearner extends ActorCriticLearner {
    private Matrix e;
    private double lambda;
    private EligibilityTraceUpdateMode traceUpdateMode;

    public ActorCriticLambdaLearner() {
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
    }

    public ActorCriticLambdaLearner(int i, int i2) {
        super(i, i2);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.e = new Matrix(i, i2);
    }

    public ActorCriticLambdaLearner(ActorCriticLearner actorCriticLearner) {
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        copy(actorCriticLearner);
        this.e = new Matrix(this.P.getStateCount(), this.P.getActionCount());
    }

    public ActorCriticLambdaLearner(int i, int i2, double d, double d2, double d3, double d4) {
        super(i, i2, d, d2, d4);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.lambda = d3;
        this.e = new Matrix(i, i2);
    }

    public EligibilityTraceUpdateMode getTraceUpdateMode() {
        return this.traceUpdateMode;
    }

    public void setTraceUpdateMode(EligibilityTraceUpdateMode eligibilityTraceUpdateMode) {
        this.traceUpdateMode = eligibilityTraceUpdateMode;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    @Override // com.github.chen0040.rl.learning.actorcritic.ActorCriticLearner
    public ActorCriticLambdaLearner makeCopy() {
        ActorCriticLambdaLearner actorCriticLambdaLearner = new ActorCriticLambdaLearner();
        actorCriticLambdaLearner.copy(this);
        return actorCriticLambdaLearner;
    }

    @Override // com.github.chen0040.rl.learning.actorcritic.ActorCriticLearner
    public void copy(ActorCriticLearner actorCriticLearner) {
        super.copy(actorCriticLearner);
        ActorCriticLambdaLearner actorCriticLambdaLearner = (ActorCriticLambdaLearner) actorCriticLearner;
        this.e = actorCriticLambdaLearner.e.makeCopy();
        this.lambda = actorCriticLambdaLearner.lambda;
        this.traceUpdateMode = actorCriticLambdaLearner.traceUpdateMode;
    }

    @Override // com.github.chen0040.rl.learning.actorcritic.ActorCriticLearner
    public boolean equals(Object obj) {
        if (!super.equals(obj) || !(obj instanceof ActorCriticLambdaLearner)) {
            return false;
        }
        ActorCriticLambdaLearner actorCriticLambdaLearner = (ActorCriticLambdaLearner) obj;
        return this.e.equals(actorCriticLambdaLearner.e) && this.lambda == actorCriticLambdaLearner.lambda && this.traceUpdateMode == actorCriticLambdaLearner.traceUpdateMode;
    }

    public Matrix getEligibility() {
        return this.e;
    }

    public void setEligibility(Matrix matrix) {
        this.e = matrix;
    }

    @Override // com.github.chen0040.rl.learning.actorcritic.ActorCriticLearner
    public void update(int i, int i2, int i3, Set<Integer> set, double d, Function<Integer, Double> function) {
        double doubleValue = (d + function.apply(Integer.valueOf(i3)).doubleValue()) - function.apply(Integer.valueOf(i)).doubleValue();
        int stateCount = this.P.getStateCount();
        int actionCount = this.P.getActionCount();
        double gamma = this.P.getGamma();
        this.e.set(i, i2, this.e.get(i, i2) + 1.0d);
        for (int i4 = 0; i4 < stateCount; i4++) {
            for (int i5 = 0; i5 < actionCount; i5++) {
                this.P.setQ(i4, i5, this.P.getQ(i4, i5) + (this.P.getAlpha(i, i2) * doubleValue * this.e.get(i4, i5)));
                if (i5 != i2) {
                    this.e.set(i, i5, 0.0d);
                } else {
                    this.e.set(i4, i5, this.e.get(i4, i5) * gamma * this.lambda);
                }
            }
        }
    }
}
