/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.aggregation.impl;

import io.crate.common.MutableDouble;
import io.crate.common.MutableFloat;
import io.crate.common.MutableLong;
import io.crate.data.Input;
import io.crate.data.breaker.RamAccounting;
import io.crate.execution.engine.aggregation.AggregationFunction;
import io.crate.execution.engine.aggregation.DocValueAggregator;
import io.crate.execution.engine.aggregation.impl.util.KahanSummationForDouble;
import io.crate.execution.engine.aggregation.impl.util.KahanSummationForFloat;
import io.crate.expression.reference.doc.lucene.LuceneReferenceResolver;
import io.crate.expression.symbol.Literal;
import io.crate.memory.MemoryManager;
import io.crate.metadata.FunctionProvider;
import io.crate.metadata.FunctionType;
import io.crate.metadata.Functions;
import io.crate.metadata.Reference;
import io.crate.metadata.Scalar;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.metadata.functions.BoundSignature;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import java.io.IOException;
import java.util.List;
import java.util.function.BinaryOperator;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class SumAggregation<T extends Number>
extends AggregationFunction<T, T> {
    public static final String NAME = "sum";
    private final Signature signature;
    private final BoundSignature boundSignature;
    private final BinaryOperator<T> addition;
    private final BinaryOperator<T> subtraction;
    private final DataType<T> returnType;
    private final int bytesSize;

    public static void register(Functions.Builder builder) {
        BinaryOperator add = Math::addExact;
        BinaryOperator sub = Math::subtractExact;
        builder.add(Signature.builder(NAME, FunctionType.AGGREGATE).argumentTypes(DataTypes.FLOAT.getTypeSignature()).returnType(DataTypes.FLOAT.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build(), SumAggregation.getSumAggregationForFloatFactory());
        builder.add(Signature.builder(NAME, FunctionType.AGGREGATE).argumentTypes(DataTypes.DOUBLE.getTypeSignature()).returnType(DataTypes.DOUBLE.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build(), SumAggregation.getSumAggregationForDoubleFactory());
        for (DataType dataType : List.of(DataTypes.BYTE, DataTypes.SHORT, DataTypes.INTEGER, DataTypes.LONG)) {
            builder.add(Signature.builder(NAME, FunctionType.AGGREGATE).argumentTypes(dataType.getTypeSignature()).returnType(DataTypes.LONG.getTypeSignature()).features(Scalar.Feature.DETERMINISTIC).build(), (signature, boundSignature) -> new SumAggregation<Long>(DataTypes.LONG, add, sub, (Signature)signature, (BoundSignature)boundSignature));
        }
    }

    @VisibleForTesting
    private SumAggregation(DataType<T> returnType, BinaryOperator<T> addition, BinaryOperator<T> subtraction, Signature signature, BoundSignature boundSignature) {
        this.addition = addition;
        this.subtraction = subtraction;
        this.returnType = returnType;
        this.bytesSize = returnType == DataTypes.FLOAT ? DataTypes.FLOAT.fixedSize() : (returnType == DataTypes.DOUBLE ? DataTypes.DOUBLE.fixedSize() : DataTypes.LONG.fixedSize());
        this.signature = signature;
        this.boundSignature = boundSignature;
    }

    @Override
    @Nullable
    public T newState(RamAccounting ramAccounting, Version indexVersionCreated, Version minNodeInCluster, MemoryManager memoryManager) {
        ramAccounting.addBytes((long)this.bytesSize);
        return null;
    }

    @Override
    public T iterate(RamAccounting ramAccounting, MemoryManager memoryManager, T state, Input<?> ... args) throws CircuitBreakingException {
        return (T)this.reduce(ramAccounting, state, (T)((Number)this.returnType.sanitizeValue(args[0].value())));
    }

    @Override
    public T reduce(RamAccounting ramAccounting, T state1, T state2) {
        if (state1 == null) {
            return state2;
        }
        if (state2 == null) {
            return state1;
        }
        return (T)((Number)this.addition.apply(state1, state2));
    }

    @Override
    public T terminatePartial(RamAccounting ramAccounting, T state) {
        return state;
    }

    @Override
    public DataType<?> partialType() {
        return this.boundSignature.returnType();
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    @Override
    public BoundSignature boundSignature() {
        return this.boundSignature;
    }

    @Override
    public boolean isRemovableCumulative() {
        return true;
    }

    @Override
    public T removeFromAggregatedState(RamAccounting ramAccounting, T previousAggState, Input<?>[] stateToRemove) {
        return (T)((Number)this.subtraction.apply(previousAggState, (Number)this.returnType.sanitizeValue(stateToRemove[0].value())));
    }

    @Override
    @Nullable
    public DocValueAggregator<?> getDocValueAggregator(LuceneReferenceResolver referenceResolver, List<Reference> aggregationReferences, DocTableInfo table, Version shardCreatedVersion, List<Literal<?>> optionalParams) {
        Reference reference = this.getAggReference(aggregationReferences);
        if (reference == null) {
            return null;
        }
        switch (reference.valueType().id()) {
            case 2: 
            case 8: 
            case 9: 
            case 10: {
                return new SumLong(reference.storageIdent());
            }
            case 7: {
                return new SumFloat(reference.storageIdent());
            }
            case 6: {
                return new SumDouble(reference.storageIdent());
            }
        }
        return null;
    }

    private static FunctionProvider.FunctionFactory getSumAggregationForDoubleFactory() {
        return (signature, boundSignature) -> {
            KahanSummationForDouble kahanSummation = new KahanSummationForDouble();
            return new SumAggregation<Double>(DataTypes.DOUBLE, kahanSummation::sum, (n1, n2) -> n1 - n2, (Signature)signature, (BoundSignature)boundSignature);
        };
    }

    private static FunctionProvider.FunctionFactory getSumAggregationForFloatFactory() {
        return (signature, boundSignature) -> {
            KahanSummationForFloat kahanSummation = new KahanSummationForFloat();
            return new SumAggregation<Float>(DataTypes.FLOAT, kahanSummation::sum, (n1, n2) -> Float.valueOf(n1.floatValue() - n2.floatValue()), (Signature)signature, (BoundSignature)boundSignature);
        };
    }

    @VisibleForTesting
    public static class SumLong
    implements DocValueAggregator<MutableLong> {
        private final String columnName;
        private SortedNumericDocValues values;

        SumLong(String columnName) {
            this.columnName = columnName;
        }

        @Override
        public MutableLong initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes((long)DataTypes.LONG.fixedSize());
            return new MutableLong(0L);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, MutableLong state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                state.setValue(Math.addExact(state.value(), this.values.nextValue()));
            }
        }

        public Long partialResult(RamAccounting ramAccounting, MutableLong state) {
            return state.hasValue() ? Long.valueOf(state.value()) : null;
        }
    }

    static class SumFloat
    implements DocValueAggregator<MutableFloat> {
        private final String columnName;
        private SortedNumericDocValues values;
        private final KahanSummationForFloat kahanSummation = new KahanSummationForFloat();

        SumFloat(String columnName) {
            this.columnName = columnName;
        }

        @Override
        public MutableFloat initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes((long)DataTypes.FLOAT.fixedSize());
            return new MutableFloat(0.0f);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, MutableFloat state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                float value = this.kahanSummation.sum(state.value(), NumericUtils.sortableIntToFloat((int)((int)this.values.nextValue())));
                state.setValue(value);
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, MutableFloat state) {
            return state.hasValue() ? Float.valueOf(state.value()) : null;
        }
    }

    static class SumDouble
    implements DocValueAggregator<MutableDouble> {
        private final String columnName;
        private SortedNumericDocValues values;
        private final KahanSummationForDouble kahanSummation = new KahanSummationForDouble();

        SumDouble(String columnName) {
            this.columnName = columnName;
        }

        @Override
        public MutableDouble initialState(RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
            ramAccounting.addBytes((long)DataTypes.DOUBLE.fixedSize());
            return new MutableDouble(0.0);
        }

        @Override
        public void loadDocValues(LeafReaderContext reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader.reader(), (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, MutableDouble state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                double value = this.kahanSummation.sum(state.value(), NumericUtils.sortableLongToDouble((long)this.values.nextValue()));
                state.setValue(value);
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, MutableDouble state) {
            return state.hasValue() ? Double.valueOf(state.value()) : null;
        }
    }
}

