/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.aggregation.impl.average.numeric;

import io.crate.data.Input;
import io.crate.data.breaker.RamAccounting;
import io.crate.execution.engine.aggregation.AggregationFunction;
import io.crate.execution.engine.aggregation.DocValueAggregator;
import io.crate.execution.engine.aggregation.impl.average.AverageAggregation;
import io.crate.execution.engine.aggregation.impl.average.numeric.NumericAverageState;
import io.crate.execution.engine.aggregation.impl.average.numeric.NumericAverageStateType;
import io.crate.execution.engine.aggregation.impl.util.BigDecimalValueWrapper;
import io.crate.execution.engine.aggregation.impl.util.OverflowAwareMutableLong;
import io.crate.expression.reference.doc.lucene.LuceneReferenceResolver;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.Symbols;
import io.crate.memory.MemoryManager;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Functions;
import io.crate.metadata.Reference;
import io.crate.metadata.Scalar;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.metadata.functions.BoundSignature;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.NumericType;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.jetbrains.annotations.Nullable;

public class NumericAverageAggregation
extends AggregationFunction<NumericAverageState<?>, BigDecimal> {
    private final Signature signature;
    private final BoundSignature boundSignature;
    private final DataType<BigDecimal> returnType;

    public static void register(Functions.Builder builder) {
        for (String functionName : AverageAggregation.NAMES) {
            builder.add(Signature.builder(functionName, FunctionType.AGGREGATE).argumentTypes(DataTypes.NUMERIC.getTypeSignature()).returnType(DataTypes.NUMERIC.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build(), NumericAverageAggregation::new);
        }
    }

    private NumericAverageAggregation(Signature signature, BoundSignature boundSignature) {
        this.signature = signature;
        this.boundSignature = boundSignature;
        DataType<?> argumentType = boundSignature.argTypes().get(0);
        this.returnType = argumentType instanceof NumericAverageStateType ? boundSignature.returnType() : argumentType;
    }

    @Override
    @Nullable
    public NumericAverageState<?> newState(RamAccounting ramAccounting, Version indexVersionCreated, Version minNodeInCluster, MemoryManager memoryManager) {
        ramAccounting.addBytes(NumericAverageStateType.INIT_SIZE);
        return new NumericAverageState<BigDecimalValueWrapper>(new BigDecimalValueWrapper(BigDecimal.ZERO), 0L);
    }

    @Override
    public NumericAverageState<?> iterate(RamAccounting ramAccounting, MemoryManager memoryManager, NumericAverageState<?> state, Input<?> ... args) throws CircuitBreakingException {
        BigDecimal value;
        if (state != null && (value = this.returnType.implicitCast(args[0].value())) != null) {
            BigDecimal newValue = state.sum.value().add(value);
            ramAccounting.addBytes(NumericType.sizeDiff(newValue, state.sum.value()));
            state.sum.setValue(newValue);
            ++state.count;
        }
        return state;
    }

    @Override
    public NumericAverageState<?> reduce(RamAccounting ramAccounting, NumericAverageState<?> state1, NumericAverageState<?> state2) {
        if (state1 == null) {
            return state2;
        }
        if (state2 == null) {
            return state1;
        }
        BigDecimal newValue = state1.sum.value().add(state2.sum.value());
        ramAccounting.addBytes(NumericType.sizeDiff(newValue, state1.sum.value()));
        state1.sum.setValue(newValue);
        state1.count += state2.count;
        return state1;
    }

    @Override
    public BigDecimal terminatePartial(RamAccounting ramAccounting, NumericAverageState<?> state) {
        return state.value();
    }

    @Override
    public DataType<?> partialType() {
        return NumericAverageStateType.INSTANCE;
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    @Override
    public BoundSignature boundSignature() {
        return this.boundSignature;
    }

    @Override
    public boolean isRemovableCumulative() {
        return true;
    }

    @Override
    public NumericAverageState<?> removeFromAggregatedState(RamAccounting ramAccounting, NumericAverageState<?> previousAggState, Input<?>[] stateToRemove) {
        BigDecimal value;
        if (previousAggState != null && (value = this.returnType.implicitCast(stateToRemove[0].value())) != null) {
            BigDecimal newValue = previousAggState.sum.value().subtract(value);
            ramAccounting.addBytes(NumericType.sizeDiff(newValue, previousAggState.sum.value()));
            --previousAggState.count;
            previousAggState.sum.setValue(newValue);
        }
        return previousAggState;
    }

    @Override
    @Nullable
    public DocValueAggregator<?> getDocValueAggregator(LuceneReferenceResolver referenceResolver, List<Reference> aggregationReferences, DocTableInfo table, Version shardCreatedVersion, List<Literal<?>> optionalParams) {
        Reference reference = aggregationReferences.get(0);
        if (reference == null) {
            return null;
        }
        if (!reference.hasDocValues()) {
            return null;
        }
        List<DataType<?>> argumentTypes = Symbols.typeView(aggregationReferences);
        return switch (argumentTypes.get(0).id()) {
            case 2, 8, 9, 10 -> new AvgLong(reference.storageIdent());
            case 7 -> new AvgFloat(this.returnType, reference.storageIdent());
            case 6 -> new AvgDouble(this.returnType, reference.storageIdent());
            default -> null;
        };
    }

    static {
        DataTypes.register(1026, in -> NumericAverageStateType.INSTANCE);
    }

    static class AvgLong
    implements DocValueAggregator<NumericAverageState<OverflowAwareMutableLong>> {
        private final String columnName;
        private SortedNumericDocValues values;

        AvgLong(String columnName) {
            this.columnName = columnName;
        }

        @Override
        public NumericAverageState<OverflowAwareMutableLong> initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes(NumericAverageStateType.INIT_SIZE);
            return new NumericAverageState<OverflowAwareMutableLong>(new OverflowAwareMutableLong(0L), 0L);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, NumericAverageState<OverflowAwareMutableLong> state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1 && state != null) {
                ((OverflowAwareMutableLong)state.sum).add(this.values.nextValue());
                ++state.count;
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, NumericAverageState<OverflowAwareMutableLong> state) {
            return state;
        }
    }

    static class AvgFloat
    implements DocValueAggregator<NumericAverageState<BigDecimalValueWrapper>> {
        private final DataType<BigDecimal> returnType;
        private final String columnName;
        private SortedNumericDocValues values;

        AvgFloat(DataType<BigDecimal> returnType, String columnName) {
            this.returnType = returnType;
            this.columnName = columnName;
        }

        @Override
        public NumericAverageState<BigDecimalValueWrapper> initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes(NumericAverageStateType.INIT_SIZE);
            return new NumericAverageState<BigDecimalValueWrapper>(new BigDecimalValueWrapper(BigDecimal.ZERO), 0L);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, NumericAverageState<BigDecimalValueWrapper> state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1 && state != null) {
                BigDecimal value = this.returnType.implicitCast(Float.valueOf(NumericUtils.sortableIntToFloat((int)((int)this.values.nextValue()))));
                BigDecimal newValue = ((BigDecimalValueWrapper)state.sum).value().add(value);
                ramAccounting.addBytes(NumericType.sizeDiff(newValue, ((BigDecimalValueWrapper)state.sum).value()));
                ((BigDecimalValueWrapper)state.sum).setValue(newValue);
                ++state.count;
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, NumericAverageState<BigDecimalValueWrapper> state) {
            return state;
        }
    }

    static class AvgDouble
    implements DocValueAggregator<NumericAverageState<BigDecimalValueWrapper>> {
        private final DataType<BigDecimal> returnType;
        private final String columnName;
        private SortedNumericDocValues values;

        AvgDouble(DataType<BigDecimal> returnType, String columnName) {
            this.returnType = returnType;
            this.columnName = columnName;
        }

        @Override
        public NumericAverageState<BigDecimalValueWrapper> initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes(NumericAverageStateType.INIT_SIZE);
            return new NumericAverageState<BigDecimalValueWrapper>(new BigDecimalValueWrapper(BigDecimal.ZERO), 0L);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, NumericAverageState<BigDecimalValueWrapper> state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1 && state != null) {
                BigDecimal value = this.returnType.implicitCast(NumericUtils.sortableLongToDouble((long)this.values.nextValue()));
                BigDecimal newValue = ((BigDecimalValueWrapper)state.sum).value().add(value);
                ramAccounting.addBytes(NumericType.sizeDiff(newValue, ((BigDecimalValueWrapper)state.sum).value()));
                ((BigDecimalValueWrapper)state.sum).setValue(newValue);
                ++state.count;
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, NumericAverageState<BigDecimalValueWrapper> state) {
            return state;
        }
    }
}

