/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.aggregation.impl;

import io.crate.Streamer;
import io.crate.common.MutableLong;
import io.crate.data.Input;
import io.crate.data.breaker.RamAccounting;
import io.crate.execution.engine.aggregation.AggregationFunction;
import io.crate.execution.engine.aggregation.DocValueAggregator;
import io.crate.execution.engine.aggregation.impl.templates.BinaryDocValueAggregator;
import io.crate.execution.engine.aggregation.impl.templates.SortedNumericDocValueAggregator;
import io.crate.expression.reference.doc.lucene.LuceneReferenceResolver;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.Symbol;
import io.crate.memory.MemoryManager;
import io.crate.metadata.ColumnIdent;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Functions;
import io.crate.metadata.NodeContext;
import io.crate.metadata.Reference;
import io.crate.metadata.RowGranularity;
import io.crate.metadata.Scalar;
import io.crate.metadata.TransactionContext;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.metadata.functions.BoundSignature;
import io.crate.metadata.functions.Signature;
import io.crate.metadata.functions.TypeVariableConstraint;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.FixedWidthType;
import io.crate.types.TypeSignature;
import java.io.IOException;
import java.util.List;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.jetbrains.annotations.Nullable;

public class CountAggregation
extends AggregationFunction<MutableLong, Long> {
    public static final String NAME = "count";
    public static final Signature SIGNATURE = Signature.builder("count", FunctionType.AGGREGATE).argumentTypes(TypeSignature.parse("V")).returnType(DataTypes.LONG.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).typeVariableConstraints(TypeVariableConstraint.typeVariable("V")).build();
    public static final Signature COUNT_STAR_SIGNATURE = Signature.builder("count", FunctionType.AGGREGATE).argumentTypes(new TypeSignature[0]).returnType(DataTypes.LONG.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build();
    private final Signature signature;
    private final BoundSignature boundSignature;
    private final boolean hasArgs;

    public static void register(Functions.Builder builder) {
        builder.add(SIGNATURE, (signature, boundSignature) -> new CountAggregation((Signature)signature, (BoundSignature)boundSignature, true));
        builder.add(COUNT_STAR_SIGNATURE, (signature, boundSignature) -> new CountAggregation((Signature)signature, (BoundSignature)boundSignature, false));
    }

    private CountAggregation(Signature signature, BoundSignature boundSignature, boolean hasArgs) {
        this.signature = signature;
        this.boundSignature = boundSignature;
        this.hasArgs = hasArgs;
    }

    @Override
    public MutableLong iterate(RamAccounting ramAccounting, MemoryManager memoryManager, MutableLong state, Input<?> ... args) {
        if (!this.hasArgs || args[0].value() != null) {
            return state.add(1L);
        }
        return state;
    }

    @Override
    @Nullable
    public MutableLong newState(RamAccounting ramAccounting, Version minNodeInCluster, MemoryManager memoryManager) {
        ramAccounting.addBytes((long)LongStateType.INSTANCE.fixedSize());
        return new MutableLong(0L);
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    @Override
    public BoundSignature boundSignature() {
        return this.boundSignature;
    }

    @Override
    public Symbol normalizeSymbol(Function function, TransactionContext txnCtx, NodeContext nodeCtx) {
        Symbol arg;
        assert (function.arguments().size() <= 1) : "function's number of arguments must be 0 or 1";
        if (function.arguments().size() == 1 && (arg = function.arguments().get(0)) instanceof Input) {
            Input input = (Input)arg;
            if (input.value() == null) {
                return Literal.of(0L);
            }
            return new Function(COUNT_STAR_SIGNATURE, List.of(), DataTypes.LONG);
        }
        return function;
    }

    @Override
    public DataType<?> partialType() {
        return LongStateType.INSTANCE;
    }

    @Override
    public MutableLong reduce(RamAccounting ramAccounting, MutableLong state1, MutableLong state2) {
        return state1.add(state2.value());
    }

    @Override
    public Long terminatePartial(RamAccounting ramAccounting, MutableLong state) {
        return state.value();
    }

    @Override
    public boolean isRemovableCumulative() {
        return true;
    }

    @Override
    public MutableLong removeFromAggregatedState(RamAccounting ramAccounting, MutableLong previousAggState, Input<?>[] stateToRemove) {
        if (!this.hasArgs || stateToRemove[0].value() != null) {
            return previousAggState.sub(1L);
        }
        return previousAggState;
    }

    private DocValueAggregator<?> getDocValueAggregator(Reference ref) {
        if (!ref.hasDocValues() || ref.granularity() != RowGranularity.DOC) {
            return null;
        }
        return switch (ref.valueType().id()) {
            case 2, 6, 7, 8, 9, 10, 11, 13, 15 -> new SortedNumericDocValueAggregator<MutableLong>(ref.storageIdent(), (ramAccounting, memoryManager, version) -> {
                ramAccounting.addBytes((long)LongStateType.INSTANCE.fixedSize());
                return new MutableLong(0L);
            }, (ramAccounting, sortedNumericDocValues, state) -> state.add(1L));
            case 4, 5, 25 -> new BinaryDocValueAggregator<MutableLong>(ref.storageIdent(), (ramAccounting, memoryManager, version) -> {
                ramAccounting.addBytes((long)LongStateType.INSTANCE.fixedSize());
                return new MutableLong(0L);
            }, (ramAccounting, sortedSetDocValues, state) -> state.add(1L));
            default -> null;
        };
    }

    @Override
    @Nullable
    public DocValueAggregator<?> getDocValueAggregator(LuceneReferenceResolver referenceResolver, List<Reference> aggregationReferences, DocTableInfo table, Version shardCreatedVersion, List<Literal<?>> optionalParams) {
        if (aggregationReferences.size() != 1) {
            return null;
        }
        Reference reference = aggregationReferences.getFirst();
        if (reference == null) {
            return null;
        }
        if (reference.valueType().id() == 12) {
            for (ColumnIdent notNullCol : table.notNullColumns()) {
                DocValueAggregator<?> subColDocValAggregator;
                Reference notNullColRef;
                if (!notNullCol.isChildOf(reference.column()) || (notNullColRef = table.getReference(notNullCol)) == null || (subColDocValAggregator = this.getDocValueAggregator(notNullColRef)) == null) continue;
                return subColDocValAggregator;
            }
        }
        return this.getDocValueAggregator(reference);
    }

    static {
        DataTypes.register(16384, streamInput -> LongStateType.INSTANCE);
    }

    public static class LongStateType
    extends DataType<MutableLong>
    implements FixedWidthType,
    Streamer<MutableLong> {
        public static final int ID = 16384;
        public static final LongStateType INSTANCE = new LongStateType();

        @Override
        public int id() {
            return 16384;
        }

        @Override
        public DataType.Precedence precedence() {
            return DataType.Precedence.CUSTOM;
        }

        @Override
        public String getName() {
            return "long_state";
        }

        @Override
        public Streamer<MutableLong> streamer() {
            return this;
        }

        @Override
        public MutableLong sanitizeValue(Object value) {
            return (MutableLong)value;
        }

        @Override
        public int compare(MutableLong val1, MutableLong val2) {
            if (val1 == null) {
                return -1;
            }
            if (val2 == null) {
                return 1;
            }
            return Long.compare(val1.value(), val2.value());
        }

        @Override
        public MutableLong readValueFrom(StreamInput in) throws IOException {
            return new MutableLong(in.readVLong());
        }

        @Override
        public void writeValueTo(StreamOutput out, MutableLong v) throws IOException {
            out.writeVLong(v.value());
        }

        @Override
        public int fixedSize() {
            return DataTypes.LONG.fixedSize();
        }

        @Override
        public long valueBytes(MutableLong value) {
            return DataTypes.LONG.fixedSize();
        }
    }
}

