/*
 * Decompiled with CFR 0.152.
 */
package io.crate.expression.symbol;

import io.crate.expression.symbol.Aggregation;
import io.crate.expression.symbol.AliasSymbol;
import io.crate.expression.symbol.DynamicReference;
import io.crate.expression.symbol.FetchReference;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.InputColumn;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.MatchPredicate;
import io.crate.expression.symbol.ParameterSymbol;
import io.crate.expression.symbol.ScopedSymbol;
import io.crate.expression.symbol.SelectSymbol;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolVisitor;
import io.crate.expression.symbol.Symbols;
import io.crate.expression.symbol.WindowFunction;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Reference;
import io.crate.types.DataTypes;
import java.util.List;

public final class GroupAndAggregateSemantics {
    public static void validate(List<Symbol> outputSymbols, List<Symbol> groupBy) throws IllegalArgumentException {
        boolean containsAggregations = Symbols.any(outputSymbols, x -> {
            Function fn;
            return x instanceof Function && (fn = (Function)x).signature().getType() == FunctionType.AGGREGATE;
        });
        if (!containsAggregations && groupBy.isEmpty()) {
            return;
        }
        groupBy.forEach(GroupAndAggregateSemantics::ensureTypedGroupKey);
        for (int i = 0; i < outputSymbols.size(); ++i) {
            Symbol output = outputSymbols.get(i);
            Symbol offender = output.accept(FindOffendingSymbol.INSTANCE, groupBy);
            if (offender == null) continue;
            throw new IllegalArgumentException("'" + String.valueOf(offender) + "' must appear in the GROUP BY clause or be used in an aggregation function. Perhaps you grouped by an alias that clashes with a column in the relations");
        }
    }

    private static void ensureTypedGroupKey(Symbol groupBy) {
        groupBy.accept(EnsureTypedGroupKey.INSTANCE, null);
    }

    private static class FindOffendingSymbol
    extends SymbolVisitor<List<Symbol>, Symbol> {
        private static final FindOffendingSymbol INSTANCE = new FindOffendingSymbol();

        private FindOffendingSymbol() {
        }

        @Override
        protected Symbol visitSymbol(Symbol symbol, List<Symbol> groupBy) {
            throw new UnsupportedOperationException("Unsupported symbol: " + String.valueOf(symbol));
        }

        @Override
        public Symbol visitFunction(Function function, List<Symbol> groupBy) {
            switch (function.signature().getType()) {
                case SCALAR: {
                    if (groupBy.contains(function)) {
                        return null;
                    }
                    for (Symbol argument : function.arguments()) {
                        Symbol offender = argument.accept(this, groupBy);
                        if (offender == null) continue;
                        return function;
                    }
                    return null;
                }
                case AGGREGATE: {
                    return null;
                }
                case TABLE: 
                case WINDOW: {
                    for (Symbol argument : function.arguments()) {
                        Symbol offender = argument.accept(this, groupBy);
                        if (offender == null) continue;
                        return offender;
                    }
                    return null;
                }
            }
            throw new IllegalStateException("Unexpected function type: " + String.valueOf((Object)function.signature().getType()));
        }

        @Override
        public Symbol visitAggregation(Aggregation symbol, List<Symbol> groupBy) {
            throw new AssertionError((Object)"`Aggregation` symbols are created in the Planner. Until then there should only be `Function` symbols with type aggregate");
        }

        @Override
        public Symbol visitAlias(AliasSymbol aliasSymbol, List<Symbol> groupBy) {
            if (groupBy.contains(aliasSymbol)) {
                return null;
            }
            return aliasSymbol.symbol().accept(this, groupBy);
        }

        @Override
        public Symbol visitReference(Reference ref, List<Symbol> groupBy) {
            if (FindOffendingSymbol.containedIn(ref, groupBy)) {
                return null;
            }
            return ref;
        }

        @Override
        public Symbol visitField(ScopedSymbol symbol, List<Symbol> groupBy) {
            if (FindOffendingSymbol.containedIn(symbol, groupBy)) {
                return null;
            }
            return symbol;
        }

        public static boolean containedIn(Symbol symbol, List<Symbol> groupBy) {
            for (Symbol groupExpr : groupBy) {
                AliasSymbol aliasSymbol;
                if (symbol.equals(groupExpr)) {
                    return true;
                }
                if (!(groupExpr instanceof AliasSymbol) || !symbol.equals((aliasSymbol = (AliasSymbol)groupExpr).symbol())) continue;
                return true;
            }
            return false;
        }

        @Override
        public Symbol visitDynamicReference(DynamicReference ref, List<Symbol> groupBy) {
            return this.visitReference((Reference)ref, groupBy);
        }

        @Override
        public Symbol visitWindowFunction(WindowFunction function, List<Symbol> groupBy) {
            for (Symbol argument : function.arguments()) {
                Symbol offender = argument.accept(this, groupBy);
                if (offender == null) continue;
                return offender;
            }
            return null;
        }

        @Override
        public Symbol visitLiteral(Literal<?> symbol, List<Symbol> groupBy) {
            return null;
        }

        @Override
        public Symbol visitParameterSymbol(ParameterSymbol parameterSymbol, List<Symbol> groupBy) {
            return null;
        }

        @Override
        public Symbol visitSelectSymbol(SelectSymbol selectSymbol, List<Symbol> groupBy) {
            return null;
        }

        @Override
        public Symbol visitInputColumn(InputColumn inputColumn, List<Symbol> groupBy) {
            throw new AssertionError((Object)"Must not have `InputColumn`s when doing semantic validation of SELECT LIST / GROUP BY");
        }

        @Override
        public Symbol visitMatchPredicate(MatchPredicate matchPredicate, List<Symbol> groupBy) {
            throw new AssertionError((Object)"MATCH predicate cannot be used in SELECT list");
        }

        @Override
        public Symbol visitFetchReference(FetchReference fetchReference, List<Symbol> groupBy) {
            throw new AssertionError((Object)"Must not have `FetchReference`s when doing semantic validation of SELECT LIST / GROUP BY");
        }
    }

    private static class EnsureTypedGroupKey
    extends SymbolVisitor<Void, Void> {
        private static final EnsureTypedGroupKey INSTANCE = new EnsureTypedGroupKey();

        private EnsureTypedGroupKey() {
        }

        @Override
        public Void visitLiteral(Literal<?> symbol, Void context) {
            if (symbol.valueType() == DataTypes.UNDEFINED) {
                if (symbol.value() == null) {
                    return null;
                }
                EnsureTypedGroupKey.raiseException(symbol);
            }
            return null;
        }

        @Override
        public Void visitAlias(AliasSymbol aliasSymbol, Void context) {
            return aliasSymbol.symbol().accept(this, context);
        }

        private static void raiseException(Symbol symbol) {
            throw new IllegalArgumentException("Cannot group or aggregate on '" + symbol.toString() + "' with an undefined type. Using an explicit type cast will make this work but adds processing overhead to the query.");
        }
    }
}

