/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.query.sqm.mutation.internal.inline;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.internal.util.collections.CollectionHelper;
import org.hibernate.metamodel.mapping.BasicEntityIdentifierMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.mapping.SelectableConsumer;
import org.hibernate.metamodel.spi.MappingMetamodelImplementor;
import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.query.SemanticException;
import org.hibernate.query.spi.DomainQueryExecutionContext;
import org.hibernate.query.sqm.ComparisonOperator;
import org.hibernate.query.sqm.internal.DomainParameterXref;
import org.hibernate.query.sqm.internal.SqmJdbcExecutionContextAdapter;
import org.hibernate.query.sqm.internal.SqmUtil;
import org.hibernate.query.sqm.mutation.internal.MatchingIdSelectionHelper;
import org.hibernate.query.sqm.mutation.internal.UpdateHandler;
import org.hibernate.query.sqm.mutation.internal.inline.MatchingIdRestrictionProducer;
import org.hibernate.query.sqm.spi.SqmParameterMappingModelResolutionAccess;
import org.hibernate.query.sqm.sql.SqmTranslation;
import org.hibernate.query.sqm.tree.expression.SqmParameter;
import org.hibernate.query.sqm.tree.from.SqmRoot;
import org.hibernate.query.sqm.tree.update.SqmUpdateStatement;
import org.hibernate.spi.NavigablePath;
import org.hibernate.sql.ast.SqlAstJoinType;
import org.hibernate.sql.ast.spi.SqlAliasBaseImpl;
import org.hibernate.sql.ast.spi.SqlSelection;
import org.hibernate.sql.ast.tree.MutationStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.expression.SqlTuple;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.from.TableReference;
import org.hibernate.sql.ast.tree.from.TableReferenceJoin;
import org.hibernate.sql.ast.tree.from.UnionTableReference;
import org.hibernate.sql.ast.tree.from.ValuesTableGroup;
import org.hibernate.sql.ast.tree.insert.InsertSelectStatement;
import org.hibernate.sql.ast.tree.insert.Values;
import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate;
import org.hibernate.sql.ast.tree.predicate.InListPredicate;
import org.hibernate.sql.ast.tree.predicate.NullnessPredicate;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.sql.ast.tree.select.QuerySpec;
import org.hibernate.sql.ast.tree.select.SelectClause;
import org.hibernate.sql.ast.tree.update.Assignment;
import org.hibernate.sql.ast.tree.update.UpdateStatement;
import org.hibernate.sql.exec.spi.ExecutionContext;
import org.hibernate.sql.exec.spi.JdbcOperationQueryMutation;
import org.hibernate.sql.exec.spi.JdbcParameterBindings;
import org.hibernate.sql.results.internal.SqlSelectionImpl;

public class InlineUpdateHandler
implements UpdateHandler {
    private final SqmUpdateStatement<?> sqmUpdate;
    private final DomainParameterXref domainParameterXref;
    private final MatchingIdRestrictionProducer matchingIdsPredicateProducer;
    private final SessionFactoryImplementor sessionFactory;

    public InlineUpdateHandler(MatchingIdRestrictionProducer matchingIdsPredicateProducer, SqmUpdateStatement<?> sqmUpdate, DomainParameterXref domainParameterXref, DomainQueryExecutionContext context) {
        this.matchingIdsPredicateProducer = matchingIdsPredicateProducer;
        this.domainParameterXref = domainParameterXref;
        this.sqmUpdate = sqmUpdate;
        this.sessionFactory = context.getSession().getFactory();
    }

    @Override
    public int execute(DomainQueryExecutionContext executionContext) {
        List<Object> ids = MatchingIdSelectionHelper.selectMatchingIds(this.sqmUpdate, this.domainParameterXref, executionContext);
        if (ids == null || ids.isEmpty()) {
            return 0;
        }
        this.domainParameterXref.clearExpansions();
        MappingMetamodelImplementor domainModel = this.sessionFactory.getRuntimeMetamodels().getMappingMetamodel();
        String mutatingEntityName = ((SqmRoot)this.sqmUpdate.getTarget()).getModel().getHibernateEntityName();
        EntityPersister entityDescriptor = domainModel.getEntityDescriptor(mutatingEntityName);
        List<Expression> inListExpressions = this.matchingIdsPredicateProducer.produceIdExpressionList(ids, entityDescriptor);
        final SqmTranslation<? extends MutationStatement> translation = this.sessionFactory.getQueryEngine().getSqmTranslatorFactory().createMutationTranslator(this.sqmUpdate, executionContext.getQueryOptions(), this.domainParameterXref, executionContext.getQueryParameterBindings(), executionContext.getSession().getLoadQueryInfluencers(), this.sessionFactory).translate();
        TableGroup updatingTableGroup = ((UpdateStatement)translation.getSqlAst()).getFromClause().getRoots().get(0);
        HashMap<String, TableReference> tableReferenceByAlias = CollectionHelper.mapOfSize(updatingTableGroup.getTableReferenceJoins().size() + 1);
        this.collectTableReference(updatingTableGroup.getPrimaryTableReference(), tableReferenceByAlias::put);
        for (int i = 0; i < updatingTableGroup.getTableReferenceJoins().size(); ++i) {
            this.collectTableReference(updatingTableGroup.getTableReferenceJoins().get(i), tableReferenceByAlias::put);
        }
        JdbcParameterBindings jdbcParameterBindings = SqmUtil.createJdbcParameterBindings(executionContext.getQueryParameterBindings(), this.domainParameterXref, SqmUtil.generateJdbcParamsXref(this.domainParameterXref, translation::getJdbcParamsBySqmParam), new SqmParameterMappingModelResolutionAccess(){

            @Override
            public <T> MappingModelExpressible<T> getResolvedMappingModelType(SqmParameter<T> parameter) {
                return translation.getSqmParameterMappingModelTypeResolutions().get(parameter);
            }
        }, executionContext.getSession());
        HashMap<TableReference, ArrayList<Assignment>> assignmentsByTable = new HashMap<TableReference, ArrayList<Assignment>>();
        List<Assignment> assignments = ((UpdateStatement)translation.getSqlAst()).getAssignments();
        for (int i = 0; i < assignments.size(); ++i) {
            Assignment assignment = assignments.get(i);
            List<ColumnReference> assignmentColumnRefs = assignment.getAssignable().getColumnReferences();
            TableReference assignmentTableReference = null;
            for (int c = 0; c < assignmentColumnRefs.size(); ++c) {
                ColumnReference columnReference = assignmentColumnRefs.get(c);
                TableReference tableReference = this.resolveTableReference(columnReference, tableReferenceByAlias);
                if (assignmentTableReference != null && assignmentTableReference != tableReference) {
                    throw new SemanticException("Assignment referred to columns from multiple tables: " + assignment.getAssignable());
                }
                assignmentTableReference = tableReference;
            }
            ArrayList<Assignment> assignmentsForTable = (ArrayList<Assignment>)assignmentsByTable.get(assignmentTableReference);
            if (assignmentsForTable == null) {
                assignmentsForTable = new ArrayList<Assignment>();
                assignmentsByTable.put(assignmentTableReference, assignmentsForTable);
            }
            assignmentsForTable.add(assignment);
        }
        int rows = ids.size();
        SqmJdbcExecutionContextAdapter executionContextAdapter = SqmJdbcExecutionContextAdapter.omittingLockingAndPaging(executionContext);
        entityDescriptor.visitConstraintOrderedTables((tableExpression, tableKeyColumnVisitationSupplier) -> this.updateTable(tableExpression, tableKeyColumnVisitationSupplier, entityDescriptor, updatingTableGroup, assignmentsByTable, inListExpressions, rows, jdbcParameterBindings, executionContextAdapter));
        return rows;
    }

    private void updateTable(String tableExpression, Supplier<Consumer<SelectableConsumer>> tableKeyColumnVisitationSupplier, EntityPersister entityDescriptor, TableGroup updatingTableGroup, Map<TableReference, List<Assignment>> assignmentsByTable, List<Expression> inListExpressions, int expectedUpdateCount, JdbcParameterBindings jdbcParameterBindings, ExecutionContext executionContext) {
        TableReference updatingTableReference = updatingTableGroup.getTableReference(updatingTableGroup.getNavigablePath(), tableExpression, false);
        List<Assignment> assignments = assignmentsByTable.get(updatingTableReference);
        if (assignments == null || assignments.isEmpty()) {
            return;
        }
        InListPredicate idListPredicate = (InListPredicate)this.matchingIdsPredicateProducer.produceRestriction(inListExpressions, entityDescriptor, 0, null, updatingTableReference, tableKeyColumnVisitationSupplier, executionContext);
        Expression keyExpression = idListPredicate.getTestExpression();
        NamedTableReference dmlTableReference = this.resolveUnionTableReference(updatingTableReference, tableExpression);
        UpdateStatement sqlAst = new UpdateStatement(dmlTableReference, assignments, (Predicate)idListPredicate);
        JdbcServices jdbcServices = this.sessionFactory.getJdbcServices();
        JdbcOperationQueryMutation jdbcUpdate = jdbcServices.getJdbcEnvironment().getSqlAstTranslatorFactory().buildMutationTranslator(this.sessionFactory, sqlAst).translate(jdbcParameterBindings, executionContext.getQueryOptions());
        int updateCount = jdbcServices.getJdbcMutationExecutor().execute(jdbcUpdate, jdbcParameterBindings, sql -> executionContext.getSession().getJdbcCoordinator().getStatementPreparer().prepareStatement((String)sql), (integer, preparedStatement) -> {}, executionContext);
        if (updateCount == expectedUpdateCount) {
            return;
        }
        EntityPersister entityPersister = entityDescriptor.getEntityPersister();
        boolean isNullable = false;
        for (int i2 = 0; i2 < entityPersister.getTableSpan(); ++i2) {
            if (!tableExpression.equals(entityPersister.getTableName(i2)) || !entityPersister.isNullableTable(i2)) continue;
            isNullable = true;
            break;
        }
        if (isNullable) {
            ComparisonPredicate joinPredicate;
            List<String> columnNames;
            QuerySpec querySpec = new QuerySpec(true);
            NavigablePath valuesPath = new NavigablePath("id");
            ArrayList<Values> valuesList = new ArrayList<Values>(inListExpressions.size());
            for (Expression inListExpression : inListExpressions) {
                if (inListExpression instanceof SqlTuple) {
                    valuesList.add(new Values(((SqlTuple)inListExpression).getExpressions()));
                    continue;
                }
                valuesList.add(new Values(Collections.singletonList(inListExpression)));
            }
            TableGroup rootTableGroup = entityDescriptor.createRootTableGroup(true, updatingTableGroup.getNavigablePath(), updatingTableGroup.getSourceAlias(), new SqlAliasBaseImpl(updatingTableGroup.getGroupAlias()), () -> predicate -> {}, null);
            if (keyExpression instanceof SqlTuple) {
                List<? extends Expression> expressions = ((SqlTuple)keyExpression).getExpressions();
                ArrayList lhs = new ArrayList(expressions.size());
                ArrayList rhs = new ArrayList(expressions.size());
                columnNames = new ArrayList<String>(expressions.size());
                entityDescriptor.getIdentifierMapping().forEachSelectable((i, selectableMapping) -> {
                    Expression expression = (Expression)expressions.get(i);
                    ColumnReference columnReference = expression.getColumnReference();
                    ColumnReference valuesColumnReference = new ColumnReference(valuesPath.getLocalName(), columnReference.getColumnExpression(), false, null, columnReference.getJdbcMapping());
                    columnNames.add(columnReference.getColumnExpression());
                    lhs.add(valuesColumnReference);
                    rhs.add(new ColumnReference(rootTableGroup.getPrimaryTableReference(), selectableMapping.getSelectionExpression(), false, null, columnReference.getJdbcMapping()));
                    querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(valuesColumnReference));
                });
                joinPredicate = new ComparisonPredicate(new SqlTuple(lhs, entityDescriptor.getIdentifierMapping()), ComparisonOperator.EQUAL, new SqlTuple(rhs, entityDescriptor.getIdentifierMapping()));
            } else {
                ColumnReference columnReference = keyExpression.getColumnReference();
                ColumnReference valuesColumnReference = new ColumnReference(valuesPath.getLocalName(), columnReference.getColumnExpression(), false, null, columnReference.getJdbcMapping());
                columnNames = Collections.singletonList(columnReference.getColumnExpression());
                joinPredicate = new ComparisonPredicate(valuesColumnReference, ComparisonOperator.EQUAL, new ColumnReference(rootTableGroup.getPrimaryTableReference(), ((BasicEntityIdentifierMapping)entityDescriptor.getIdentifierMapping()).getSelectionExpression(), false, null, columnReference.getJdbcMapping()));
                querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(valuesColumnReference));
            }
            ValuesTableGroup valuesTableGroup = new ValuesTableGroup(valuesPath, null, valuesList, valuesPath.getLocalName(), columnNames, true, this.sessionFactory);
            valuesTableGroup.addNestedTableGroupJoin(new TableGroupJoin(rootTableGroup.getNavigablePath(), SqlAstJoinType.LEFT, rootTableGroup, joinPredicate));
            querySpec.getFromClause().addRoot(valuesTableGroup);
            querySpec.applyPredicate(new NullnessPredicate(new ColumnReference(rootTableGroup.resolveTableReference(tableExpression), columnNames.get(0), entityDescriptor.getIdentifierMapping().getSingleJdbcMapping())));
            ArrayList<? extends Expression> targetColumnReferences = new ArrayList<Expression>();
            if (keyExpression instanceof SqlTuple) {
                targetColumnReferences.addAll(((SqlTuple)keyExpression).getExpressions());
            } else {
                targetColumnReferences.add((ColumnReference)keyExpression);
            }
            for (Assignment assignment : assignments) {
                targetColumnReferences.addAll(assignment.getAssignable().getColumnReferences());
                querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(assignment.getAssignedValue()));
            }
            InsertSelectStatement insertSqlAst = new InsertSelectStatement(dmlTableReference);
            insertSqlAst.addTargetColumnReferences(targetColumnReferences.toArray(new ColumnReference[0]));
            insertSqlAst.setSourceSelectStatement(querySpec);
            JdbcOperationQueryMutation jdbcInsert = jdbcServices.getJdbcEnvironment().getSqlAstTranslatorFactory().buildMutationTranslator(this.sessionFactory, insertSqlAst).translate(jdbcParameterBindings, executionContext.getQueryOptions());
            int insertCount = jdbcServices.getJdbcMutationExecutor().execute(jdbcInsert, jdbcParameterBindings, sql -> executionContext.getSession().getJdbcCoordinator().getStatementPreparer().prepareStatement((String)sql), (integer, preparedStatement) -> {}, executionContext);
            assert (insertCount + updateCount == expectedUpdateCount);
        }
    }

    private Expression asExpression(SelectClause selectClause) {
        List<SqlSelection> sqlSelections = selectClause.getSqlSelections();
        if (sqlSelections.size() == 1) {
            return sqlSelections.get(0).getExpression();
        }
        ArrayList<Expression> expressions = new ArrayList<Expression>(sqlSelections.size());
        for (SqlSelection sqlSelection : sqlSelections) {
            expressions.add(sqlSelection.getExpression());
        }
        return new SqlTuple(expressions, null);
    }

    private void collectTableReference(TableReference tableReference, BiConsumer<String, TableReference> consumer) {
        consumer.accept(tableReference.getIdentificationVariable(), tableReference);
    }

    private void collectTableReference(TableReferenceJoin tableReferenceJoin, BiConsumer<String, TableReference> consumer) {
        this.collectTableReference(tableReferenceJoin.getJoinedTableReference(), consumer);
    }

    private TableReference resolveTableReference(ColumnReference columnReference, Map<String, TableReference> tableReferenceByAlias) {
        TableReference tableReferenceByQualifier = tableReferenceByAlias.get(columnReference.getQualifier());
        if (tableReferenceByQualifier != null) {
            return tableReferenceByQualifier;
        }
        throw new SemanticException("Assignment referred to column of a joined association: " + columnReference);
    }

    private NamedTableReference resolveUnionTableReference(TableReference tableReference, String tableExpression) {
        if (tableReference instanceof UnionTableReference) {
            return new NamedTableReference(tableExpression, tableReference.getIdentificationVariable(), tableReference.isOptional());
        }
        return (NamedTableReference)tableReference;
    }
}

