/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.optimizer.rule;

import io.crate.common.collections.Lists;
import io.crate.execution.engine.aggregation.impl.CountAggregation;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.Symbol;
import io.crate.metadata.Reference;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.metadata.functions.Signature;
import io.crate.planner.operators.Collect;
import io.crate.planner.operators.Count;
import io.crate.planner.operators.HashAggregate;
import io.crate.planner.optimizer.Rule;
import io.crate.planner.optimizer.matcher.Capture;
import io.crate.planner.optimizer.matcher.Captures;
import io.crate.planner.optimizer.matcher.Pattern;
import io.crate.planner.optimizer.matcher.Patterns;
import java.util.List;

public final class MergeAggregateAndCollectToCount
implements Rule<HashAggregate> {
    private final Capture<Collect> collectCapture = new Capture();
    private final Pattern<HashAggregate> pattern = Pattern.typeOf(HashAggregate.class).with(Patterns.source(), Pattern.typeOf(Collect.class).capturedAs(this.collectCapture).with(collect -> collect.relation().tableInfo() instanceof DocTableInfo)).with(aggregate -> MergeAggregateAndCollectToCount.isCountAggregate(aggregate.aggregates()));

    private static boolean isCountAggregate(List<Function> aggregates) {
        Reference ref;
        Symbol symbol;
        if (aggregates.size() != 1) {
            return false;
        }
        Function aggregate = aggregates.get(0);
        Signature signature = aggregate.signature();
        if (signature.equals(CountAggregation.COUNT_STAR_SIGNATURE)) {
            return true;
        }
        return signature.getName().equals(CountAggregation.SIGNATURE.getName()) && (symbol = aggregate.arguments().get(0)) instanceof Reference && !(ref = (Reference)symbol).isNullable();
    }

    @Override
    public Pattern<HashAggregate> pattern() {
        return this.pattern;
    }

    public Count apply(HashAggregate aggregate, Captures captures, Rule.Context context) {
        Collect collect = captures.get(this.collectCapture);
        Function countAggregate = (Function)Lists.getOnlyElement(aggregate.aggregates());
        if (countAggregate.filter() != null) {
            return new Count(countAggregate, collect.relation(), collect.where().add(countAggregate.filter()));
        }
        return new Count(countAggregate, collect.relation(), collect.where());
    }
}

