/*
 * Decompiled with CFR 0.152.
 */
package io.crate.analyze.where;

import io.crate.common.collections.CartesianList;
import io.crate.expression.eval.EvaluatingNormalizer;
import io.crate.expression.operator.EqOperator;
import io.crate.expression.operator.Operator;
import io.crate.expression.operator.Operators;
import io.crate.expression.operator.any.AnyEqOperator;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.MatchPredicate;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolType;
import io.crate.expression.symbol.SymbolVisitor;
import io.crate.expression.symbol.format.Style;
import io.crate.metadata.ColumnIdent;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Reference;
import io.crate.metadata.Scalar;
import io.crate.metadata.TransactionContext;
import io.crate.metadata.functions.Signature;
import io.crate.session.Session;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.TypeSignature;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class EqualityExtractor {
    private static final int MAX_ITERATIONS = 10000;
    private static final Function NULL_MARKER = new Function(Signature.builder("null_marker", FunctionType.SCALAR).argumentTypes(new TypeSignature[0]).returnType(DataTypes.UNDEFINED.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build(), List.of(), DataTypes.UNDEFINED);
    private static final EqProxy NULL_MARKER_PROXY = new EqProxy(NULL_MARKER);
    private final EvaluatingNormalizer normalizer;

    public EqualityExtractor(EvaluatingNormalizer normalizer) {
        this.normalizer = normalizer;
    }

    @VisibleForTesting
    protected int maxIterations() {
        return 10000;
    }

    public EqMatches extractParentMatches(List<ColumnIdent> columns, Symbol symbol, @Nullable TransactionContext coordinatorTxnCtx, Session.TimeoutToken timeoutToken) {
        return this.extractMatches(columns, symbol, false, coordinatorTxnCtx, timeoutToken);
    }

    public EqMatches extractMatches(List<ColumnIdent> columns, Symbol symbol, TransactionContext txnCtx, Session.TimeoutToken timeoutToken) {
        return this.extractMatches(columns, symbol, true, txnCtx, timeoutToken);
    }

    private EqMatches extractMatches(Collection<ColumnIdent> columns, Symbol query, boolean shortCircuitOnMatchPredicateUnknown, TransactionContext txnCtx, Session.TimeoutToken timeoutToken) {
        Function fn;
        Symbol unknown;
        ProxyInjectingVisitor.Context context = new ProxyInjectingVisitor.Context(columns);
        Symbol normalizedQuery = this.normalizer.normalize(query, txnCtx);
        Symbol proxiedTree = normalizedQuery.accept(ProxyInjectingVisitor.INSTANCE, context);
        if (shortCircuitOnMatchPredicateUnknown && context.unknowns.size() == 1 && ((unknown = context.unknowns.iterator().next()) instanceof MatchPredicate || unknown instanceof Function && (fn = (Function)unknown).name().equals("match"))) {
            return EqMatches.NONE;
        }
        List<List<EqProxy>> comparisons = context.comparisonValues();
        List cp = CartesianList.of(comparisons);
        ArrayList<List<Symbol>> result = new ArrayList<List<Symbol>>();
        int iterations = 0;
        for (List proxies : cp) {
            if (++iterations >= this.maxIterations()) {
                return EqMatches.NONE;
            }
            if (iterations % 100 == 0) {
                timeoutToken.check();
            }
            boolean anyNull = false;
            for (EqProxy proxy : proxies) {
                if (proxy != NULL_MARKER_PROXY) {
                    proxy.setTrue();
                    continue;
                }
                anyNull = true;
            }
            Symbol normalized = this.normalizer.normalize(proxiedTree, txnCtx);
            if (normalized == Literal.BOOLEAN_TRUE) {
                if (anyNull) {
                    return EqMatches.NONE;
                }
                if (proxies.isEmpty()) continue;
                ArrayList<Symbol> row = new ArrayList<Symbol>(proxies.size());
                for (EqProxy proxy : proxies) {
                    proxy.reset();
                    row.add(proxy.origin.arguments().get(1));
                }
                result.add(row);
                continue;
            }
            for (EqProxy proxy : proxies) {
                proxy.reset();
            }
        }
        return new EqMatches(result.isEmpty() ? null : result, context.unknowns);
    }

    public record EqMatches(@Nullable List<List<Symbol>> matches, Set<Symbol> unknowns) {
        public static final EqMatches NONE = new EqMatches(null, Set.of());
    }

    private static class ProxyInjectingVisitor
    extends SymbolVisitor<Context, Symbol> {
        public static final ProxyInjectingVisitor INSTANCE = new ProxyInjectingVisitor();

        private ProxyInjectingVisitor() {
        }

        @Override
        protected Symbol visitSymbol(Symbol symbol, Context ctx) {
            return symbol;
        }

        @Override
        public Symbol visitMatchPredicate(MatchPredicate matchPredicate, Context ctx) {
            ctx.unknowns.add(matchPredicate);
            return Literal.BOOLEAN_TRUE;
        }

        @Override
        public Symbol visitReference(Reference ref, Context ctx) {
            Comparison comparison = ctx.comparisons.get(ref.column());
            if (comparison != null) {
                if (ctx.isUnderOrOperator) {
                    ctx.foundPKColumnUnderOr = true;
                }
                if (ctx.isUnderNotPredicate) {
                    ctx.foundPKColumnUnderNot = true;
                }
                if (ref.valueType().equals(DataTypes.BOOLEAN) && ctx.isUnderLogicalOperator) {
                    ctx.proxyBelow = true;
                    return comparison.add(new Function(EqOperator.SIGNATURE, List.of(ref, Literal.BOOLEAN_TRUE), EqOperator.RETURN_TYPE));
                }
            } else {
                if (ctx.isUnderOrOperator) {
                    ctx.foundNonPKColumnUnderOr = true;
                }
                ctx.unknowns.add(ref);
            }
            return (Symbol)super.visitReference(ref, ctx);
        }

        @Override
        public Symbol visitFunction(Function function, Context ctx) {
            String functionName = function.name();
            List<Symbol> arguments = function.arguments();
            boolean prevIsUnderNotPredicate = ctx.isUnderNotPredicate;
            ctx.isUnderLogicalOperator = false;
            if (functionName.equals("op_=")) {
                Symbol firstArg = arguments.get(0).accept(this, ctx);
                if (firstArg instanceof Reference) {
                    Comparison comparison;
                    Reference ref = (Reference)firstArg;
                    if (!arguments.get(1).any(Symbol.IS_COLUMN) && (comparison = ctx.comparisons.get(ref.column())) != null) {
                        ctx.proxyBelow = true;
                        return comparison.add(function);
                    }
                }
            } else if (functionName.equals(AnyEqOperator.NAME) && arguments.get(1).symbolType().isValueSymbol()) {
                Reference ref;
                Comparison comparison;
                Symbol firstArg = arguments.get(0).accept(this, ctx);
                if (firstArg instanceof Reference && (comparison = ctx.comparisons.get((ref = (Reference)firstArg).column())) != null) {
                    ctx.proxyBelow = true;
                    return comparison.add(function);
                }
            } else if (Operators.LOGICAL_OPERATORS.contains(functionName)) {
                boolean proxyBelowPre;
                ctx.isUnderLogicalOperator = true;
                switch (functionName) {
                    case "op_or": {
                        ctx.isUnderOrOperator = true;
                        break;
                    }
                    case "op_not": {
                        ctx.isUnderNotPredicate = true;
                        break;
                    }
                    case "op_and": {
                        ctx.isUnderOrOperator = false;
                        ctx.foundNonPKColumnUnderOr = false;
                        ctx.foundPKColumnUnderOr = false;
                        break;
                    }
                    default: {
                        throw new IllegalStateException("Unexpected function: " + functionName);
                    }
                }
                boolean proxyBelowPost = proxyBelowPre = ctx.proxyBelow;
                ArrayList<Symbol> newArgs = new ArrayList<Symbol>(arguments.size());
                for (Symbol arg : arguments) {
                    ctx.proxyBelow = proxyBelowPre;
                    newArgs.add(arg.accept(this, ctx));
                    proxyBelowPost = ctx.proxyBelow || proxyBelowPost;
                }
                if (ctx.foundPKColumnUnderOr && ctx.foundNonPKColumnUnderOr || ctx.foundPKColumnUnderNot) {
                    return Literal.BOOLEAN_FALSE;
                }
                ctx.isUnderNotPredicate = prevIsUnderNotPredicate;
                ctx.proxyBelow = proxyBelowPost;
                if (!ctx.proxyBelow && function.valueType().equals(DataTypes.BOOLEAN)) {
                    return Literal.BOOLEAN_TRUE;
                }
                return new Function(function.signature(), newArgs, function.valueType());
            }
            ctx.unknowns.add(function);
            return Literal.BOOLEAN_TRUE;
        }

        private static class Context {
            private final LinkedHashMap<ColumnIdent, Comparison> comparisons;
            private boolean proxyBelow;
            private final Set<Symbol> unknowns = new HashSet<Symbol>();
            private boolean isUnderNotPredicate = false;
            private boolean foundPKColumnUnderNot = false;
            private boolean isUnderOrOperator = false;
            private boolean foundNonPKColumnUnderOr = false;
            private boolean foundPKColumnUnderOr = false;
            private boolean isUnderLogicalOperator = false;

            private Context(Collection<ColumnIdent> references) {
                this.comparisons = LinkedHashMap.newLinkedHashMap(references.size());
                for (ColumnIdent reference : references) {
                    this.comparisons.put(reference, new Comparison());
                }
            }

            private List<List<EqProxy>> comparisonValues() {
                ArrayList<List<EqProxy>> comps = new ArrayList<List<EqProxy>>(this.comparisons.size());
                for (Comparison comparison : this.comparisons.values()) {
                    comps.add(List.copyOf(comparison.proxies.values()));
                }
                return comps;
            }
        }

        private static class Comparison {
            final HashMap<Function, EqProxy> proxies = new HashMap();

            public Comparison() {
                this.proxies.put(NULL_MARKER, NULL_MARKER_PROXY);
            }

            public EqProxy add(Function compared) {
                if (compared.name().equals(AnyEqOperator.NAME)) {
                    AnyEqProxy anyEqProxy = new AnyEqProxy(compared, this.proxies);
                    for (EqProxy proxiedProxy : anyEqProxy) {
                        if (this.proxies.containsKey(proxiedProxy.origin())) continue;
                        this.proxies.put(proxiedProxy.origin(), proxiedProxy);
                    }
                    return anyEqProxy;
                }
                return this.proxies.computeIfAbsent(compared, EqProxy::new);
            }
        }
    }

    static class EqProxy
    implements Symbol {
        protected Symbol current;
        protected final Function origin;

        private Function origin() {
            return this.origin;
        }

        EqProxy(Function origin) {
            this.origin = origin;
            this.current = origin;
        }

        public void reset() {
            this.current = this.origin;
        }

        public void setTrue() {
            this.current = Literal.BOOLEAN_TRUE;
        }

        @Override
        public SymbolType symbolType() {
            return this.current.symbolType();
        }

        @Override
        public <C, R> R accept(SymbolVisitor<C, R> visitor, C context) {
            return this.current.accept(visitor, context);
        }

        @Override
        public DataType<?> valueType() {
            return this.current.valueType();
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            throw new UnsupportedOperationException("writeTo not supported for " + EqProxy.class.getSimpleName());
        }

        @Override
        public String toString(Style style) {
            if (this == NULL_MARKER_PROXY) {
                return "NULL";
            }
            StringBuilder sb = new StringBuilder().append("(").append(this.origin.arguments().get(0).toString(style)).append("=").append(this.origin.arguments().get(1).toString(style)).append(")");
            if (this.current != this.origin) {
                sb.append(" TRUE");
            }
            return sb.toString();
        }

        public long ramBytesUsed() {
            return this.origin.ramBytesUsed();
        }
    }

    private static class AnyEqProxy
    extends EqProxy
    implements Iterable<EqProxy> {
        private Map<Function, EqProxy> proxies;
        @Nullable
        private ChildEqProxy delegate = null;

        private AnyEqProxy(Function compared, Map<Function, EqProxy> existingProxies) {
            super(compared);
            this.initProxies(existingProxies);
        }

        private void initProxies(Map<Function, EqProxy> existingProxies) {
            Symbol left = this.origin.arguments().getFirst();
            Signature signature = this.origin.signature();
            assert (signature != null) : "Expecting non-null signature while analyzing";
            Literal arrayLiteral = (Literal)this.origin.arguments().get(1);
            this.proxies = new HashMap<Function, EqProxy>();
            for (Literal<?> arrayElem : Literal.explodeCollection(arrayLiteral)) {
                Function f = new Function(EqOperator.SIGNATURE, Arrays.asList(left, arrayElem), Operator.RETURN_TYPE);
                EqProxy existingProxy = existingProxies.get(f);
                if (existingProxy == null) {
                    existingProxy = new ChildEqProxy(f, this);
                } else if (existingProxy instanceof ChildEqProxy) {
                    ChildEqProxy childEqProxy = (ChildEqProxy)existingProxy;
                    childEqProxy.addParent(this);
                }
                this.proxies.put(f, existingProxy);
            }
        }

        @Override
        @NotNull
        public Iterator<EqProxy> iterator() {
            return this.proxies.values().iterator();
        }

        @Override
        public <C, R> R accept(SymbolVisitor<C, R> visitor, C context) {
            if (this.delegate != null) {
                return this.delegate.accept(visitor, context);
            }
            return super.accept(visitor, context);
        }

        private void setDelegate(@Nullable ChildEqProxy childEqProxy) {
            this.delegate = childEqProxy;
        }

        private void cleanDelegate() {
            this.delegate = null;
        }

        private static class ChildEqProxy
        extends EqProxy {
            private final List<AnyEqProxy> parentProxies = new ArrayList<AnyEqProxy>();

            private ChildEqProxy(Function origin, AnyEqProxy parent) {
                super(origin);
                this.addParent(parent);
            }

            private void addParent(AnyEqProxy parentProxy) {
                this.parentProxies.add(parentProxy);
            }

            @Override
            public void setTrue() {
                super.setTrue();
                for (AnyEqProxy parent : this.parentProxies) {
                    parent.setTrue();
                    parent.setDelegate(this);
                }
            }

            @Override
            public void reset() {
                super.reset();
                for (AnyEqProxy parent : this.parentProxies) {
                    parent.reset();
                    parent.cleanDelegate();
                }
            }
        }
    }
}

