/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.optimizer.joinorder;

import io.crate.analyze.relations.QuerySplitter;
import io.crate.common.collections.Lists;
import io.crate.common.collections.Maps;
import io.crate.common.collections.Sets;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.ScopedSymbol;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolVisitor;
import io.crate.metadata.Reference;
import io.crate.metadata.RelationName;
import io.crate.planner.operators.Filter;
import io.crate.planner.operators.JoinPlan;
import io.crate.planner.operators.LogicalPlan;
import io.crate.planner.operators.LogicalPlanVisitor;
import io.crate.planner.optimizer.iterative.GroupReference;
import io.crate.sql.tree.JoinType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;

public record JoinGraph(List<LogicalPlan> nodes, Map<LogicalPlan, Set<Edge>> edges, List<Symbol> filters, boolean hasCrossJoin) {
    JoinGraph joinWith(JoinGraph other) {
        for (LogicalPlan node : other.nodes) {
            assert (!this.edges.containsKey(node)) : "LogicalPlan" + String.valueOf(node) + " can't be in both graphs";
        }
        List newNodes = Lists.concat(this.nodes, other.nodes);
        Map newEdges = Maps.merge(this.edges, other.edges, Sets::union);
        List newFilters = Lists.concat(this.filters, other.filters);
        boolean hasCrossJoin = this.hasCrossJoin || other.hasCrossJoin();
        return new JoinGraph(newNodes, newEdges, newFilters, hasCrossJoin);
    }

    JoinGraph withEdges(Map<LogicalPlan, Set<Edge>> edges) {
        Map newEdges = Maps.merge(this.edges, edges, Sets::union);
        return new JoinGraph(this.nodes, newEdges, this.filters, this.hasCrossJoin);
    }

    JoinGraph withFilters(List<Symbol> filters) {
        if (filters.isEmpty()) {
            return this;
        }
        List newFilters = Lists.concat(this.filters, filters);
        return new JoinGraph(this.nodes, this.edges, newFilters, this.hasCrossJoin);
    }

    JoinGraph withCrossJoin() {
        return new JoinGraph(this.nodes, this.edges, this.filters, true);
    }

    public int size() {
        return this.nodes.size();
    }

    public Set<Edge> edges(LogicalPlan node) {
        Set<Edge> result = this.edges.get(node);
        if (result == null) {
            return Set.of();
        }
        return result;
    }

    public static JoinGraph create(LogicalPlan plan, UnaryOperator<LogicalPlan> resolvePlan) {
        return plan.accept(new GraphBuilder(resolvePlan), new HashMap());
    }

    private static class GraphBuilder
    extends LogicalPlanVisitor<Map<Symbol, LogicalPlan>, JoinGraph> {
        private final UnaryOperator<LogicalPlan> resolvePlan;

        GraphBuilder(UnaryOperator<LogicalPlan> resolvePlan) {
            this.resolvePlan = resolvePlan;
        }

        @Override
        public JoinGraph visitPlan(LogicalPlan logicalPlan, Map<Symbol, LogicalPlan> context) {
            for (Symbol output : logicalPlan.outputs()) {
                context.put(output, logicalPlan);
            }
            return new JoinGraph(List.of(logicalPlan), Map.of(), List.of(), false);
        }

        @Override
        public JoinGraph visitGroupReference(GroupReference groupReference, Map<Symbol, LogicalPlan> context) {
            return ((LogicalPlan)this.resolvePlan.apply(groupReference)).accept(this, context);
        }

        @Override
        public JoinGraph visitFilter(Filter filter, Map<Symbol, LogicalPlan> context) {
            JoinGraph source = filter.source().accept(this, context);
            return source.withFilters(List.of(filter.query()));
        }

        @Override
        public JoinGraph visitJoinPlan(JoinPlan joinPlan, Map<Symbol, LogicalPlan> context) {
            JoinGraph left = joinPlan.lhs().accept(this, context);
            JoinGraph right = joinPlan.rhs().accept(this, context);
            if (joinPlan.joinType() == JoinType.CROSS) {
                return left.joinWith(right).withCrossJoin();
            }
            if (joinPlan.joinType() != JoinType.INNER) {
                return left.joinWith(right);
            }
            Symbol joinCondition = joinPlan.joinCondition();
            EdgeCollector edgeCollector = new EdgeCollector();
            ArrayList<Symbol> filters = new ArrayList<Symbol>();
            if (joinCondition != null) {
                Map<Set<RelationName>, Symbol> split = QuerySplitter.split(joinCondition);
                for (Map.Entry<Set<RelationName>, Symbol> entry : split.entrySet()) {
                    if (entry.getKey().size() == 2) {
                        entry.getValue().accept(edgeCollector, context);
                        continue;
                    }
                    filters.add(entry.getValue());
                }
            }
            return left.joinWith(right).withEdges(edgeCollector.edges).withFilters(filters);
        }

        private static class EdgeCollector
        extends SymbolVisitor<Map<Symbol, LogicalPlan>, Void> {
            private final Map<LogicalPlan, Set<Edge>> edges = new HashMap<LogicalPlan, Set<Edge>>();
            private final List<LogicalPlan> sources = new ArrayList<LogicalPlan>();

            private EdgeCollector() {
            }

            @Override
            public Void visitField(ScopedSymbol s, Map<Symbol, LogicalPlan> context) {
                this.sources.add(context.get(s));
                return null;
            }

            @Override
            public Void visitReference(Reference ref, Map<Symbol, LogicalPlan> context) {
                this.sources.add(context.get(ref));
                return null;
            }

            @Override
            public Void visitFunction(Function f, Map<Symbol, LogicalPlan> context) {
                int sizeSource = this.sources.size();
                f.arguments().forEach(x -> x.accept(this, context));
                if (f.name().equals("op_=")) {
                    assert (this.sources.size() == sizeSource + 2) : "Source must be collected for each argument";
                    Symbol fromSymbol = f.arguments().get(0);
                    Symbol toSymbol = f.arguments().get(1);
                    LogicalPlan fromRelation = this.sources.get(this.sources.size() - 2);
                    LogicalPlan toRelation = this.sources.get(this.sources.size() - 1);
                    if (fromRelation != null && toRelation != null) {
                        this.addEdge(fromRelation, new Edge(toRelation, fromSymbol, toSymbol));
                        this.addEdge(toRelation, new Edge(fromRelation, fromSymbol, toSymbol));
                    }
                }
                return null;
            }

            private void addEdge(LogicalPlan from, Edge edge) {
                Set<Edge> values = this.edges.get(from);
                if (values == null) {
                    values = Set.of(edge);
                } else {
                    values = new HashSet<Edge>(values);
                    values.add(edge);
                }
                this.edges.put(from, values);
            }
        }
    }

    public record Edge(LogicalPlan to, Symbol left, Symbol right) {
    }
}

