/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine;

import io.crate.common.concurrent.CompletableFutures;
import io.crate.common.exceptions.Exceptions;
import io.crate.data.Bucket;
import io.crate.data.CollectingRowConsumer;
import io.crate.data.InMemoryBatchIterator;
import io.crate.data.Row1;
import io.crate.data.RowConsumer;
import io.crate.data.SentinelRow;
import io.crate.exceptions.SQLExceptions;
import io.crate.execution.dml.BulkResponse;
import io.crate.execution.dsl.phases.ExecutionPhase;
import io.crate.execution.dsl.phases.NodeOperation;
import io.crate.execution.dsl.phases.NodeOperationGrouper;
import io.crate.execution.dsl.phases.NodeOperationTree;
import io.crate.execution.engine.BucketForwarder;
import io.crate.execution.engine.InitializationTracker;
import io.crate.execution.engine.InterceptingRowConsumer;
import io.crate.execution.engine.PagingUnsupportedResultListener;
import io.crate.execution.engine.distribution.StreamBucket;
import io.crate.execution.jobs.DownstreamRXTask;
import io.crate.execution.jobs.InstrumentedIndexSearcher;
import io.crate.execution.jobs.JobSetup;
import io.crate.execution.jobs.PageBucketReceiver;
import io.crate.execution.jobs.RootTask;
import io.crate.execution.jobs.SharedShardContexts;
import io.crate.execution.jobs.TasksService;
import io.crate.execution.jobs.kill.KillJobsNodeRequest;
import io.crate.execution.jobs.kill.KillResponse;
import io.crate.execution.jobs.transport.JobRequest;
import io.crate.execution.jobs.transport.JobResponse;
import io.crate.execution.support.ActionExecutor;
import io.crate.execution.support.NodeRequest;
import io.crate.metadata.TransactionContext;
import io.crate.profile.ProfilingContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class JobLauncher {
    private final ActionExecutor<NodeRequest<JobRequest>, JobResponse> transportJobAction;
    private final ActionExecutor<KillJobsNodeRequest, KillResponse> killNodeAction;
    private final List<NodeOperationTree> nodeOperationTrees;
    private final UUID jobId;
    private final ClusterService clusterService;
    private final JobSetup jobSetup;
    private final TasksService tasksService;
    private final IndicesService indicesService;
    private final boolean enableProfiling;

    JobLauncher(UUID jobId, ClusterService clusterService, JobSetup jobSetup, TasksService tasksService, IndicesService indicesService, ActionExecutor<NodeRequest<JobRequest>, JobResponse> transportJobAction, ActionExecutor<KillJobsNodeRequest, KillResponse> killNodeAction, List<NodeOperationTree> nodeOperationTrees, boolean enableProfiling) {
        this.jobId = jobId;
        this.clusterService = clusterService;
        this.jobSetup = jobSetup;
        this.tasksService = tasksService;
        this.indicesService = indicesService;
        this.transportJobAction = transportJobAction;
        this.killNodeAction = killNodeAction;
        this.nodeOperationTrees = nodeOperationTrees;
        this.enableProfiling = enableProfiling;
    }

    public void execute(RowConsumer consumer, TransactionContext txnCtx, boolean waitForCompletion) {
        if (waitForCompletion) {
            this.execute(consumer, txnCtx);
        } else {
            this.execute((RowConsumer)new CollectingRowConsumer(Collectors.counting()), txnCtx);
            consumer.accept(InMemoryBatchIterator.of((Object)Row1.ROW_COUNT_UNKNOWN, (Object)SentinelRow.SENTINEL), null);
        }
    }

    public void execute(RowConsumer consumer, TransactionContext txnCtx) {
        assert (this.nodeOperationTrees.size() == 1) : "must only have 1 NodeOperationTree for non-bulk operations";
        NodeOperationTree nodeOperationTree = this.nodeOperationTrees.get(0);
        Map<String, Collection<NodeOperation>> operationByServer = NodeOperationGrouper.groupByServer(nodeOperationTree.nodeOperations());
        List<ExecutionPhase> handlerPhases = Collections.singletonList(nodeOperationTree.leaf());
        List<RowConsumer> handlerConsumers = Collections.singletonList(consumer);
        try {
            this.setupTasks(txnCtx, operationByServer, handlerPhases, handlerConsumers);
        }
        catch (Throwable throwable) {
            consumer.accept(null, throwable);
        }
    }

    public CompletableFuture<BulkResponse> executeBulk(TransactionContext txnCtx) {
        Iterable<NodeOperation> nodeOperations = this.nodeOperationTrees.stream().flatMap(opTree -> opTree.nodeOperations().stream())::iterator;
        Map<String, Collection<NodeOperation>> operationByServer = NodeOperationGrouper.groupByServer(nodeOperations);
        ArrayList<ExecutionPhase> handlerPhases = new ArrayList<ExecutionPhase>(this.nodeOperationTrees.size());
        ArrayList<RowConsumer> handlerConsumers = new ArrayList<RowConsumer>(this.nodeOperationTrees.size());
        CompletableFuture<BulkResponse> result = new CompletableFuture<BulkResponse>();
        BulkResponse bulkResponse = new BulkResponse(this.nodeOperationTrees.size());
        ArrayList<CompletionStage> results = new ArrayList<CompletionStage>(this.nodeOperationTrees.size());
        int i = 0;
        while (i < this.nodeOperationTrees.size()) {
            NodeOperationTree nodeOperationTree = this.nodeOperationTrees.get(i);
            CollectingRowConsumer consumer = new CollectingRowConsumer(Collectors.collectingAndThen(Collectors.summingLong(r -> (Long)r.get(0)), sum -> sum));
            handlerConsumers.add((RowConsumer)consumer);
            int idx = i++;
            results.add(consumer.completionFuture().whenComplete((rowCount, t) -> bulkResponse.update(idx, (Long)rowCount, (Throwable)t)));
            handlerPhases.add(nodeOperationTree.leaf());
        }
        CompletableFutures.allSuccessfulAsList(results).whenComplete((list, t) -> {
            if (t == null) {
                result.complete(bulkResponse);
            } else {
                result.completeExceptionally((Throwable)t);
            }
        });
        try {
            this.setupTasks(txnCtx, operationByServer, handlerPhases, handlerConsumers);
        }
        catch (Throwable throwable) {
            return CompletableFuture.failedFuture(throwable);
        }
        return result;
    }

    private void setupTasks(TransactionContext txnCtx, Map<String, Collection<NodeOperation>> operationByServer, List<ExecutionPhase> handlerPhases, List<RowConsumer> handlerConsumers) throws Throwable {
        assert (handlerPhases.size() == handlerConsumers.size()) : "handlerPhases size must match handlerConsumers size";
        String localNodeId = this.clusterService.localNode().getId();
        Collection<NodeOperation> localNodeOperations = operationByServer.remove(localNodeId);
        if (localNodeOperations == null) {
            localNodeOperations = Collections.emptyList();
        }
        InitializationTracker initializationTracker = new InitializationTracker(operationByServer.size() + 1);
        List<HandlerPhase> handlerPhaseAndReceiver = this.createHandlerPhaseAndReceivers(handlerPhases, handlerConsumers, initializationTracker);
        RootTask.Builder builder = this.tasksService.newBuilder(this.jobId, txnCtx.sessionSettings().userName(), localNodeId, operationByServer.keySet());
        SharedShardContexts sharedShardContexts = this.maybeInstrumentProfiler(builder);
        List<CompletableFuture<StreamBucket>> directResponseFutures = this.jobSetup.prepareOnHandler(txnCtx.sessionSettings(), localNodeOperations, builder, handlerPhaseAndReceiver, sharedShardContexts);
        RootTask localTask = this.tasksService.createTask(builder);
        List<PageBucketReceiver> pageBucketReceivers = this.getHandlerBucketReceivers(localTask, handlerPhaseAndReceiver);
        int bucketIdx = 0;
        if (!directResponseFutures.isEmpty()) {
            assert (directResponseFutures.size() == pageBucketReceivers.size()) : "directResponses size must match pageBucketReceivers";
            this.forwardDirectResponseToPageBucketRX(initializationTracker, directResponseFutures, pageBucketReceivers, bucketIdx);
            ++bucketIdx;
        }
        int nextBucket = bucketIdx;
        CompletableFuture<Void> start = localTask.start();
        start.whenComplete((void_, err) -> {
            if (err == null) {
                initializationTracker.jobInitialized();
                this.sendJobRequests(txnCtx, localNodeId, operationByServer, pageBucketReceivers, handlerPhaseAndReceiver, nextBucket, initializationTracker);
            } else {
                initializationTracker.jobInitializationFailed((Throwable)err);
                int bucket = nextBucket;
                Exception e = Exceptions.toException((Throwable)err);
                for (int i = 0; i < operationByServer.size(); ++i) {
                    BucketForwarder listener = new BucketForwarder(pageBucketReceivers, bucket, initializationTracker);
                    listener.onFailure(e);
                    ++bucket;
                }
            }
        });
    }

    private void forwardDirectResponseToPageBucketRX(InitializationTracker initializationTracker, List<CompletableFuture<StreamBucket>> directResponseFutures, List<PageBucketReceiver> pageBucketReceivers, int bucketIdx) {
        int bucket = bucketIdx;
        for (int i = 0; i < directResponseFutures.size(); ++i) {
            CompletableFuture<StreamBucket> directResponse = directResponseFutures.get(i);
            PageBucketReceiver pageBucketReceiver = pageBucketReceivers.get(i);
            directResponse.whenComplete((res, err) -> {
                if (err == null) {
                    initializationTracker.jobInitialized();
                    pageBucketReceiver.setBucket(bucket, (Bucket)res, true, PagingUnsupportedResultListener.INSTANCE);
                } else {
                    err = SQLExceptions.unwrap(err);
                    initializationTracker.jobInitializationFailed((Throwable)err);
                    pageBucketReceiver.kill((Throwable)err);
                }
            });
        }
    }

    private SharedShardContexts maybeInstrumentProfiler(RootTask.Builder builder) {
        if (this.enableProfiling) {
            HashMap<ShardId, QueryProfiler> profilers = new HashMap<ShardId, QueryProfiler>();
            ProfilingContext profilingContext = new ProfilingContext(profilers);
            builder.profilingContext(profilingContext);
            return new SharedShardContexts(this.indicesService, (shardId, indexSearcher) -> {
                QueryProfiler queryProfiler = new QueryProfiler();
                profilers.put((ShardId)shardId, queryProfiler);
                return new InstrumentedIndexSearcher((Engine.Searcher)indexSearcher, queryProfiler);
            });
        }
        return new SharedShardContexts(this.indicesService, (shardId, indexSearcher) -> indexSearcher);
    }

    private List<HandlerPhase> createHandlerPhaseAndReceivers(List<ExecutionPhase> handlerPhases, List<RowConsumer> handlerReceivers, InitializationTracker initializationTracker) {
        ArrayList<HandlerPhase> handlerPhaseAndReceiver = new ArrayList<HandlerPhase>(handlerPhases.size());
        ListIterator<RowConsumer> consumerIt = handlerReceivers.listIterator();
        for (ExecutionPhase handlerPhase : handlerPhases) {
            InterceptingRowConsumer interceptingBatchConsumer = new InterceptingRowConsumer(this.jobId, consumerIt.next(), initializationTracker, this.killNodeAction);
            handlerPhaseAndReceiver.add(new HandlerPhase(handlerPhase, interceptingBatchConsumer));
        }
        return handlerPhaseAndReceiver;
    }

    private void sendJobRequests(TransactionContext txnCtx, String localNodeId, Map<String, Collection<NodeOperation>> operationByServer, List<PageBucketReceiver> bucketReceivers, List<HandlerPhase> handlerPhases, int bucketIdx, InitializationTracker initializationTracker) {
        for (Map.Entry<String, Collection<NodeOperation>> entry : operationByServer.entrySet()) {
            String serverNodeId = entry.getKey();
            NodeRequest<JobRequest> request = JobRequest.of(serverNodeId, this.jobId, txnCtx.sessionSettings(), localNodeId, entry.getValue(), this.enableProfiling);
            BucketForwarder listener = new BucketForwarder(bucketReceivers, bucketIdx, initializationTracker);
            this.transportJobAction.execute(request).whenComplete((BiConsumer)listener);
            ++bucketIdx;
        }
    }

    private List<PageBucketReceiver> getHandlerBucketReceivers(RootTask rootTask, List<HandlerPhase> handlerPhases) {
        ArrayList<PageBucketReceiver> pageBucketReceivers = new ArrayList<PageBucketReceiver>(handlerPhases.size());
        for (HandlerPhase handlerPhase : handlerPhases) {
            Object ctx = rootTask.getTaskOrNull(handlerPhase.phase().phaseId());
            if (!(ctx instanceof DownstreamRXTask)) continue;
            DownstreamRXTask rxTask = (DownstreamRXTask)ctx;
            PageBucketReceiver pageBucketReceiver = rxTask.getBucketReceiver((byte)0);
            pageBucketReceivers.add(pageBucketReceiver);
        }
        return pageBucketReceivers;
    }

    public record HandlerPhase(ExecutionPhase phase, RowConsumer consumer) {
    }
}

