/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.oauth2.client.web.reactive.function.client;

import java.net.URI;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
implements ExchangeFilterFunction {
    private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
    private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
    private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
    private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
    private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
    private Clock clock = Clock.systemUTC();
    private Duration accessTokenExpiresSkew = Duration.ofMinutes(1L);
    private ClientRegistrationRepository clientRegistrationRepository;
    private OAuth2AuthorizedClientRepository authorizedClientRepository;
    private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
    private boolean defaultOAuth2AuthorizedClient;
    private String defaultClientRegistrationId;

    public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) {
        this.clientRegistrationRepository = clientRegistrationRepository;
        this.authorizedClientRepository = authorizedClientRepository;
    }

    public void setClientCredentialsTokenResponseClient(OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
        Assert.notNull(clientCredentialsTokenResponseClient, (String)"clientCredentialsTokenResponseClient cannot be null");
        this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
    }

    public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
        this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
    }

    public void setDefaultClientRegistrationId(String clientRegistrationId) {
        this.defaultClientRegistrationId = clientRegistrationId;
    }

    public Consumer<WebClient.Builder> oauth2Configuration() {
        return builder -> builder.defaultRequest(this.defaultRequest()).filter((ExchangeFilterFunction)this);
    }

    public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
        return spec -> spec.attributes(attrs -> {
            this.populateDefaultRequestResponse((Map<String, Object>)attrs);
            this.populateDefaultAuthentication((Map<String, Object>)attrs);
            this.populateDefaultOAuth2AuthorizedClient((Map<String, Object>)attrs);
        });
    }

    public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
        return attributes -> {
            if (authorizedClient == null) {
                attributes.remove(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
            } else {
                attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
            }
        };
    }

    public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
        return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
    }

    public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
        return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication);
    }

    public static Consumer<Map<String, Object>> httpServletRequest(HttpServletRequest request) {
        return attributes -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
    }

    public static Consumer<Map<String, Object>> httpServletResponse(HttpServletResponse response) {
        return attributes -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
    }

    public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
        Assert.notNull((Object)accessTokenExpiresSkew, (String)"accessTokenExpiresSkew cannot be null");
        this.accessTokenExpiresSkew = accessTokenExpiresSkew;
    }

    public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
        Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).map(OAuth2AuthorizedClient.class::cast);
        return Mono.justOrEmpty(attribute).flatMap(authorizedClient -> this.authorizedClient(request, next, (OAuth2AuthorizedClient)authorizedClient)).map(authorizedClient -> this.bearer(request, (OAuth2AuthorizedClient)authorizedClient)).flatMap(arg_0 -> ((ExchangeFunction)next).exchange(arg_0)).switchIfEmpty(next.exchange(request));
    }

    private void populateDefaultRequestResponse(Map<String, Object> attrs) {
        if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
            return;
        }
        ServletRequestAttributes context = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = null;
        HttpServletResponse response = null;
        if (context != null) {
            request = context.getRequest();
            response = context.getResponse();
        }
        attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
        attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
    }

    private void populateDefaultAuthentication(Map<String, Object> attrs) {
        if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
            return;
        }
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
    }

    private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
        if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
            return;
        }
        Authentication authentication = ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs);
        String clientRegistrationId = ServletOAuth2AuthorizedClientExchangeFilterFunction.getClientRegistrationId(attrs);
        if (clientRegistrationId == null) {
            clientRegistrationId = this.defaultClientRegistrationId;
        }
        if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) {
            clientRegistrationId = ((OAuth2AuthenticationToken)authentication).getAuthorizedClientRegistrationId();
        }
        if (clientRegistrationId != null) {
            HttpServletRequest request = ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs);
            Object authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, request);
            if (authorizedClient == null) {
                authorizedClient = this.getAuthorizedClient(clientRegistrationId, attrs);
            }
            ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient).accept(attrs);
        }
    }

    private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) {
        ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
        if (clientRegistration == null) {
            throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
        }
        if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals((Object)clientRegistration.getAuthorizationGrantType())) {
            return this.getAuthorizedClient(clientRegistration, attrs);
        }
        throw new ClientAuthorizationRequiredException(clientRegistrationId);
    }

    private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map<String, Object> attrs) {
        HttpServletRequest request = ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs);
        HttpServletResponse response = ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse(attrs);
        OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
        OAuth2AccessTokenResponse tokenResponse = this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
        Authentication principal = ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs);
        OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal != null ? principal.getName() : "anonymousUser", tokenResponse.getAccessToken());
        this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response);
        return authorizedClient;
    }

    private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
        if (this.shouldRefresh(authorizedClient)) {
            return this.refreshAuthorizedClient(request, next, authorizedClient);
        }
        return Mono.just((Object)authorizedClient);
    }

    private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
        ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
        String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
        ClientRequest refreshRequest = ClientRequest.create((HttpMethod)HttpMethod.POST, (URI)URI.create(tokenUri)).header("Accept", new String[]{"application/json"}).headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())).body(ServletOAuth2AuthorizedClientExchangeFilterFunction.refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())).build();
        return next.exchange(refreshRequest).flatMap(response -> (Mono)response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())).map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())).map(result -> {
            Authentication principal = request.attribute(AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()));
            HttpServletRequest httpRequest = (HttpServletRequest)request.attributes().get(HTTP_SERVLET_REQUEST_ATTR_NAME);
            HttpServletResponse httpResponse = (HttpServletResponse)request.attributes().get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
            this.authorizedClientRepository.saveAuthorizedClient((OAuth2AuthorizedClient)result, principal, httpRequest, httpResponse);
            return result;
        }).publishOn(Schedulers.elastic());
    }

    private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
        Instant expiresAt;
        if (this.authorizedClientRepository == null) {
            return false;
        }
        OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
        if (refreshToken == null) {
            return false;
        }
        Instant now = this.clock.instant();
        return now.isAfter((expiresAt = authorizedClient.getAccessToken().getExpiresAt()).minus(this.accessTokenExpiresSkew));
    }

    private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
        return ClientRequest.from((ClientRequest)request).headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())).build();
    }

    private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
        return BodyInserters.fromFormData((String)"grant_type", (String)AuthorizationGrantType.REFRESH_TOKEN.getValue()).with("refresh_token", (Object)refreshToken);
    }

    static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
        return (OAuth2AuthorizedClient)attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
    }

    static String getClientRegistrationId(Map<String, Object> attrs) {
        return (String)attrs.get(CLIENT_REGISTRATION_ID_ATTR_NAME);
    }

    static Authentication getAuthentication(Map<String, Object> attrs) {
        return (Authentication)attrs.get(AUTHENTICATION_ATTR_NAME);
    }

    static HttpServletRequest getRequest(Map<String, Object> attrs) {
        return (HttpServletRequest)attrs.get(HTTP_SERVLET_REQUEST_ATTR_NAME);
    }

    static HttpServletResponse getResponse(Map<String, Object> attrs) {
        return (HttpServletResponse)attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
    }

    private static class PrincipalNameAuthentication
    implements Authentication {
        private final String username;

        private PrincipalNameAuthentication(String username) {
            this.username = username;
        }

        public Collection<? extends GrantedAuthority> getAuthorities() {
            throw this.unsupported();
        }

        public Object getCredentials() {
            throw this.unsupported();
        }

        public Object getDetails() {
            throw this.unsupported();
        }

        public Object getPrincipal() {
            throw this.unsupported();
        }

        public boolean isAuthenticated() {
            throw this.unsupported();
        }

        public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
            throw this.unsupported();
        }

        public String getName() {
            return this.username;
        }

        private UnsupportedOperationException unsupported() {
            return new UnsupportedOperationException("Not Supported");
        }
    }
}

