package org.biscuitsec.biscuit.datalog;

import biscuit.format.schema.Schema;
import io.vavr.API;
import io.vavr.Tuple2;
import io.vavr.Tuple3;
import io.vavr.control.Either;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Spliterators;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.biscuitsec.biscuit.datalog.Term;
import org.biscuitsec.biscuit.datalog.expressions.Expression;
import org.biscuitsec.biscuit.error.Error;

/* loaded from: input_file:org/biscuitsec/biscuit/datalog/Rule.class */
public final class Rule implements Serializable {
    private final Predicate head;
    private final List<Predicate> body;
    private final List<Expression> expressions;
    private final List<Scope> scopes;

    public final Predicate head() {
        return this.head;
    }

    public final List<Predicate> body() {
        return this.body;
    }

    public final List<Expression> expressions() {
        return this.expressions;
    }

    public List<Scope> scopes() {
        return this.scopes;
    }

    public Stream<Either<Error, Tuple2<Origin, Fact>>> apply(Supplier<Stream<Tuple2<Origin, Fact>>> supplier, Long l, SymbolTable symbolTable) {
        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Combinator(variablesSet(), this.body, supplier, symbolTable), 16), false).map(tuple2 -> {
            Origin origin = (Origin) tuple2._1;
            Map<Long, Term> map = (Map) tuple2._2;
            TemporarySymbolTable temporarySymbolTable = new TemporarySymbolTable(symbolTable);
            Iterator<Expression> it = this.expressions.iterator();
            while (it.hasNext()) {
                try {
                    Term evaluate = it.next().evaluate(map, temporarySymbolTable);
                    if (!(evaluate instanceof Term.Bool)) {
                        return Either.left(new Error.InvalidType());
                    }
                    if (!((Term.Bool) evaluate).value()) {
                        return Either.right(new Tuple3(origin, map, false));
                    }
                } catch (Error e) {
                    return Either.left(e);
                }
            }
            return Either.right(new Tuple3(origin, map, true));
        }).filter(either -> {
            return either.isRight() & ((Boolean) ((Tuple3) either.get())._3).booleanValue();
        }).map(either2 -> {
            Tuple3 tuple3 = (Tuple3) either2.get();
            Origin origin = (Origin) tuple3._1;
            Map map = (Map) tuple3._2;
            Predicate m1344clone = this.head.m1344clone();
            for (int i = 0; i < m1344clone.terms().size(); i++) {
                if (m1344clone.terms().get(i) instanceof Term.Variable) {
                    Term.Variable variable = (Term.Variable) m1344clone.terms().get(i);
                    if (!map.containsKey(Long.valueOf(variable.value()))) {
                        return Either.left(new Error.InternalError());
                    }
                    m1344clone.terms().set(i, (Term) map.get(Long.valueOf(variable.value())));
                }
            }
            origin.add(l.longValue());
            return Either.right(new Tuple2(origin, new Fact(m1344clone)));
        });
    }

    private MatchedVariables variablesSet() {
        HashSet hashSet = new HashSet();
        Iterator<Predicate> it = this.body.iterator();
        while (it.hasNext()) {
            hashSet.addAll((Collection) it.next().terms().stream().filter(term -> {
                return term instanceof Term.Variable;
            }).map(term2 -> {
                return Long.valueOf(((Term.Variable) term2).value());
            }).collect(Collectors.toSet()));
        }
        return new MatchedVariables(hashSet);
    }

    public boolean find_match(FactSet factSet, Long l, TrustedOrigins trustedOrigins, SymbolTable symbolTable) throws Error {
        MatchedVariables variablesSet = variablesSet();
        if (this.body.isEmpty()) {
            return variablesSet.check_expressions(this.expressions, symbolTable).isDefined();
        }
        Iterator<Either<Error, Tuple2<Origin, Fact>>> it = apply(() -> {
            return factSet.stream(trustedOrigins);
        }, l, symbolTable).iterator();
        if (!it.hasNext()) {
            return false;
        }
        Either<Error, Tuple2<Origin, Fact>> next = it.next();
        if (next.isRight()) {
            return true;
        }
        throw ((Error) next.getLeft());
    }

    public boolean check_match_all(FactSet factSet, TrustedOrigins trustedOrigins, SymbolTable symbolTable) throws Error {
        MatchedVariables variablesSet = variablesSet();
        if (this.body.isEmpty()) {
            return variablesSet.check_expressions(this.expressions, symbolTable).isDefined();
        }
        Combinator combinator = new Combinator(variablesSet, this.body, () -> {
            return factSet.stream(trustedOrigins);
        }, symbolTable);
        boolean z = false;
        while (combinator.hasNext()) {
            Map<Long, Term> map = (Map) combinator.next()._2;
            z = true;
            TemporarySymbolTable temporarySymbolTable = new TemporarySymbolTable(symbolTable);
            Iterator<Expression> it = this.expressions.iterator();
            while (it.hasNext()) {
                Term evaluate = it.next().evaluate(map, temporarySymbolTable);
                if (!(evaluate instanceof Term.Bool)) {
                    throw new Error.InvalidType();
                }
                if (!((Term.Bool) evaluate).value()) {
                    return false;
                }
            }
        }
        return z;
    }

    public Rule(Predicate predicate, List<Predicate> list, List<Expression> list2) {
        this.head = predicate;
        this.body = list;
        this.expressions = list2;
        this.scopes = new ArrayList();
    }

    public Rule(Predicate predicate, List<Predicate> list, List<Expression> list2, List<Scope> list3) {
        this.head = predicate;
        this.body = list;
        this.expressions = list2;
        this.scopes = list3;
    }

    public Schema.RuleV2 serialize() {
        Schema.RuleV2.Builder head = Schema.RuleV2.newBuilder().setHead(this.head.serialize());
        for (int i = 0; i < this.body.size(); i++) {
            head.addBody(this.body.get(i).serialize());
        }
        for (int i2 = 0; i2 < this.expressions.size(); i2++) {
            head.addExpressions(this.expressions.get(i2).serialize());
        }
        Iterator<Scope> it = this.scopes.iterator();
        while (it.hasNext()) {
            head.addScope(it.next().serialize());
        }
        return head.m951build();
    }

    public static Either<Error.FormatError, Rule> deserializeV2(Schema.RuleV2 ruleV2) {
        ArrayList arrayList = new ArrayList();
        Iterator<Schema.PredicateV2> it = ruleV2.getBodyList().iterator();
        while (it.hasNext()) {
            Either<Error.FormatError, Predicate> deserializeV2 = Predicate.deserializeV2(it.next());
            if (deserializeV2.isLeft()) {
                return API.Left((Error.FormatError) deserializeV2.getLeft());
            }
            arrayList.add((Predicate) deserializeV2.get());
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator<Schema.ExpressionV2> it2 = ruleV2.getExpressionsList().iterator();
        while (it2.hasNext()) {
            Either<Error.FormatError, Expression> deserializeV22 = Expression.deserializeV2(it2.next());
            if (deserializeV22.isLeft()) {
                return API.Left((Error.FormatError) deserializeV22.getLeft());
            }
            arrayList2.add((Expression) deserializeV22.get());
        }
        ArrayList arrayList3 = new ArrayList();
        Iterator<Schema.Scope> it3 = ruleV2.getScopeList().iterator();
        while (it3.hasNext()) {
            Either<Error.FormatError, Scope> deserialize = Scope.deserialize(it3.next());
            if (deserialize.isLeft()) {
                return API.Left((Error.FormatError) deserialize.getLeft());
            }
            arrayList3.add((Scope) deserialize.get());
        }
        Either<Error.FormatError, Predicate> deserializeV23 = Predicate.deserializeV2(ruleV2.getHead());
        return deserializeV23.isLeft() ? API.Left((Error.FormatError) deserializeV23.getLeft()) : API.Right(new Rule((Predicate) deserializeV23.get(), arrayList, arrayList2, arrayList3));
    }
}
