/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.optimizer.costs;

import io.crate.common.collections.Maps;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.Symbol;
import io.crate.metadata.NodeContext;
import io.crate.metadata.TransactionContext;
import io.crate.planner.operators.AbstractJoinPlan;
import io.crate.planner.operators.Collect;
import io.crate.planner.operators.CorrelatedJoin;
import io.crate.planner.operators.Count;
import io.crate.planner.operators.Filter;
import io.crate.planner.operators.ForeignCollect;
import io.crate.planner.operators.Get;
import io.crate.planner.operators.GroupHashAggregate;
import io.crate.planner.operators.HashAggregate;
import io.crate.planner.operators.HashJoin;
import io.crate.planner.operators.Insert;
import io.crate.planner.operators.JoinPlan;
import io.crate.planner.operators.Limit;
import io.crate.planner.operators.LogicalPlan;
import io.crate.planner.operators.LogicalPlanVisitor;
import io.crate.planner.operators.NestedLoopJoin;
import io.crate.planner.operators.TableFunction;
import io.crate.planner.operators.Union;
import io.crate.planner.optimizer.iterative.GroupReference;
import io.crate.planner.optimizer.iterative.Memo;
import io.crate.planner.selectivity.SelectivityFunctions;
import io.crate.sql.tree.JoinType;
import io.crate.statistics.Stats;
import io.crate.statistics.TableStats;
import io.crate.types.DataTypes;
import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.Nullable;

public class PlanStats {
    private final TableStats tableStats;
    private final StatsVisitor visitor;
    private final NodeContext nodeContext;
    private final TransactionContext txnCtx;

    public PlanStats(NodeContext nodeContext, TransactionContext txnCtx, TableStats tableStats) {
        this(nodeContext, txnCtx, tableStats, null);
    }

    public PlanStats(NodeContext nodeContext, TransactionContext txnCtx, TableStats tableStats, @Nullable Memo memo) {
        this.nodeContext = nodeContext;
        this.txnCtx = txnCtx;
        this.tableStats = tableStats;
        this.visitor = new StatsVisitor(nodeContext, txnCtx, tableStats, memo);
    }

    public PlanStats withMemo(Memo memo) {
        return new PlanStats(this.nodeContext, this.txnCtx, this.tableStats, memo);
    }

    public Stats get(LogicalPlan logicalPlan) {
        return logicalPlan.accept(this.visitor, null);
    }

    private static class StatsVisitor
    extends LogicalPlanVisitor<Void, Stats> {
        private final NodeContext nodeContext;
        private final TransactionContext txnCtx;
        private final TableStats tableStats;
        @Nullable
        private final Memo memo;

        public StatsVisitor(NodeContext nodeContext, TransactionContext txnCtx, TableStats tableStats, @Nullable Memo memo) {
            this.nodeContext = nodeContext;
            this.txnCtx = txnCtx;
            this.tableStats = tableStats;
            this.memo = memo;
        }

        @Override
        public Stats visitGroupReference(GroupReference group, Void context) {
            if (this.memo == null) {
                throw new UnsupportedOperationException("Stats cannot be provided for GroupReference without a Memo");
            }
            int groupId = group.groupId();
            Stats stats = this.memo.stats(groupId);
            if (stats == null) {
                LogicalPlan logicalPlan = this.memo.resolve(groupId);
                stats = logicalPlan.accept(this, null);
                this.memo.addStats(groupId, stats);
            }
            return stats;
        }

        @Override
        public Stats visitLimit(Limit limit, Void context) {
            Stats stats = limit.source().accept(this, null);
            Symbol symbol = limit.limit();
            if (symbol instanceof Literal) {
                Literal literal = (Literal)symbol;
                Long numberOfRows = DataTypes.LONG.sanitizeValue(literal.value());
                if (stats.numDocs() > numberOfRows) {
                    return stats.withNumDocs(numberOfRows);
                }
            }
            return stats;
        }

        @Override
        public Stats visitUnion(Union union, Void context) {
            Stats lhsStats = union.lhs().accept(this, context);
            Stats rhsStats = union.rhs().accept(this, context);
            return lhsStats.add(rhsStats);
        }

        @Override
        public Stats visitJoinPlan(JoinPlan join, Void context) {
            return this.visitAbstractJoinPlan(join, context);
        }

        @Override
        public Stats visitNestedLoopJoin(NestedLoopJoin join, Void context) {
            return this.visitAbstractJoinPlan(join, context);
        }

        private Stats visitAbstractJoinPlan(AbstractJoinPlan join, Void context) {
            Stats lhsStats = join.lhs().accept(this, context);
            Stats rhsStats = join.rhs().accept(this, context);
            Map statsByColumn = Maps.concat(lhsStats.statsByColumn(), rhsStats.statsByColumn());
            if (lhsStats.numDocs() == -1L || lhsStats.sizeInBytes() == -1L || rhsStats.numDocs() == -1L || rhsStats.sizeInBytes() == -1L) {
                return new Stats(-1L, -1L, statsByColumn);
            }
            long numRows = join.joinType() == JoinType.CROSS ? lhsStats.numDocs() * rhsStats.numDocs() : Math.max(lhsStats.numDocs(), rhsStats.numDocs());
            Stats joinStats = new Stats(numRows, lhsStats.averageSizePerRowInBytes() * numRows + rhsStats.averageSizePerRowInBytes() * numRows, statsByColumn);
            Symbol joinCondition = join.joinCondition();
            if (joinCondition == null) {
                return joinStats;
            }
            long estimatedNumRows = SelectivityFunctions.estimateNumRows(this.nodeContext, this.txnCtx, joinStats, joinCondition, null);
            return joinStats.withNumDocs(estimatedNumRows);
        }

        @Override
        public Stats visitHashJoin(HashJoin join, Void context) {
            Stats lhsStats = join.lhs().accept(this, context);
            Stats rhsStats = join.rhs().accept(this, context);
            Map statsByColumn = Maps.concat(lhsStats.statsByColumn(), rhsStats.statsByColumn());
            if (lhsStats.numDocs() == -1L || lhsStats.sizeInBytes() == -1L || rhsStats.numDocs() == -1L || rhsStats.sizeInBytes() == -1L) {
                return new Stats(-1L, -1L, statsByColumn);
            }
            long numRows = Math.max(lhsStats.numDocs(), rhsStats.numDocs());
            long sizeInBytes = numRows * lhsStats.averageSizePerRowInBytes() + numRows * rhsStats.averageSizePerRowInBytes();
            Stats joinStats = new Stats(numRows, sizeInBytes, statsByColumn);
            long estimatedNumRows = SelectivityFunctions.estimateNumRows(this.nodeContext, this.txnCtx, joinStats, join.joinCondition(), null);
            return joinStats.withNumDocs(estimatedNumRows);
        }

        @Override
        public Stats visitCollect(Collect collect, Void context) {
            Stats stats = this.tableStats.getStats(collect.relation().tableInfo().ident());
            if (stats.equals(Stats.EMPTY)) {
                return stats;
            }
            Symbol query = collect.where().queryOrFallback();
            long numberOfRows = SelectivityFunctions.estimateNumRows(this.nodeContext, this.txnCtx, stats, query, null);
            return stats.withNumDocs(numberOfRows);
        }

        @Override
        public Stats visitFilter(Filter filter, Void context) {
            Stats sourceStats = filter.source().accept(this, context);
            Symbol query = filter.query();
            long numRows = SelectivityFunctions.estimateNumRows(this.nodeContext, this.txnCtx, sourceStats, query, null);
            return sourceStats.withNumDocs(numRows);
        }

        @Override
        public Stats visitCount(Count count, Void context) {
            return new Stats(1L, 8L, Map.of());
        }

        @Override
        public Stats visitGet(Get get, Void context) {
            Stats stats = this.tableStats.getStats(get.table().relationName());
            return stats.withNumDocs(get.numExpectedRows());
        }

        @Override
        public Stats visitGroupHashAggregate(GroupHashAggregate groupHashAggregate, Void context) {
            Stats stats = groupHashAggregate.source().accept(this, context);
            return stats.withNumDocs(GroupHashAggregate.approximateDistinctValues(stats, groupHashAggregate.groupKeys()));
        }

        @Override
        public Stats visitHashAggregate(HashAggregate hashAggregate, Void context) {
            Stats stats = hashAggregate.source().accept(this, context);
            return stats.withNumDocs(1L);
        }

        @Override
        public Stats visitInsert(Insert insert, Void context) {
            Stats stats = insert.sources().get(0).accept(this, context);
            return stats.withNumDocs(1L);
        }

        @Override
        public Stats visitCorrelatedJoin(CorrelatedJoin join, Void context) {
            return join.sources().get(0).accept(this, context);
        }

        @Override
        public Stats visitTableFunction(TableFunction tableFunction, Void context) {
            return Stats.EMPTY;
        }

        @Override
        public Stats visitForeignCollect(ForeignCollect foreignCollect, Void context) {
            return Stats.EMPTY;
        }

        @Override
        public Stats visitPlan(LogicalPlan logicalPlan, Void context) {
            List<LogicalPlan> sources = logicalPlan.sources();
            if (sources.size() == 1) {
                return sources.get(0).accept(this, context);
            }
            return Stats.EMPTY;
        }
    }
}

