/*
 * Decompiled with CFR 0.152.
 */
package io.crate.analyze.where;

import io.crate.exceptions.VersioningValidationException;
import io.crate.expression.operator.any.AnyNeqOperator;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.ScopedSymbol;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolType;
import io.crate.expression.symbol.SymbolVisitor;
import io.crate.expression.symbol.WindowFunction;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Reference;
import io.crate.metadata.doc.SysColumns;
import io.crate.sql.tree.ComparisonExpression;
import java.util.Locale;
import java.util.Set;
import java.util.Stack;
import java.util.function.Supplier;

public final class WhereClauseValidator {
    private static final Visitor VISITOR = new Visitor();

    private WhereClauseValidator() {
    }

    public static void validate(Symbol query) {
        query.accept(VISITOR, new Visitor.Context());
    }

    private static class Visitor
    extends SymbolVisitor<Context, Symbol> {
        private static final String SCORE = "_score";
        private static final Set<String> SCORE_ALLOWED_COMPARISONS = Set.of("op_>=");
        private static final String VERSION = "_version";
        private static final String SEQ_NO = "_seq_no";
        private static final String PRIMARY_TERM = "_primary_term";
        private static final Set<String> VERSIONING_ALLOWED_COMPARISONS = Set.of("op_=", AnyNeqOperator.NAME);
        private static final String SCORE_ERROR = String.format(Locale.ENGLISH, "System column '%s' can only be used within a '%s' comparison without any surrounded predicate", "_score", ComparisonExpression.Type.GREATER_THAN_OR_EQUAL.getValue());

        private Visitor() {
        }

        @Override
        public Symbol visitField(ScopedSymbol field, Context context) {
            this.validateSysReference(context, field.column().sqlFqn());
            return (Symbol)super.visitField(field, context);
        }

        @Override
        public Symbol visitReference(Reference symbol, Context context) {
            this.validateSysReference(context, symbol.column().name());
            return (Symbol)super.visitReference(symbol, context);
        }

        @Override
        public Symbol visitFunction(Function function, Context context) {
            context.functions.push(function);
            if (function.signature().getType().equals((Object)FunctionType.TABLE)) {
                throw new UnsupportedOperationException("Table functions are not allowed in WHERE");
            }
            this.continueTraversal(function, context);
            context.functions.pop();
            return function;
        }

        @Override
        public Symbol visitWindowFunction(WindowFunction symbol, Context context) {
            throw new IllegalArgumentException("Window functions are not allowed in WHERE");
        }

        private void continueTraversal(Function symbol, Context context) {
            for (Symbol argument : symbol.arguments()) {
                argument.accept(this, context);
            }
        }

        private static boolean insideNotPredicate(Context context) {
            for (Function function : context.functions) {
                if (!function.name().equals("op_not")) continue;
                return true;
            }
            return false;
        }

        private static boolean insideCastComparedWithLiteral(Context context, Set<String> requiredFunctionNames) {
            int numFunctions = context.functions.size();
            if (numFunctions < 2) {
                return false;
            }
            Function lastFunction = (Function)context.functions.get(numFunctions - 1);
            Function parentFunction = (Function)context.functions.get(numFunctions - 2);
            if (lastFunction.isCast() && parentFunction.name().startsWith("op_") && requiredFunctionNames.contains(parentFunction.name())) {
                Symbol rightArg = parentFunction.arguments().get(1);
                return rightArg.symbolType().isValueSymbol();
            }
            return false;
        }

        private void validateSysReference(Context context, String columnName) {
            if (columnName.equalsIgnoreCase(VERSION)) {
                Visitor.validateSysReference(context, VERSIONING_ALLOWED_COMPARISONS, VersioningValidationException::versionInvalidUsage);
            } else if (columnName.equalsIgnoreCase(SEQ_NO) || columnName.equalsIgnoreCase(PRIMARY_TERM)) {
                Visitor.validateSysReference(context, VERSIONING_ALLOWED_COMPARISONS, VersioningValidationException::seqNoAndPrimaryTermUsage);
            } else if (columnName.equalsIgnoreCase(SCORE)) {
                Visitor.validateSysReference(context, SCORE_ALLOWED_COMPARISONS, () -> new UnsupportedOperationException(SCORE_ERROR));
            } else if (columnName.equalsIgnoreCase(SysColumns.RAW.name())) {
                throw new UnsupportedOperationException("The _raw column is not searchable and cannot be used inside a query");
            }
        }

        private static void validateSysReference(Context context, Set<String> requiredFunctionNames, Supplier<RuntimeException> error) {
            if (context.functions.isEmpty()) {
                throw error.get();
            }
            Function function = (Function)context.functions.lastElement();
            if (!Visitor.insideCastComparedWithLiteral(context, requiredFunctionNames) && !requiredFunctionNames.contains(function.name().toLowerCase(Locale.ENGLISH)) || Visitor.insideNotPredicate(context)) {
                throw error.get();
            }
            assert (function.arguments().size() == 2) : "function's number of arguments must be 2";
            Symbol right = function.arguments().get(1);
            if (!right.symbolType().isValueSymbol() && right.symbolType() != SymbolType.PARAMETER) {
                throw error.get();
            }
        }

        static class Context {
            private final Stack<Function> functions = new Stack();

            private Context() {
            }
        }
    }
}

