/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.athena.jdbc.authentication.datazone;

import com.amazon.athena.jdbc.authentication.datazone.helpers.BrowserControlHelper;
import com.amazon.athena.jdbc.authentication.datazone.helpers.DataZoneHelper;
import com.amazon.athena.jdbc.authentication.datazone.helpers.SsoOidcHelper;
import com.amazon.athena.jdbc.authentication.datazone.httpserver.Server;
import com.amazon.athena.jdbc.authentication.utils.RandomString;
import com.amazon.athena.jdbc.support.AuthenticationException;
import com.amazon.athena.logging.AthenaLogger;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URIBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.ssooidc.model.RegisterClientResponse;

public class DataZoneIdcCredentialsProvider
implements AwsCredentialsProvider {
    private static final AthenaLogger logger = AthenaLogger.of(DataZoneIdcCredentialsProvider.class);
    private final String domainId;
    private final String environmentId;
    private final String domainRegion;
    private final Region environmentRegion;
    private final String datazoneEndpoint;
    private final String identityCenterIssuerUrl;
    private final String cachedTokenKey;
    private final int idcResponseTimeout;
    private final boolean enableTokenCaching;
    private final SsoOidcHelper ssoOidcHelper;
    private final DataZoneHelper dataZoneHelper;
    private final BrowserControlHelper browserControlHelper;
    private final Server server;
    private final String datazoneScope;
    private AwsSessionCredentials credentials;
    private static final Map<String, RegisterClientResponse> registeredClientCache = new HashMap<String, RegisterClientResponse>();
    private static final Map<String, String> tokenCache = new HashMap<String, String>();
    private static final String OAUTH_CSRF_STATE_PARAMETER_NAME = "state";
    private static final String AUTH_CODE_PARAMETER_NAME = "code";

    public DataZoneIdcCredentialsProvider(String domainId, String environmentId, String domainRegion, Region environmentRegion, String datazoneEndpoint, String identityCenterIssuerUrl, int idcResponseTimeout, boolean enableTokenCaching, SsoOidcHelper ssoOidcHelper, DataZoneHelper dataZoneHelper, BrowserControlHelper browserControlHelper, Server server, String datazoneScope) {
        this.domainId = domainId;
        this.environmentId = environmentId;
        this.domainRegion = domainRegion;
        this.environmentRegion = environmentRegion;
        this.datazoneEndpoint = datazoneEndpoint;
        this.identityCenterIssuerUrl = identityCenterIssuerUrl;
        this.idcResponseTimeout = idcResponseTimeout;
        this.enableTokenCaching = enableTokenCaching;
        this.cachedTokenKey = domainId + environmentId;
        this.ssoOidcHelper = ssoOidcHelper;
        this.dataZoneHelper = dataZoneHelper;
        this.browserControlHelper = browserControlHelper;
        this.server = server;
        this.datazoneScope = datazoneScope;
    }

    public AwsCredentials resolveCredentials() {
        if (this.credentials == null || ((Instant)this.credentials.expirationTime().get()).isBefore(Instant.now())) {
            if (this.enableTokenCaching) {
                this.credentials = this.resolveCachedTokenCredentials();
            } else {
                tokenCache.remove(this.cachedTokenKey);
                this.credentials = this.getDataZoneCredentials();
            }
        }
        return this.credentials;
    }

    private AwsSessionCredentials resolveCachedTokenCredentials() {
        String cachedToken = tokenCache.get(this.cachedTokenKey);
        if (cachedToken == null) {
            return this.getDataZoneCredentials();
        }
        try {
            return this.dataZoneHelper.getEnvironmentCredentials(cachedToken);
        }
        catch (AuthenticationException e) {
            logger.info(String.format("Could not re-use cached access token: %s", e.getMessage()), new Object[0]);
            return this.getDataZoneCredentials();
        }
    }

    private AwsSessionCredentials getDataZoneCredentials() {
        RegisterClientResponse registeredClient = this.getRegisteredClient(this.identityCenterIssuerUrl);
        String codeVerifier = DataZoneIdcCredentialsProvider.generateCodeVerifier();
        String codeChallenge = DataZoneIdcCredentialsProvider.generateCodeChallenge(codeVerifier);
        String authCode = this.fetchAuthorizationToken(registeredClient, codeChallenge);
        String accessToken = this.ssoOidcHelper.retrieveAccessToken(registeredClient, codeVerifier, authCode);
        if (this.enableTokenCaching) {
            tokenCache.put(this.cachedTokenKey, accessToken);
        }
        AwsSessionCredentials environmentCredentials = this.dataZoneHelper.getEnvironmentCredentials(accessToken);
        logger.trace("Retrieved DataZone Environment Credentials", new Object[0]);
        return AwsSessionCredentials.builder().accessKeyId(environmentCredentials.accessKeyId()).secretAccessKey(environmentCredentials.secretAccessKey()).sessionToken(environmentCredentials.sessionToken()).expirationTime((Instant)environmentCredentials.expirationTime().get()).build();
    }

    private RegisterClientResponse getRegisteredClient(String issuerUrl) {
        String cachedClientKey = this.domainId + this.environmentId + issuerUrl;
        RegisterClientResponse cachedClient = registeredClientCache.get(cachedClientKey);
        if (this.isCachedClientValid(cachedClient)) {
            logger.trace("Registered client is cached, using previously registered client", new Object[0]);
            return cachedClient;
        }
        logger.trace("Registering a new Identity Center client", new Object[0]);
        RegisterClientResponse registerClientResponse = this.ssoOidcHelper.registerClient(issuerUrl);
        registeredClientCache.put(cachedClientKey, registerClientResponse);
        return registerClientResponse;
    }

    private boolean isCachedClientValid(RegisterClientResponse cachedClient) {
        if (cachedClient == null || cachedClient.clientSecretExpiresAt() == null) {
            return false;
        }
        return System.currentTimeMillis() < cachedClient.clientSecretExpiresAt() * 1000L;
    }

    private static String generateCodeVerifier() {
        SecureRandom secureRandom = new SecureRandom();
        byte[] codeVerifier = new byte[32];
        secureRandom.nextBytes(codeVerifier);
        return Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifier);
    }

    private static String generateCodeChallenge(String codeVerifier) {
        try {
            MessageDigest digest = MessageDigest.getInstance("SHA-256");
            byte[] hash = digest.digest(codeVerifier.getBytes());
            return Base64.getUrlEncoder().withoutPadding().encodeToString(hash);
        }
        catch (NoSuchAlgorithmException e) {
            throw new AuthenticationException("Unable to generate code challenge", e);
        }
    }

    private String fetchAuthorizationToken(RegisterClientResponse registeredClient, String codeChallenge) {
        String state = RandomString.generateRandomString(10);
        Future<List<NameValuePair>> future = this.server.listenForResponse();
        try {
            String oidcUrl = String.format("//oidc.%s.amazonaws.com/authorize", this.domainRegion);
            URIBuilder builder = new URIBuilder().setScheme("https").setPath(oidcUrl).addParameter("response_type", AUTH_CODE_PARAMETER_NAME).addParameter("client_id", registeredClient.clientId()).addParameter("redirect_uri", "http://127.0.0.1:" + this.server.getListenPort()).addParameter("scopes", this.datazoneScope).addParameter(OAUTH_CSRF_STATE_PARAMETER_NAME, state).addParameter("code_challenge", codeChallenge).addParameter("code_challenge_method", "S256");
            URI requestUri = builder.build();
            this.browserControlHelper.launchBrowser(requestUri);
            List<NameValuePair> response = future.get(this.idcResponseTimeout, TimeUnit.SECONDS);
            String receivedState = this.findValueInNameValuePairs(OAUTH_CSRF_STATE_PARAMETER_NAME, response).orElseThrow(() -> new AuthenticationException("State is not found in the response"));
            if (!receivedState.equals(state)) {
                String stateErrorMessage = "State " + receivedState + " does not match the outgoing state " + state;
                throw new AuthenticationException(stateErrorMessage);
            }
            String string = this.findValueInNameValuePairs(AUTH_CODE_PARAMETER_NAME, response).filter(value -> !value.isEmpty()).orElseThrow(() -> new AuthenticationException("Authorization code is not found or empty"));
            return string;
        }
        catch (TimeoutException e) {
            future.cancel(true);
            throw new AuthenticationException("Timeout while fetching authorization code", e);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            logger.debug("Main thread got interrupted: {}", e.getMessage());
            throw new AuthenticationException("Main thread got interrupted", e);
        }
        catch (ExecutionException e) {
            logger.debug("Server thread threw an exception: {}", e.getMessage());
            throw new AuthenticationException(e.getMessage());
        }
        catch (IOException | URISyntaxException e) {
            logger.debug("Server thread threw an exception: {}", e.getMessage());
            throw new AuthenticationException(e.getMessage());
        }
        finally {
            logger.trace("Shutdown listening server", new Object[0]);
            this.server.shutdownServer();
        }
    }

    private Optional<String> findValueInNameValuePairs(String name, List<NameValuePair> list) {
        for (NameValuePair pair : list) {
            if (!name.equals(pair.getName())) continue;
            return Optional.of(pair.getValue());
        }
        return Optional.empty();
    }
}

