/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.operators;

import io.crate.analyze.OrderBy;
import io.crate.analyze.relations.QuerySplitter;
import io.crate.common.collections.Lists;
import io.crate.data.Row;
import io.crate.execution.dsl.phases.HashJoinPhase;
import io.crate.execution.dsl.phases.MergePhase;
import io.crate.execution.dsl.projection.EvalProjection;
import io.crate.execution.dsl.projection.builder.InputColumns;
import io.crate.execution.dsl.projection.builder.ProjectionBuilder;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.Symbols;
import io.crate.metadata.RelationName;
import io.crate.planner.DependencyCarrier;
import io.crate.planner.ExecutionPlan;
import io.crate.planner.PlannerContext;
import io.crate.planner.ResultDescription;
import io.crate.planner.distribution.DistributionInfo;
import io.crate.planner.distribution.DistributionType;
import io.crate.planner.node.dql.join.Join;
import io.crate.planner.operators.AbstractJoinPlan;
import io.crate.planner.operators.FetchRewrite;
import io.crate.planner.operators.JoinConditionSymbolsExtractor;
import io.crate.planner.operators.LogicalPlan;
import io.crate.planner.operators.LogicalPlanVisitor;
import io.crate.planner.operators.NestedLoopJoin;
import io.crate.planner.operators.PlanHint;
import io.crate.planner.operators.PrintContext;
import io.crate.planner.operators.SubQueryAndParamBinder;
import io.crate.planner.operators.SubQueryResults;
import io.crate.sql.tree.JoinType;
import io.crate.statistics.Stats;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.SequencedCollection;
import java.util.Set;
import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class HashJoin
extends AbstractJoinPlan {
    public HashJoin(LogicalPlan lhs, LogicalPlan rhs, Symbol joinCondition, JoinType joinType) {
        super(lhs, rhs, joinCondition, joinType, AbstractJoinPlan.LookUpJoin.NONE);
    }

    public HashJoin(LogicalPlan lhs, LogicalPlan rhs, Symbol joinCondition, JoinType joinType, AbstractJoinPlan.LookUpJoin lookUpJoin) {
        super(lhs, rhs, joinCondition, joinType, lookUpJoin);
    }

    @Override
    public ExecutionPlan build(DependencyCarrier executor, PlannerContext plannerContext, Set<PlanHint> hints, ProjectionBuilder projectionBuilder, int limit, int offset, @Nullable OrderBy order, @Nullable Integer pageSizeHint, Row params, SubQueryResults subQueryResults) {
        boolean isDistributed;
        ExecutionPlan leftExecutionPlan = this.lhs.build(executor, plannerContext, hints, projectionBuilder, -1, 0, null, null, params, subQueryResults);
        ExecutionPlan rightExecutionPlan = this.rhs.build(executor, plannerContext, hints, projectionBuilder, -1, 0, null, null, params, subQueryResults);
        SubQueryAndParamBinder paramBinder = new SubQueryAndParamBinder(params, subQueryResults);
        HashSymbols hashSymbols = HashJoin.createHashSymbols(this.lhs.relationNames(), this.rhs.relationNames(), this.joinCondition);
        List<Symbol> lhsHashSymbols = hashSymbols.lhsHashSymbols();
        List<Symbol> rhsHashSymbols = hashSymbols.rhsHashSymbols();
        ResultDescription leftResultDesc = leftExecutionPlan.resultDescription();
        ResultDescription rightResultDesc = rightExecutionPlan.resultDescription();
        Collection<String> joinExecutionNodes = leftResultDesc.nodeIds();
        List<Symbol> leftOutputs = this.lhs.outputs();
        List<Symbol> rightOutputs = this.rhs.outputs();
        MergePhase leftMerge = null;
        MergePhase rightMerge = null;
        boolean bl = isDistributed = !leftResultDesc.hasRemainingLimitOrOffset() && !rightResultDesc.hasRemainingLimitOrOffset();
        if (joinExecutionNodes.isEmpty()) {
            isDistributed = false;
        }
        if (joinExecutionNodes.size() == 1 && Lists.equals(joinExecutionNodes, rightResultDesc.nodeIds()) && !rightResultDesc.hasRemainingLimitOrOffset()) {
            leftExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_SAME_NODE);
            rightExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_SAME_NODE);
        } else {
            if (isDistributed) {
                leftOutputs = this.setModuloDistribution(lhsHashSymbols, this.lhs.outputs(), leftExecutionPlan);
                rightOutputs = this.setModuloDistribution(rhsHashSymbols, this.rhs.outputs(), rightExecutionPlan);
            } else {
                joinExecutionNodes = Collections.singletonList(plannerContext.handlerNode());
                leftExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_BROADCAST);
                rightExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_BROADCAST);
            }
            leftMerge = HashJoin.buildMergePhaseForJoin(plannerContext, leftResultDesc, joinExecutionNodes);
            rightMerge = HashJoin.buildMergePhaseForJoin(plannerContext, rightResultDesc, joinExecutionNodes);
        }
        List joinOutputs = Lists.concat(leftOutputs, rightOutputs);
        Stats lhStats = plannerContext.planStats().get(this.lhs);
        HashJoinPhase joinPhase = new HashJoinPhase(plannerContext.jobId(), plannerContext.nextExecutionPhaseId(), "hash-join", Collections.singletonList(NestedLoopJoin.createJoinProjection(this.outputs(), joinOutputs)), leftMerge, rightMerge, leftOutputs.size(), rightOutputs.size(), joinExecutionNodes, InputColumns.create(paramBinder.apply(this.joinCondition), joinOutputs), InputColumns.create(lhsHashSymbols, new InputColumns.SourceSymbols(leftOutputs)), InputColumns.create(rhsHashSymbols, new InputColumns.SourceSymbols(rightOutputs)), Symbols.typeView(leftOutputs), lhStats.estimateSizeForColumns(leftOutputs), this.joinType);
        return new Join(joinPhase, leftExecutionPlan, rightExecutionPlan, -1, 0, -1, this.outputs().size(), null);
    }

    @Override
    public LogicalPlan replaceSources(List<LogicalPlan> sources) {
        return new HashJoin(sources.get(0), sources.get(1), this.joinCondition, this.joinType, this.lookupJoin);
    }

    @Override
    public LogicalPlan pruneOutputsExcept(SequencedCollection<Symbol> outputsToKeep) {
        LinkedHashSet<Symbol> lhsToKeep = new LinkedHashSet<Symbol>();
        LinkedHashSet<Symbol> rhsToKeep = new LinkedHashSet<Symbol>();
        for (Symbol outputToKeep : outputsToKeep) {
            Symbols.intersection(outputToKeep, this.lhs.outputs(), lhsToKeep::add);
            Symbols.intersection(outputToKeep, this.rhs.outputs(), rhsToKeep::add);
        }
        if (lhsToKeep.isEmpty() && this.lookupJoin == AbstractJoinPlan.LookUpJoin.RIGHT) {
            Symbols.intersection(this.joinCondition, this.rhs.outputs(), rhsToKeep::add);
            return this.rhs.pruneOutputsExcept(rhsToKeep);
        }
        if (rhsToKeep.isEmpty() && this.lookupJoin == AbstractJoinPlan.LookUpJoin.LEFT) {
            Symbols.intersection(this.joinCondition, this.lhs.outputs(), lhsToKeep::add);
            return this.lhs.pruneOutputsExcept(lhsToKeep);
        }
        Symbols.intersection(this.joinCondition, this.lhs.outputs(), lhsToKeep::add);
        Symbols.intersection(this.joinCondition, this.rhs.outputs(), rhsToKeep::add);
        LogicalPlan newLhs = this.lhs.pruneOutputsExcept(lhsToKeep);
        LogicalPlan newRhs = this.rhs.pruneOutputsExcept(rhsToKeep);
        if (newLhs == this.lhs && newRhs == this.rhs) {
            return this;
        }
        return new HashJoin(newLhs, newRhs, this.joinCondition, this.joinType, this.lookupJoin);
    }

    @Override
    @Nullable
    public FetchRewrite rewriteToFetch(Collection<Symbol> usedColumns) {
        LinkedHashSet<Symbol> usedFromLeft = new LinkedHashSet<Symbol>();
        LinkedHashSet<Symbol> usedFromRight = new LinkedHashSet<Symbol>();
        for (Symbol usedColumn : usedColumns) {
            Symbols.intersection(usedColumn, this.lhs.outputs(), usedFromLeft::add);
            Symbols.intersection(usedColumn, this.rhs.outputs(), usedFromRight::add);
        }
        Symbols.intersection(this.joinCondition, this.lhs.outputs(), usedFromLeft::add);
        Symbols.intersection(this.joinCondition, this.rhs.outputs(), usedFromRight::add);
        FetchRewrite lhsFetchRewrite = this.lhs.rewriteToFetch(usedFromLeft);
        FetchRewrite rhsFetchRewrite = this.rhs.rewriteToFetch(usedFromRight);
        if (lhsFetchRewrite == null && rhsFetchRewrite == null) {
            return null;
        }
        LinkedHashMap<Symbol, Symbol> allReplacedOutputs = new LinkedHashMap<Symbol, Symbol>();
        NestedLoopJoin.setReplacedOutputs(this.lhs, lhsFetchRewrite, allReplacedOutputs);
        NestedLoopJoin.setReplacedOutputs(this.rhs, rhsFetchRewrite, allReplacedOutputs);
        return new FetchRewrite(allReplacedOutputs, new HashJoin(lhsFetchRewrite == null ? this.lhs : lhsFetchRewrite.newPlan(), rhsFetchRewrite == null ? this.rhs : rhsFetchRewrite.newPlan(), this.joinCondition, this.joinType, this.lookupJoin));
    }

    @Override
    public <C, R> R accept(LogicalPlanVisitor<C, R> visitor, C context) {
        return visitor.visitHashJoin(this, context);
    }

    @Override
    public void print(PrintContext printContext) {
        printContext.text("HashJoin[").text(this.joinType.toString()).text(" | ").text(this.joinCondition.toString()).text("]");
        this.printStats(printContext);
        Consumer[] consumerArray = new Consumer[2];
        consumerArray[0] = this.lhs::print;
        consumerArray[1] = this.rhs::print;
        printContext.nest(consumerArray);
    }

    private List<Symbol> setModuloDistribution(List<Symbol> joinSymbols, List<Symbol> planOutputs, ExecutionPlan executionPlan) {
        List<Symbol> outputs = planOutputs;
        Symbol firstJoinSymbol = joinSymbols.get(0);
        int distributeBySymbolPos = planOutputs.indexOf(firstJoinSymbol);
        if (distributeBySymbolPos < 0) {
            outputs = this.createEvalProjectionForDistributionJoinSymbol(firstJoinSymbol, planOutputs, executionPlan);
            distributeBySymbolPos = planOutputs.size();
        }
        executionPlan.setDistributionInfo(new DistributionInfo(DistributionType.MODULO, distributeBySymbolPos));
        return outputs;
    }

    private List<Symbol> createEvalProjectionForDistributionJoinSymbol(Symbol firstJoinSymbol, List<Symbol> outputs, ExecutionPlan executionPlan) {
        ArrayList<Symbol> projectionOutputs = new ArrayList<Symbol>(outputs.size() + 1);
        projectionOutputs.addAll(outputs);
        projectionOutputs.add(firstJoinSymbol);
        EvalProjection evalProjection = new EvalProjection(InputColumns.create(projectionOutputs, new InputColumns.SourceSymbols(outputs)));
        executionPlan.addProjection(evalProjection);
        return projectionOutputs;
    }

    @VisibleForTesting
    static HashSymbols createHashSymbols(List<RelationName> lhsRelationNames, List<RelationName> rhsRelationNames, Symbol symbol) {
        ArrayList<Symbol> lhsHashSymbols = new ArrayList<Symbol>();
        ArrayList<Symbol> rhsHashSymbols = new ArrayList<Symbol>();
        for (Map.Entry<Set<RelationName>, Symbol> condition : QuerySplitter.split(symbol).entrySet()) {
            for (Map.Entry<RelationName, List<Symbol>> entry : JoinConditionSymbolsExtractor.extract(condition.getValue()).entrySet()) {
                RelationName relationName = entry.getKey();
                List<Symbol> symbols = entry.getValue();
                if (symbols == null) continue;
                if (rhsRelationNames.contains(relationName)) {
                    rhsHashSymbols.addAll(symbols);
                    continue;
                }
                if (!lhsRelationNames.contains(relationName)) continue;
                lhsHashSymbols.addAll(symbols);
            }
        }
        assert (rhsHashSymbols.size() == lhsHashSymbols.size()) : "Number of hash values for left and right hand side of a hash-join must be equal";
        return new HashSymbols(lhsHashSymbols, rhsHashSymbols);
    }

    record HashSymbols(List<Symbol> lhsHashSymbols, List<Symbol> rhsHashSymbols) {
    }
}

