/*
 * Decompiled with CFR 0.152.
 */
package io.crate.auth;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkException;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.RSAKeyProvider;
import com.auth0.jwt.interfaces.Verification;
import io.crate.auth.AuthSettings;
import io.crate.auth.AuthenticationMethod;
import io.crate.auth.CachingJwkProvider;
import io.crate.auth.Credentials;
import io.crate.auth.LoadedRSAKeyProvider;
import io.crate.protocols.postgres.ConnectionProperties;
import io.crate.role.JwtProperties;
import io.crate.role.Role;
import io.crate.role.Roles;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import org.elasticsearch.common.settings.Settings;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class JWTAuthenticationMethod
implements AuthenticationMethod {
    public static final String NAME = "jwt";
    private final Roles roles;
    private final Function<String, JwkProvider> createProvider;
    private final ConcurrentHashMap<String, JwkProvider> cachedJwkProviders = new ConcurrentHashMap();
    private final Supplier<String> clusterId;
    private final Settings settings;

    public JWTAuthenticationMethod(Roles roles, Settings settings, Supplier<String> clusterId) {
        this(roles, settings, clusterId, CachingJwkProvider::new);
    }

    @VisibleForTesting
    JWTAuthenticationMethod(Roles roles, Settings settings, Supplier<String> clusterId, Function<String, JwkProvider> createProvider) {
        this.roles = roles;
        this.settings = settings;
        this.clusterId = clusterId;
        this.createProvider = createProvider;
    }

    @Override
    @Nullable
    public Role authenticate(Credentials credentials, ConnectionProperties connProperties) {
        String username = credentials.username();
        assert (username != null) : "User name must be resolved before authentication attempt";
        DecodedJWT decodedJWT = credentials.decodedToken();
        assert (decodedJWT != null) : "Token must be not null on jwt auth";
        Role user = this.roles.findUser(username);
        if (user == null) {
            throw new RuntimeException("jwt authentication failed for user \"" + username + "\"");
        }
        try {
            String name;
            String issuer = this.settings.get(AuthSettings.AUTH_HOST_BASED_JWT_ISS_SETTING.getKey());
            String audience = this.settings.get(AuthSettings.AUTH_HOST_BASED_JWT_AUD_SETTING.getKey(), this.clusterId.get());
            if (issuer != null) {
                name = username;
            } else {
                JwtProperties jwtProperties = user.jwtProperties();
                assert (jwtProperties != null) : "credentials.username was matched using jwt properties, properties cannot be null.";
                issuer = jwtProperties.iss();
                name = jwtProperties.username();
                audience = jwtProperties.aud() != null ? jwtProperties.aud() : audience;
            }
            JwkProvider jwkProvider = this.cachedJwkProviders.computeIfAbsent(issuer, this.createProvider::apply);
            Algorithm algorithm = JWTAuthenticationMethod.resolveAlgorithm(jwkProvider, decodedJWT);
            Verification verification = JWT.require((Algorithm)algorithm).withIssuer(issuer).withClaim("username", name).withAudience(new String[]{audience});
            JWTVerifier verifier = verification.build();
            verifier.verify(decodedJWT);
        }
        catch (Exception e) {
            throw new RuntimeException(String.format(Locale.ENGLISH, "jwt authentication failed for user %s. Reason: %s", username, e.getMessage()));
        }
        return user;
    }

    @Override
    public String name() {
        return NAME;
    }

    private static Algorithm resolveAlgorithm(JwkProvider jwkProvider, @NotNull DecodedJWT decodedJWT) throws JwkException {
        Jwk jwk = jwkProvider.get(decodedJWT.getKeyId());
        PublicKey publicKey = jwk.getPublicKey();
        if (!(publicKey instanceof RSAPublicKey)) {
            throw new UnsupportedOperationException("Only RSA algorithm is supported for JWT");
        }
        RSAPublicKey rsaPublicKey = (RSAPublicKey)publicKey;
        if (jwk.getAlgorithm() == null) {
            return Algorithm.RSA256((RSAKeyProvider)new LoadedRSAKeyProvider(rsaPublicKey));
        }
        if (decodedJWT.getAlgorithm() != null && !decodedJWT.equals((Object)jwk.getAlgorithm())) {
            throw new IllegalArgumentException("Jwt token has algorithm not matching with the algorithm of the public key.");
        }
        return switch (jwk.getAlgorithm()) {
            case "RS256" -> Algorithm.RSA256((RSAKeyProvider)new LoadedRSAKeyProvider(rsaPublicKey));
            case "RS384" -> Algorithm.RSA384((RSAKeyProvider)new LoadedRSAKeyProvider(rsaPublicKey));
            case "RS512" -> Algorithm.RSA512((RSAKeyProvider)new LoadedRSAKeyProvider(rsaPublicKey));
            default -> throw new RuntimeException(String.format(Locale.ENGLISH, "Unsupported algorithm %s", jwk.getAlgorithm()));
        };
    }
}

