/*
 * Decompiled with CFR 0.152.
 */
package io.crate.expression.symbol;

import io.crate.exceptions.ConversionException;
import io.crate.expression.operator.all.AllOperator;
import io.crate.expression.operator.any.AnyOperator;
import io.crate.expression.scalar.cast.CastMode;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolType;
import io.crate.expression.symbol.SymbolVisitor;
import io.crate.expression.symbol.Symbols;
import io.crate.expression.symbol.format.MatchPrinter;
import io.crate.expression.symbol.format.Style;
import io.crate.metadata.FunctionName;
import io.crate.metadata.Reference;
import io.crate.metadata.functions.Signature;
import io.crate.sql.SqlFormatter;
import io.crate.sql.tree.ColumnType;
import io.crate.sql.tree.Expression;
import io.crate.types.ArrayType;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.jetbrains.annotations.Nullable;

public class Function
implements Symbol,
Cloneable {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Function.class);
    private static final Map<String, String> ARITHMETIC_OPERATOR_MAPPING = Map.ofEntries(Map.entry("add", "+"), Map.entry("subtract", "-"), Map.entry("multiply", "*"), Map.entry("divide", "/"), Map.entry("mod", "%"), Map.entry("modulus", "%"));
    private final List<Symbol> arguments;
    protected final DataType<?> returnType;
    protected final Signature signature;
    @Nullable
    protected final Symbol filter;

    public Function(StreamInput in) throws IOException {
        Signature generatedSignature = null;
        if (in.getVersion().before(Version.V_5_0_0)) {
            generatedSignature = Signature.readFromFunctionInfo(in);
        }
        this.filter = in.getVersion().onOrAfter(Version.V_4_1_0) ? Symbol.nullableFromStream(in) : null;
        this.arguments = List.copyOf(Symbols.fromStream(in));
        if (in.getVersion().onOrAfter(Version.V_4_2_0)) {
            if (in.getVersion().before(Version.V_5_0_0)) {
                in.readBoolean();
            }
            this.signature = new Signature(in);
            this.returnType = DataTypes.fromStream(in);
        } else {
            assert (generatedSignature != null) : "expecting a non-null generated signature";
            this.signature = generatedSignature;
            this.returnType = generatedSignature.getReturnType().createType();
        }
    }

    public Function(Signature signature, List<Symbol> arguments, DataType<?> returnType) {
        this(signature, arguments, returnType, null);
    }

    public Function(Signature signature, List<Symbol> arguments, DataType<?> returnType, @Nullable Symbol filter) {
        this.signature = signature;
        this.arguments = List.copyOf(arguments);
        this.returnType = returnType;
        this.filter = filter;
    }

    public List<Symbol> arguments() {
        return this.arguments;
    }

    public String name() {
        return this.signature.getName().name();
    }

    public Signature signature() {
        return this.signature;
    }

    @Nullable
    public Symbol filter() {
        return this.filter;
    }

    @Override
    public DataType<?> valueType() {
        return this.returnType;
    }

    @Override
    public boolean isDeterministic() {
        if (!this.signature.isDeterministic()) {
            return false;
        }
        for (Symbol arg : this.arguments) {
            if (arg.isDeterministic()) continue;
            return false;
        }
        return true;
    }

    @Override
    public boolean any(Predicate<? super Symbol> predicate) {
        if (predicate.test(this)) {
            return true;
        }
        for (Symbol arg : this.arguments) {
            if (!arg.any(predicate)) continue;
            return true;
        }
        if (this.filter != null) {
            return this.filter.any(predicate);
        }
        return false;
    }

    @Override
    public Symbol cast(DataType<?> targetType, CastMode ... modes) {
        String name = this.signature.getName().name();
        if (targetType instanceof ArrayType) {
            ArrayType arrayType = (ArrayType)targetType;
            if (name.equals("_array")) {
                return this.castArrayElements(arrayType, modes);
            }
        }
        return Symbol.super.cast(targetType, modes);
    }

    @Override
    public Symbol uncast() {
        if (this.isCast()) {
            return this.arguments.get(0);
        }
        return this;
    }

    private Symbol castArrayElements(ArrayType<?> targetType, CastMode ... modes) {
        DataType<?> innerType = targetType.innerType();
        ArrayList<Symbol> newArgs = new ArrayList<Symbol>(this.arguments.size());
        for (Symbol arg : this.arguments) {
            try {
                newArgs.add(arg.cast(innerType, modes));
            }
            catch (ConversionException e) {
                throw new ConversionException(this.returnType, targetType);
            }
        }
        return new Function(this.signature, newArgs, targetType, null);
    }

    public boolean isCast() {
        return this.castMode() != null;
    }

    @Nullable
    public CastMode castMode() {
        return switch (this.name()) {
            case "cast" -> CastMode.EXPLICIT;
            case "_cast" -> CastMode.IMPLICIT;
            case "try_cast" -> CastMode.TRY;
            default -> null;
        };
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + this.arguments.stream().mapToLong(Accountable::ramBytesUsed).sum() + this.returnType.ramBytesUsed() + (this.filter == null ? 0L : this.filter.ramBytesUsed()) + this.signature.ramBytesUsed();
    }

    @Override
    public SymbolType symbolType() {
        return SymbolType.FUNCTION;
    }

    @Override
    public <C, R> R accept(SymbolVisitor<C, R> visitor, C context) {
        return visitor.visitFunction(this, context);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        if (out.getVersion().before(Version.V_5_0_0)) {
            this.signature.writeAsFunctionInfo(out, Symbols.typeView(this.arguments));
        }
        if (out.getVersion().onOrAfter(Version.V_4_1_0)) {
            Symbol.nullableToStream(this.filter, out);
        }
        Symbols.toStream(this.arguments, out);
        if (out.getVersion().onOrAfter(Version.V_4_2_0)) {
            if (out.getVersion().before(Version.V_5_0_0)) {
                out.writeBoolean(true);
            }
            this.signature.writeTo(out);
            DataTypes.toStream(this.returnType, out);
        }
    }

    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Function function = (Function)o;
        return Objects.equals(this.arguments, function.arguments) && Objects.equals(this.signature, function.signature) && Objects.equals(this.filter, function.filter);
    }

    public int hashCode() {
        int result = this.arguments.hashCode();
        result = 31 * result + this.signature.hashCode();
        result = 31 * result + (this.filter == null ? 0 : this.filter.hashCode());
        return result;
    }

    public String toString() {
        return this.toString(Style.UNQUALIFIED);
    }

    @Override
    public String toString(Style style) {
        String name;
        StringBuilder builder = new StringBuilder();
        switch (name = this.signature.getName().name()) {
            case "match": {
                MatchPrinter.printMatchPredicate(this, style, builder);
                break;
            }
            case "_negate": {
                this.printNegate(builder, style);
                break;
            }
            case "subscript": 
            case "subscript_obj": {
                this.printSubscriptFunction(builder, style);
                break;
            }
            case "_subscript_record": {
                this.printSubscriptRecord(builder, style);
                break;
            }
            case "_exists": {
                builder.append("EXISTS ");
                builder.append(this.arguments.get(0).toString(style));
                break;
            }
            case "current_user": {
                builder.append("CURRENT_USER");
                break;
            }
            case "session_user": {
                builder.append("SESSION_USER");
                break;
            }
            case "current_schemas": {
                builder.append("current_schemas");
                break;
            }
            case "current_schema": {
                builder.append("current_schema");
                break;
            }
            case "op_isnull": {
                builder.append("(");
                builder.append(this.arguments.get(0).toString(style));
                builder.append(" IS NULL)");
                break;
            }
            case "op_not": {
                builder.append("(NOT ");
                builder.append(this.arguments.get(0).toString(style));
                builder.append(")");
                break;
            }
            case "count": {
                if (this.arguments.isEmpty()) {
                    builder.append("count(*)");
                    this.printFilter(builder, style);
                    break;
                }
                this.printFunctionWithParenthesis(builder, style);
                break;
            }
            case "current_timestamp": {
                if (this.arguments.isEmpty()) {
                    builder.append("CURRENT_TIMESTAMP");
                    break;
                }
                this.printFunctionWithParenthesis(builder, style);
                break;
            }
            case "current_time": {
                if (this.arguments.isEmpty()) {
                    builder.append("CURRENT_TIME");
                    break;
                }
                this.printFunctionWithParenthesis(builder, style);
                break;
            }
            case "_array": {
                this.printArray(builder, style);
                break;
            }
            case "case": {
                builder.append("CASE");
                for (int i = 2; i < this.arguments.size(); i += 2) {
                    builder.append(" WHEN ");
                    builder.append(this.arguments.get(i).toString(style));
                    builder.append(" THEN ");
                    builder.append(this.arguments.get(i + 1).toString(style));
                }
                builder.append(" ELSE ");
                builder.append(this.arguments.get(1).toString(style));
                builder.append(" END");
                break;
            }
            default: {
                if (AnyOperator.OPERATOR_NAMES.contains(name)) {
                    this.printAnyOperator(builder, style);
                    break;
                }
                if (AllOperator.OPERATOR_NAMES.contains(name)) {
                    this.printAllOperator(builder, style);
                    break;
                }
                if (this.isCast()) {
                    this.printCastFunction(builder, style);
                    break;
                }
                if (name.startsWith("op_")) {
                    this.printOperator(builder, style, null);
                    break;
                }
                if (name.startsWith("extract_")) {
                    this.printExtract(builder, style);
                    break;
                }
                String arithmeticOperator = ARITHMETIC_OPERATOR_MAPPING.get(name);
                if (arithmeticOperator != null) {
                    this.printOperator(builder, style, arithmeticOperator);
                    break;
                }
                this.printFunctionWithParenthesis(builder, style);
            }
        }
        return builder.toString();
    }

    private void printNegate(StringBuilder builder, Style style) {
        builder.append("- ");
        builder.append(this.arguments.get(0).toString(style));
    }

    private void printSubscriptRecord(StringBuilder builder, Style style) {
        builder.append("(");
        builder.append(this.arguments.get(0).toString(style));
        builder.append(").");
        builder.append(this.arguments.get(1).toString(style));
    }

    private void printArray(StringBuilder builder, Style style) {
        builder.append("[");
        int size = this.arguments.size();
        for (int i = 0; i < size; ++i) {
            Symbol arg = this.arguments.get(i);
            builder.append(arg.toString(style));
            if (i + 1 >= size) continue;
            builder.append(", ");
        }
        builder.append("]");
    }

    private void printAnyOperator(StringBuilder builder, Style style) {
        String name = this.signature.getName().name();
        assert (name.startsWith("any_")) : "function for printAnyOperator must start with any prefix";
        assert (this.arguments.size() == 2) : "function's number of arguments must be 2";
        String operatorName = name.substring("any_".length()).replace('_', ' ').toUpperCase(Locale.ENGLISH);
        builder.append("(").append(this.arguments.get(0).toString(style)).append(" ").append(operatorName).append(" ").append("ANY(").append(this.arguments.get(1).toString(style)).append("))");
    }

    private void printAllOperator(StringBuilder builder, Style style) {
        String name = this.signature.getName().name();
        assert (name.startsWith("_all_")) : "function for printAllOperator must start with all prefix";
        assert (this.arguments.size() == 2) : "function's number of arguments must be 2";
        String operatorName = name.substring("_all_".length()).replace('_', ' ').toUpperCase(Locale.ENGLISH);
        builder.append("(").append(this.arguments.get(0).toString(style)).append(" ").append(operatorName).append(" ").append("ALL(").append(this.arguments.get(1).toString(style)).append("))");
    }

    private void printCastFunction(StringBuilder builder, Style style) {
        String name = this.signature.getName().name();
        assert (this.arguments.size() == 2) : "Expecting 2 arguments for function " + name;
        if (name.equalsIgnoreCase("_cast")) {
            builder.append(this.arguments().get(0).toString(style));
        } else {
            DataType<?> targetType = this.arguments.get(1).valueType();
            ColumnType<Expression> columnType = targetType.toColumnType(null);
            builder.append(name).append("(").append(this.arguments().get(0).toString(style)).append(" AS ").append(SqlFormatter.formatSql(columnType)).append(")");
        }
    }

    private void printExtract(StringBuilder builder, Style style) {
        String name = this.signature.getName().name();
        assert (name.startsWith("extract_")) : "name of function passed to printExtract must start with extract_";
        String fieldName = name.substring("extract_".length());
        builder.append("extract(").append(fieldName).append(" FROM ");
        builder.append(this.arguments.get(0).toString(style));
        builder.append(")");
    }

    private void printOperator(StringBuilder builder, Style style, String operator) {
        if (operator == null) {
            String name = this.signature.getName().name();
            assert (name.startsWith("op_"));
            operator = name.substring("op_".length()).toUpperCase(Locale.ENGLISH);
        }
        builder.append("(").append(this.arguments.get(0).toString(style)).append(" ").append(operator).append(" ").append(this.arguments.get(1).toString(style)).append(")");
    }

    /*
     * Enabled aggressive block sorting
     */
    private void printSubscriptFunction(StringBuilder builder, Style style) {
        Symbol base = this.arguments.get(0);
        if (base instanceof Reference) {
            Reference ref = (Reference)base;
            if (base.valueType() instanceof ArrayType && !ref.column().path().isEmpty()) {
                builder.append(ref.column().getRoot().quotedOutputName());
                builder.append("[");
                builder.append(this.arguments.get(1).toString(style));
                builder.append("]");
                ref.column().path().forEach(path -> builder.append("['").append((String)path).append("']"));
                return;
            }
        }
        builder.append(base.toString(style));
        builder.append("[");
        builder.append(this.arguments.get(1).toString(style));
        builder.append("]");
    }

    private void printFunctionWithParenthesis(StringBuilder builder, Style style) {
        FunctionName functionName = this.signature.getName();
        builder.append(functionName.displayName());
        builder.append("(");
        for (int i = 0; i < this.arguments.size(); ++i) {
            Symbol argument = this.arguments.get(i);
            builder.append(argument.toString(style));
            if (i + 1 >= this.arguments.size()) continue;
            builder.append(", ");
        }
        builder.append(")");
        this.printFilter(builder, style);
    }

    private void printFilter(StringBuilder builder, Style style) {
        if (this.filter != null) {
            builder.append(" FILTER (WHERE ");
            builder.append(this.filter.toString(style));
            builder.append(")");
        }
    }
}

