package io.ballerina.messaging.broker.core.transaction;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.ballerina.messaging.broker.common.ValidationException;
import io.ballerina.messaging.broker.core.BrokerException;
import io.ballerina.messaging.broker.core.store.MessageStore;
import io.ballerina.messaging.broker.core.transaction.Branch;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.concurrent.ThreadSafe;
import javax.transaction.xa.Xid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ThreadSafe
/* loaded from: input_file:io/ballerina/messaging/broker/core/transaction/Registry.class */
public class Registry {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) Registry.class);
    private static final String ASSOCIATED_XID_ERROR_MSG = "Branch still has associated active sessions for xid ";
    private static final String TIMED_OUT_ERROR_MSG = "Transaction timed out for xid ";
    private final BranchFactory branchFactory;
    private final Map<Xid, Branch> branchMap = new ConcurrentHashMap();
    private final Set<Xid> storedXidSet = ConcurrentHashMap.newKeySet();
    private final ScheduledExecutorService branchTimeoutExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactoryBuilder().setNameFormat("DtxBranchTimeoutExecutor-%d").build());

    /* JADX INFO: Access modifiers changed from: package-private */
    public Registry(BranchFactory branchFactory) {
        this.branchFactory = branchFactory;
    }

    public void register(Branch branch) throws ValidationException {
        if (Objects.nonNull(this.branchMap.putIfAbsent(branch.getXid(), branch))) {
            throw new ValidationException("Branch with the same xid " + branch.getXid() + " is already registered.");
        }
    }

    public void unregister(Xid xid) {
        if (Objects.isNull(this.branchMap.remove(xid))) {
            this.storedXidSet.remove(xid);
        }
    }

    public Branch getBranch(Xid xid) throws ValidationException {
        if (this.storedXidSet.contains(xid)) {
            throw new ValidationException("Branch is in prepared stage. Branch can be only be committed or rollbacked.");
        }
        return this.branchMap.get(xid);
    }

    public synchronized void prepare(Xid xid) throws ValidationException, BrokerException {
        if (this.storedXidSet.contains(xid)) {
            throw new DtxStateTransitionException(xid, Branch.State.PREPARED, Branch.State.PREPARED);
        }
        Branch branch = this.branchMap.get(xid);
        if (Objects.isNull(branch)) {
            throw new ValidationException(DistributedTransaction.UNKNOWN_XID_ERROR_MSG + xid);
        }
        if (branch.hasAssociatedActiveSessions()) {
            throw new ValidationException(ASSOCIATED_XID_ERROR_MSG + xid);
        }
        checkForBranchExpiration(branch);
        branch.clearAssociations();
        if (branch.getState() == Branch.State.ROLLBACK_ONLY) {
            throw new ValidationException("Transaction can only be rollbacked");
        }
        if (branch.getState() != Branch.State.ACTIVE) {
            throw new ValidationException("Cannot prepare a branch in state " + branch.getState());
        }
        branch.prepare();
    }

    private void checkForBranchExpiration(Branch branch) throws ValidationException {
        if (branch.isExpired() || !cancelTimeoutTask(branch)) {
            unregister(branch.getXid());
            throw new ValidationException(TIMED_OUT_ERROR_MSG + branch.getXid());
        }
    }

    public synchronized void commit(Xid xid, boolean z) throws ValidationException, BrokerException {
        Branch branch = this.branchMap.get(xid);
        if (Objects.isNull(branch)) {
            branch = checkForBranchRecovery(xid);
        } else {
            if (branch.hasAssociatedActiveSessions()) {
                throw new ValidationException(ASSOCIATED_XID_ERROR_MSG + xid);
            }
            checkForBranchExpiration(branch);
            if (branch.isRollbackOnly()) {
                throw new ValidationException("Branch is set to rollback only. Can't commit with xid " + xid);
            }
            if (!z && !branch.isPrepared()) {
                throw new ValidationException("Cannot call two-phase commit on a non-prepared branch for xid " + xid);
            }
        }
        if (z && branch.isPrepared()) {
            throw new ValidationException("Cannot call one-phase commit on a prepared branch for xid " + xid);
        }
        branch.clearAssociations();
        branch.commit(z);
        branch.setState(Branch.State.FORGOTTEN);
        unregister(xid);
    }

    private Branch checkForBranchRecovery(Xid xid) throws UnknownDtxBranchException {
        if (!this.storedXidSet.contains(xid)) {
            throw new UnknownDtxBranchException(xid);
        }
        Branch createBranch = this.branchFactory.createBranch(xid);
        createBranch.markAsRecoveryBranch();
        return createBranch;
    }

    private boolean cancelTimeoutTask(Branch branch) {
        Future timeoutTaskFuture = branch.getTimeoutTaskFuture();
        return Objects.isNull(timeoutTaskFuture) || timeoutTaskFuture.isCancelled() || timeoutTaskFuture.cancel(false);
    }

    public synchronized void rollback(Xid xid) throws ValidationException, BrokerException {
        Branch branch = this.branchMap.get(xid);
        if (Objects.isNull(branch)) {
            branch = checkForBranchRecovery(xid);
        } else {
            checkForBranchExpiration(branch);
            if (branch.hasAssociatedActiveSessions()) {
                throw new ValidationException(ASSOCIATED_XID_ERROR_MSG + xid);
            }
            branch.clearAssociations();
        }
        branch.dtxRollback();
        branch.setState(Branch.State.FORGOTTEN);
        unregister(xid);
    }

    public void forget(Xid xid) throws ValidationException {
        Branch branch = this.branchMap.get(xid);
        if (Objects.isNull(branch)) {
            throw new ValidationException(DistributedTransaction.UNKNOWN_XID_ERROR_MSG + xid);
        }
        synchronized (branch) {
            if (branch.hasAssociatedActiveSessions()) {
                throw new ValidationException(ASSOCIATED_XID_ERROR_MSG + xid);
            }
            if (branch.getState() != Branch.State.HEUR_COM && branch.getState() != Branch.State.HEUR_RB) {
                throw new ValidationException("Branch is not heuristically complete, hence unable to forget. Xid " + xid);
            }
            branch.setState(Branch.State.FORGOTTEN);
            unregister(xid);
        }
    }

    public void setTimeout(Xid xid, long j, TimeUnit timeUnit) throws ValidationException {
        Branch branch = this.branchMap.get(xid);
        if (Objects.isNull(branch)) {
            throw new ValidationException(DistributedTransaction.UNKNOWN_XID_ERROR_MSG + xid);
        }
        if (j == 0) {
            return;
        }
        branch.setTimeoutTaskFuture(this.branchTimeoutExecutorService.schedule(() -> {
            LOGGER.debug("timing out dtx task with xid {}", xid);
            synchronized (branch) {
                if (branch.isPrepared()) {
                    LOGGER.debug("Branch already prepared. Won't be timed out. Xid {}", xid);
                    return;
                }
                try {
                    rollback(xid);
                    branch.setState(Branch.State.TIMED_OUT);
                } catch (ValidationException | BrokerException e) {
                    LOGGER.error("Error occurred while rolling back timed out branch with Xid " + xid, e);
                }
            }
        }, j, timeUnit));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void syncWithMessageStore(MessageStore messageStore) throws BrokerException {
        this.storedXidSet.clear();
        Set<Xid> set = this.storedXidSet;
        set.getClass();
        messageStore.retrieveStoredXids((v1) -> {
            r1.add(v1);
        });
    }
}
