/*
 * Decompiled with CFR 0.152.
 */
package io.crate.protocols.postgres;

import io.crate.auth.AccessControl;
import io.crate.auth.Authentication;
import io.crate.auth.AuthenticationMethod;
import io.crate.auth.Credentials;
import io.crate.auth.Protocol;
import io.crate.common.collections.Lists;
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.Symbol;
import io.crate.metadata.settings.CoordinatorSessionSettings;
import io.crate.metadata.settings.session.SessionSetting;
import io.crate.metadata.settings.session.SessionSettingRegistry;
import io.crate.protocols.SSL;
import io.crate.protocols.postgres.AuthenticationContext;
import io.crate.protocols.postgres.ConnectionProperties;
import io.crate.protocols.postgres.DelayableWriteChannel;
import io.crate.protocols.postgres.FormatCodes;
import io.crate.protocols.postgres.KeyData;
import io.crate.protocols.postgres.Messages;
import io.crate.protocols.postgres.PGError;
import io.crate.protocols.postgres.PgDecoder;
import io.crate.protocols.postgres.ResultSetReceiver;
import io.crate.protocols.postgres.RowCountReceiver;
import io.crate.protocols.postgres.TransactionState;
import io.crate.protocols.postgres.types.PGType;
import io.crate.protocols.postgres.types.PGTypes;
import io.crate.role.Role;
import io.crate.session.BaseResultReceiver;
import io.crate.session.DescribeResult;
import io.crate.session.Session;
import io.crate.session.Sessions;
import io.crate.sql.tree.Statement;
import io.crate.types.DataType;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.SslContext;
import java.net.InetAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.SSLSession;
import joptsimple.ArgumentAcceptingOptionSpec;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import joptsimple.OptionSpec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class PostgresWireProtocol {
    private static final Logger LOGGER = LogManager.getLogger(PostgresWireProtocol.class);
    private static final String PASSWORD_AUTH_NAME = "password";
    public static final int SERVER_VERSION_NUM = 140000;
    public static final String PG_SERVER_VERSION = "14.0";
    final PgDecoder decoder;
    final MessageHandler handler;
    private final Sessions sessions;
    private final SessionSettingRegistry sessionSettingRegistry;
    private final Function<CoordinatorSessionSettings, AccessControl> getAccessControl;
    private final Authentication authService;
    private final Consumer<ChannelPipeline> addTransportHandler;
    private DelayableWriteChannel channel;
    Session session;
    private boolean ignoreTillSync = false;
    private AuthenticationContext authContext;
    private Properties properties;

    PostgresWireProtocol(Sessions sessions, SessionSettingRegistry sessionSettingRegistry, Function<CoordinatorSessionSettings, AccessControl> getAcessControl, Consumer<ChannelPipeline> addTransportHandler, Authentication authService, Supplier<SslContext> getSslContext) {
        this.sessions = sessions;
        this.sessionSettingRegistry = sessionSettingRegistry;
        this.getAccessControl = getAcessControl;
        this.addTransportHandler = addTransportHandler;
        this.authService = authService;
        this.decoder = new PgDecoder(getSslContext);
        this.handler = new MessageHandler();
    }

    @Nullable
    static String readCString(ByteBuf buffer) {
        byte[] bytes = new byte[buffer.bytesBefore((byte)0) + 1];
        if (bytes.length == 0) {
            return null;
        }
        buffer.readBytes(bytes);
        return new String(bytes, 0, bytes.length - 1, StandardCharsets.UTF_8);
    }

    @Nullable
    private static char[] readCharArray(ByteBuf buffer) {
        byte[] bytes = new byte[buffer.bytesBefore((byte)0) + 1];
        if (bytes.length == 0) {
            return null;
        }
        buffer.readBytes(bytes);
        return StandardCharsets.UTF_8.decode(ByteBuffer.wrap(bytes)).array();
    }

    private Properties readStartupMessage(ByteBuf buffer) {
        String key;
        Properties properties = new Properties();
        while ((key = PostgresWireProtocol.readCString(buffer)) != null) {
            String value = PostgresWireProtocol.readCString(buffer);
            LOGGER.trace("payload: key={} value={}", (Object)key, (Object)value);
            if (key.isEmpty() || "".equals(value)) continue;
            properties.setProperty(key, value);
        }
        return properties;
    }

    private void handleStartupBody(ByteBuf buffer, Channel channel) {
        this.properties = this.readStartupMessage(buffer);
        this.initAuthentication(channel);
    }

    private void initAuthentication(Channel channel) {
        SSLSession sslSession;
        InetAddress address;
        Credentials credentials;
        ConnectionProperties connProperties;
        String userName = this.properties.getProperty("user");
        AuthenticationMethod authMethod = this.authService.resolveAuthenticationType(userName, connProperties = new ConnectionProperties(credentials = new Credentials(userName, null), address = Netty4HttpServerTransport.getRemoteAddress(channel), Protocol.POSTGRES, sslSession = SSL.getSession(channel)));
        if (authMethod == null) {
            String errorMessage = String.format(Locale.ENGLISH, "No valid auth.host_based entry found for host \"%s\", user \"%s\". Did you enable TLS in your client?", address.getHostAddress(), userName);
            Messages.sendAuthenticationError(channel, errorMessage);
        } else {
            this.authContext = new AuthenticationContext(authMethod, connProperties, credentials, LOGGER);
            if (PASSWORD_AUTH_NAME.equals(authMethod.name())) {
                Messages.sendAuthenticationCleartextPassword(channel);
                return;
            }
            this.finishAuthentication(channel);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void finishAuthentication(Channel channel) {
        assert (this.authContext != null) : "finishAuthentication() requires an authContext instance";
        try {
            Role authenticatedUser = this.authContext.authenticate();
            String database = this.properties.getProperty("database");
            this.session = this.sessions.newSession(this.authContext.connectionProperties(), database, authenticatedUser);
            String options = this.properties.getProperty("options");
            if (options != null) {
                this.applyOptions(options);
            }
            Messages.sendAuthenticationOK(channel).addListener(f -> this.sendParams(channel, this.session.sessionSettings())).addListener(f -> Messages.sendKeyData(channel, this.session.id(), this.session.secret())).addListener(f -> {
                Messages.sendReadyForQuery(channel, TransactionState.IDLE);
                if (this.properties.containsKey("CrateDBTransport")) {
                    this.switchToTransportProtocol(channel);
                }
            });
        }
        catch (Exception e) {
            Messages.sendAuthenticationError(channel, e.getMessage());
        }
        finally {
            this.authContext.close();
            this.authContext = null;
        }
    }

    private void applyOptions(String options) {
        OptionParser parser = new OptionParser();
        ArgumentAcceptingOptionSpec optionC = parser.accepts("c").withRequiredArg().ofType(String.class);
        OptionSet parseResult = parser.parse(options.split(" "));
        List parsedOptions = parseResult.valuesOf((OptionSpec)optionC);
        for (String parsedOption : parsedOptions) {
            String[] parts = parsedOption.split("=");
            if (parts.length != 2) continue;
            String key = parts[0].trim();
            String value = parts[1].trim();
            SessionSetting<?> sessionSetting = this.sessionSettingRegistry.settings().get(key);
            if (sessionSetting == null) continue;
            sessionSetting.apply(this.session.sessionSettings(), List.of(Literal.of(value)), symbol -> {
                if (symbol instanceof Literal) {
                    Literal literal = (Literal)symbol;
                    return literal.value();
                }
                throw new IllegalStateException("Unexpected symbol: " + String.valueOf(symbol));
            });
        }
    }

    private void switchToTransportProtocol(Channel channel) {
        ChannelPipeline pipeline = channel.pipeline();
        pipeline.remove("frame-decoder");
        pipeline.remove("handler");
        this.addTransportHandler.accept(pipeline);
    }

    private void sendParams(Channel channel, CoordinatorSessionSettings sessionSettings) {
        Messages.sendParameterStatus(channel, "crate_version", Version.CURRENT.externalNumber());
        Messages.sendParameterStatus(channel, "server_version", PG_SERVER_VERSION);
        Messages.sendParameterStatus(channel, "server_encoding", "UTF8");
        Messages.sendParameterStatus(channel, "client_encoding", "UTF8");
        Messages.sendParameterStatus(channel, "datestyle", sessionSettings.dateStyle());
        Messages.sendParameterStatus(channel, "TimeZone", "UTC");
        Messages.sendParameterStatus(channel, "integer_datetimes", "on");
        Messages.sendParameterStatus(channel, "standard_conforming_strings", "on");
    }

    private void handleFlush(Channel channel) {
        try {
            if (this.session.hasDeferredExecutions()) {
                this.session.flush();
            } else {
                channel.flush();
            }
        }
        catch (Throwable t) {
            Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionSettings()), t);
        }
    }

    private void handleParseMessage(ByteBuf buffer, Channel channel) {
        String statementName = PostgresWireProtocol.readCString(buffer);
        String query = PostgresWireProtocol.readCString(buffer);
        int numParams = buffer.readShort();
        ArrayList paramTypes = new ArrayList(numParams);
        for (int i = 0; i < numParams; ++i) {
            int oid = buffer.readInt();
            DataType<?> dataType = PGTypes.fromOID(oid);
            if (dataType == null) {
                throw new IllegalArgumentException(String.format(Locale.ENGLISH, "Can't map PGType with oid=%d to Crate type", oid));
            }
            paramTypes.add(dataType);
        }
        this.session.parse(statementName, query, paramTypes);
        Messages.sendParseComplete(channel);
    }

    private void handlePassword(ByteBuf buffer, Channel channel) {
        char[] passwd = PostgresWireProtocol.readCharArray(buffer);
        if (passwd != null) {
            this.authContext.setSecurePassword(passwd);
        }
        this.finishAuthentication(channel);
    }

    private void handleBindMessage(ByteBuf buffer, Channel channel) {
        String portalName = PostgresWireProtocol.readCString(buffer);
        String statementName = PostgresWireProtocol.readCString(buffer);
        assert (portalName != null) : "portalName cannot be null";
        assert (statementName != null) : "statementName cannot be null";
        FormatCodes.FormatCode[] formatCodes = FormatCodes.fromBuffer(buffer);
        int numParams = buffer.readShort();
        List<Object> params = this.createList((short)numParams);
        block4: for (int i = 0; i < numParams; ++i) {
            int valueLength = buffer.readInt();
            if (valueLength == -1) {
                params.add(null);
                continue;
            }
            DataType<?> paramType = this.session.getParamType(statementName, i);
            PGType<?> pgType = PGTypes.get(paramType);
            FormatCodes.FormatCode formatCode = FormatCodes.getFormatCode(formatCodes, i);
            switch (formatCode) {
                case TEXT: {
                    params.add(pgType.readTextValue(buffer, valueLength));
                    continue block4;
                }
                case BINARY: {
                    params.add(pgType.readBinaryValue(buffer, valueLength));
                    continue block4;
                }
                default: {
                    Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionSettings()), new UnsupportedOperationException(String.format(Locale.ENGLISH, "Unsupported format code '%d' for param '%s'", formatCode.ordinal(), paramType.getName())));
                    return;
                }
            }
        }
        FormatCodes.FormatCode[] resultFormatCodes = FormatCodes.fromBuffer(buffer);
        this.session.bind(portalName, statementName, params, resultFormatCodes);
        Messages.sendBindComplete(channel);
    }

    private <T> List<T> createList(short size) {
        return size == 0 ? Collections.emptyList() : new ArrayList(size);
    }

    private void handleDescribeMessage(ByteBuf buffer, Channel channel) {
        byte type = buffer.readByte();
        String portalOrStatement = PostgresWireProtocol.readCString(buffer);
        DescribeResult describeResult = this.session.describe((char)type, portalOrStatement);
        List<Symbol> fields = describeResult.getFields();
        if (type == 83) {
            Messages.sendParameterDescription(channel, describeResult.getParameters());
        }
        if (fields == null) {
            Messages.sendNoData(channel);
        } else {
            FormatCodes.FormatCode[] resultFormatCodes = type == 80 ? this.session.getResultFormatCodes(portalOrStatement) : null;
            Messages.sendRowDescription(channel, fields, describeResult.getFieldNames(), resultFormatCodes, describeResult.relation());
        }
    }

    private void handleExecute(ByteBuf buffer, DelayableWriteChannel channel) {
        BaseResultReceiver resultReceiver;
        String portalName = PostgresWireProtocol.readCString(buffer);
        int maxRows = buffer.readInt();
        String query = this.session.getQuery(portalName);
        if (query.isEmpty()) {
            this.session.close((byte)80, portalName);
            Messages.sendEmptyQueryResponse(channel);
            return;
        }
        List<? extends DataType<?>> outputTypes = this.session.getOutputTypes(portalName);
        DelayableWriteChannel.DelayedWrites delayedWrites = channel.delayWrites();
        if (outputTypes == null) {
            maxRows = 0;
            resultReceiver = new RowCountReceiver(query, channel, delayedWrites, this.getAccessControl.apply(this.session.sessionSettings()));
        } else {
            resultReceiver = new ResultSetReceiver(query, channel, delayedWrites, this.getAccessControl.apply(this.session.sessionSettings()), Lists.map(outputTypes, PGTypes::get), this.session.getResultFormatCodes(portalName));
        }
        this.session.execute(portalName, maxRows, resultReceiver);
    }

    private void handleSync(DelayableWriteChannel channel) {
        if (this.ignoreTillSync) {
            this.ignoreTillSync = false;
            this.session.resetDeferredExecutions();
            channel.writePendingMessages();
            Messages.sendReadyForQuery(channel, TransactionState.FAILED_TRANSACTION);
            return;
        }
        try {
            ReadyForQueryCallback readyForQueryCallback = new ReadyForQueryCallback(channel, this.session.transactionState());
            this.session.sync(false).whenComplete((BiConsumer)readyForQueryCallback);
        }
        catch (Throwable t) {
            channel.discardDelayedWrites();
            Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionSettings()), t);
            Messages.sendReadyForQuery(channel, TransactionState.FAILED_TRANSACTION);
        }
    }

    private void handleClose(ByteBuf buffer, Channel channel) {
        byte b = buffer.readByte();
        String portalOrStatementName = PostgresWireProtocol.readCString(buffer);
        this.session.close(b, portalOrStatementName);
        Messages.sendCloseComplete(channel);
    }

    @VisibleForTesting
    void handleSimpleQuery(ByteBuf buffer, DelayableWriteChannel channel) {
        List<Statement> statements;
        assert (this.session != null) : "Session must be created when running a simple query";
        Session.TimeoutToken timeoutToken = this.session.newTimeoutToken();
        String queryString = PostgresWireProtocol.readCString(buffer);
        assert (queryString != null) : "query must not be nulL";
        if (queryString.isEmpty() || ";".equals(queryString)) {
            Messages.sendEmptyQueryResponse(channel);
            Messages.sendReadyForQuery(channel, TransactionState.IDLE);
            return;
        }
        try {
            statements = this.session.simpleQuery(queryString);
        }
        catch (Exception ex) {
            Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionSettings()), ex);
            Messages.sendReadyForQuery(channel, TransactionState.IDLE);
            return;
        }
        timeoutToken.check();
        CompletionStage<Object> composedFuture = CompletableFuture.completedFuture(null);
        for (Statement statement : statements) {
            composedFuture = composedFuture.thenCompose(object -> this.handleSingleQuery(statement, queryString, channel, timeoutToken));
        }
        composedFuture.whenComplete((BiConsumer)new ReadyForQueryCallback(channel, TransactionState.IDLE));
    }

    private CompletableFuture<?> handleSingleQuery(Statement statement, String query, DelayableWriteChannel channel, Session.TimeoutToken timeoutToken) {
        CompletableFuture result = new CompletableFuture();
        AccessControl accessControl = this.getAccessControl.apply(this.session.sessionSettings());
        try {
            this.session.analyze("", statement, Collections.emptyList(), query, timeoutToken);
            this.session.bind("", "", Collections.emptyList(), null);
            DescribeResult describeResult = this.session.describe('P', "");
            List<Symbol> fields = describeResult.getFields();
            if (fields == null) {
                DelayableWriteChannel.DelayedWrites delayedWrites = channel.delayWrites();
                RowCountReceiver rowCountReceiver = new RowCountReceiver(query, channel, delayedWrites, accessControl);
                this.session.execute("", 0, rowCountReceiver);
            } else {
                Messages.sendRowDescription(channel, fields, describeResult.getFieldNames(), null, describeResult.relation());
                DelayableWriteChannel.DelayedWrites delayedWrites = channel.delayWrites();
                ResultSetReceiver resultSetReceiver = new ResultSetReceiver(query, channel, delayedWrites, accessControl, Lists.map(fields, x -> PGTypes.get(x.valueType())), null);
                this.session.execute("", 0, resultSetReceiver);
            }
            return this.session.sync(false);
        }
        catch (Throwable t) {
            channel.discardDelayedWrites();
            Messages.sendErrorResponse(channel, accessControl, t);
            result.completeExceptionally(t);
            return result;
        }
    }

    private void handleCancelRequestBody(ByteBuf buffer, Channel channel) {
        KeyData keyData = KeyData.of(buffer);
        this.sessions.cancel(keyData);
        this.handler.closeSession();
        channel.close();
    }

    private class MessageHandler
    extends SimpleChannelInboundHandler<ByteBuf> {
        private MessageHandler() {
        }

        public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
            PostgresWireProtocol.this.channel = new DelayableWriteChannel(ctx.channel());
        }

        public boolean acceptInboundMessage(Object msg) throws Exception {
            return true;
        }

        public void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
            assert (PostgresWireProtocol.this.channel != null) : "Channel must be initialized";
            try {
                this.dispatchState(buffer, PostgresWireProtocol.this.channel);
            }
            catch (Throwable t) {
                PostgresWireProtocol.this.ignoreTillSync = true;
                try {
                    AccessControl accessControl = PostgresWireProtocol.this.session == null ? AccessControl.DISABLED : PostgresWireProtocol.this.getAccessControl.apply(PostgresWireProtocol.this.session.sessionSettings());
                    Messages.sendErrorResponse(PostgresWireProtocol.this.channel, accessControl, t);
                }
                catch (Throwable ti) {
                    LOGGER.error("Error trying to send error to client: {}", (Object)t, (Object)ti);
                }
            }
        }

        private void dispatchState(ByteBuf buffer, DelayableWriteChannel channel) {
            switch (PostgresWireProtocol.this.decoder.state()) {
                case STARTUP_PARAMETERS: {
                    PostgresWireProtocol.this.handleStartupBody(buffer, channel);
                    PostgresWireProtocol.this.decoder.startupDone();
                    return;
                }
                case CANCEL: {
                    PostgresWireProtocol.this.handleCancelRequestBody(buffer, channel);
                    return;
                }
                case MSG: {
                    LOGGER.trace("msg={} msgLength={} readableBytes={}", (Object)Character.valueOf((char)PostgresWireProtocol.this.decoder.msgType()), (Object)PostgresWireProtocol.this.decoder.payloadLength(), (Object)buffer.readableBytes());
                    if (PostgresWireProtocol.this.ignoreTillSync && PostgresWireProtocol.this.decoder.msgType() != 83) {
                        buffer.skipBytes(PostgresWireProtocol.this.decoder.payloadLength());
                        return;
                    }
                    this.dispatchMessage(buffer, channel);
                    return;
                }
            }
            throw new IllegalStateException("Illegal state: " + String.valueOf((Object)PostgresWireProtocol.this.decoder.state()));
        }

        private void dispatchMessage(ByteBuf buffer, DelayableWriteChannel channel) {
            switch (PostgresWireProtocol.this.decoder.msgType()) {
                case 81: {
                    PostgresWireProtocol.this.handleSimpleQuery(buffer, channel);
                    return;
                }
                case 80: {
                    PostgresWireProtocol.this.handleParseMessage(buffer, channel);
                    return;
                }
                case 112: {
                    PostgresWireProtocol.this.handlePassword(buffer, channel);
                    return;
                }
                case 66: {
                    PostgresWireProtocol.this.handleBindMessage(buffer, channel);
                    return;
                }
                case 68: {
                    PostgresWireProtocol.this.handleDescribeMessage(buffer, channel);
                    return;
                }
                case 69: {
                    PostgresWireProtocol.this.handleExecute(buffer, channel);
                    return;
                }
                case 72: {
                    PostgresWireProtocol.this.handleFlush(channel);
                    return;
                }
                case 83: {
                    PostgresWireProtocol.this.handleSync(channel);
                    return;
                }
                case 67: {
                    PostgresWireProtocol.this.handleClose(buffer, channel);
                    return;
                }
                case 88: {
                    this.closeSession();
                    channel.close();
                    return;
                }
            }
            Messages.sendErrorResponse(channel, PostgresWireProtocol.this.session == null ? AccessControl.DISABLED : PostgresWireProtocol.this.getAccessControl.apply(PostgresWireProtocol.this.session.sessionSettings()), new UnsupportedOperationException("Unsupported messageType: " + PostgresWireProtocol.this.decoder.msgType()));
        }

        private void closeSession() {
            if (PostgresWireProtocol.this.session != null) {
                PostgresWireProtocol.this.session.close();
                PostgresWireProtocol.this.session = null;
            }
        }

        public void exceptionCaught(ChannelHandlerContext ctx, Throwable t) throws Exception {
            if (t instanceof SocketException && t.getMessage().equals("Connection reset")) {
                LOGGER.info("Connection reset. Client likely terminated connection");
            } else if (t instanceof DecoderException) {
                Messages.sendErrorResponse(PostgresWireProtocol.this.channel, PostgresWireProtocol.this.session == null ? AccessControl.DISABLED : PostgresWireProtocol.this.getAccessControl.apply(PostgresWireProtocol.this.session.sessionSettings()), t.getCause(), PGError.Severity.FATAL);
            } else {
                LOGGER.error("Uncaught exception: ", t);
            }
            this.closeSession();
            ctx.channel().close();
        }

        public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
            LOGGER.trace("channelDisconnected");
            PostgresWireProtocol.this.channel = null;
            this.closeSession();
            super.channelUnregistered(ctx);
        }
    }

    private static class ReadyForQueryCallback
    implements BiConsumer<Object, Throwable> {
        private final Channel channel;
        private final TransactionState transactionState;

        private ReadyForQueryCallback(Channel channel, TransactionState transactionState) {
            this.channel = channel;
            this.transactionState = transactionState;
        }

        @Override
        public void accept(Object result, Throwable t) {
            Messages.sendReadyForQuery(this.channel, this.transactionState);
        }
    }
}

