Commit 045b724c authored by Jasha Joachimsthal's avatar Jasha Joachimsthal Committed by 陈健

OAUTH-3116 Logout from RP before sending user to OP

parent 71653881
package com.onegini.oidc; package com.onegini.oidc;
import static org.springframework.web.servlet.view.UrlBasedViewResolver.REDIRECT_URL_PREFIX;
import java.security.Principal; import java.security.Principal;
import java.util.Map; import java.util.Map;
...@@ -11,7 +13,6 @@ import org.slf4j.Logger; ...@@ -11,7 +13,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2RestOperations;
import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
...@@ -19,6 +20,7 @@ import org.springframework.util.LinkedMultiValueMap; ...@@ -19,6 +20,7 @@ import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder; import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
...@@ -27,62 +29,79 @@ import com.onegini.oidc.model.UserInfo; ...@@ -27,62 +29,79 @@ import com.onegini.oidc.model.UserInfo;
@Controller @Controller
public class LogoutController { public class LogoutController {
public static final String PAGE_LOGOUT = "/logout";
private static final Logger LOG = LoggerFactory.getLogger(LogoutController.class); private static final Logger LOG = LoggerFactory.getLogger(LogoutController.class);
@SuppressWarnings("squid:S1075")
private static final String WELL_KNOWN_CONFIG_PATH = "/.well-known/openid-configuration"; private static final String WELL_KNOWN_CONFIG_PATH = "/.well-known/openid-configuration";
private static final String KEY_END_SESSION_ENDPOINT = "end_session_endpoint";
private static final String PARAM_POST_LOGOUT_REDIRECT_URI = "post_logout_redirect_uri";
private static final String PARAM_ID_TOKEN_HINT = "id_token_hint";
private static final String PAGE_SIGNOUT_CALLBACK_OIDC = "/signout-callback-oidc"; private static final String PAGE_SIGNOUT_CALLBACK_OIDC = "/signout-callback-oidc";
public static final String PAGE_LOGOUT = "/logout"; private static final String REDIRECT_TO_INDEX = "redirect:/";
@Resource @Resource
private ApplicationProperties applicationProperties; private ApplicationProperties applicationProperties;
@Resource @Resource
private OAuth2RestOperations restTemplate; private RestTemplate restTemplate;
@GetMapping(PAGE_LOGOUT) @GetMapping(PAGE_LOGOUT)
private String logout(final HttpServletRequest request, final HttpServletResponse response, final Principal principal) { private String logout(final HttpServletRequest request, final HttpServletResponse response, final Principal principal) {
if (principal instanceof PreAuthenticatedAuthenticationToken) { // Save idToken before authentication is cleared
final String idToken = getIdToken(principal);
endSessionInSpringSecurity(request, response);
if (StringUtils.hasLength(idToken)) {
LOG.info("Has idToken {}", idToken);
final Map configuration = restTemplate.getForObject(applicationProperties.getIssuer() + WELL_KNOWN_CONFIG_PATH, Map.class); final Map configuration = restTemplate.getForObject(applicationProperties.getIssuer() + WELL_KNOWN_CONFIG_PATH, Map.class);
final String endSessionEndpoint = configuration == null ? null : (String) configuration.get("end_session_endpoint"); @SuppressWarnings("squid:S2583") final String endSessionEndpoint = configuration == null ? null : (String) configuration.get(KEY_END_SESSION_ENDPOINT);
if (StringUtils.hasLength(endSessionEndpoint)) { if (StringUtils.hasLength(endSessionEndpoint)) {
return endOpenIdSession((PreAuthenticatedAuthenticationToken) principal, endSessionEndpoint); return endOpenIdSession(idToken, endSessionEndpoint);
} }
} }
return doLogout(request, response); return REDIRECT_TO_INDEX;
} }
@GetMapping(PAGE_SIGNOUT_CALLBACK_OIDC) @GetMapping(PAGE_SIGNOUT_CALLBACK_OIDC)
public String callbackOidc(final HttpServletRequest request, final HttpServletResponse response) { public String callbackOidc() {
LOG.info("Signout callback from OP"); LOG.info("Signout callback from OP");
return doLogout(request, response); return REDIRECT_TO_INDEX;
} }
private String endOpenIdSession(final PreAuthenticatedAuthenticationToken principal, final String endSessionEndpoint) { private String getIdToken(final Principal principal) {
final UserInfo userInfo = (UserInfo) principal.getPrincipal(); if (principal instanceof PreAuthenticatedAuthenticationToken) {
final PreAuthenticatedAuthenticationToken authenticationToken = (PreAuthenticatedAuthenticationToken) principal;
final UserInfo userInfo = (UserInfo) authenticationToken.getPrincipal();
return userInfo.getIdToken();
}
return null;
}
private void endSessionInSpringSecurity(final HttpServletRequest request, final HttpServletResponse response) {
final Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null) {
LOG.info("End user session in Spring Security");
new SecurityContextLogoutHandler().logout(request, response, auth);
}
}
private String endOpenIdSession(final String idToken, final String endSessionEndpoint) {
final MultiValueMap<String, String> requestParameters = new LinkedMultiValueMap<>(); final MultiValueMap<String, String> requestParameters = new LinkedMultiValueMap<>();
final String postLogoutRedirectUri = ServletUriComponentsBuilder.fromCurrentContextPath().path(PAGE_SIGNOUT_CALLBACK_OIDC).build().toUriString(); final String postLogoutRedirectUri = ServletUriComponentsBuilder.fromCurrentContextPath().path(PAGE_SIGNOUT_CALLBACK_OIDC).build().toUriString();
requestParameters.add("post_logout_redirect_uri", postLogoutRedirectUri); requestParameters.add(PARAM_POST_LOGOUT_REDIRECT_URI, postLogoutRedirectUri);
requestParameters.add("id_token_hint", userInfo.getIdToken()); requestParameters.add(PARAM_ID_TOKEN_HINT, idToken);
final String redirectUri = UriComponentsBuilder.fromUriString(endSessionEndpoint) final String redirectUri = UriComponentsBuilder.fromUriString(endSessionEndpoint)
.queryParams(requestParameters) .queryParams(requestParameters)
.build().toUriString(); .build().toUriString();
LOG.info("Redirect to OP end session"); LOG.info("Redirect to OP end session");
return "redirect:" + redirectUri; return REDIRECT_URL_PREFIX + redirectUri;
} }
private String doLogout(final HttpServletRequest request, final HttpServletResponse response) {
final Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null) {
LOG.info("End user session in Spring Security");
new SecurityContextLogoutHandler().logout(request, response, auth);
}
return "redirect:/";
}
} }
\ No newline at end of file
package com.onegini.oidc.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
@Configuration
public class RestTemplateConfiguration {
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}
\ No newline at end of file
...@@ -28,7 +28,7 @@ import com.nimbusds.jwt.JWTParser; ...@@ -28,7 +28,7 @@ import com.nimbusds.jwt.JWTParser;
public class OpenIdConnectAuthenticationFilter extends AbstractAuthenticationProcessingFilter { public class OpenIdConnectAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
@Resource @Resource
private OAuth2RestOperations restTemplate; private OAuth2RestOperations oAuth2RestOperations;
@Resource @Resource
private OAuth2ProtectedResourceDetails details; private OAuth2ProtectedResourceDetails details;
@Resource @Resource
...@@ -60,7 +60,7 @@ public class OpenIdConnectAuthenticationFilter extends AbstractAuthenticationPro ...@@ -60,7 +60,7 @@ public class OpenIdConnectAuthenticationFilter extends AbstractAuthenticationPro
final OAuth2AccessToken accessToken; final OAuth2AccessToken accessToken;
try { try {
accessToken = restTemplate.getAccessToken(); accessToken = oAuth2RestOperations.getAccessToken();
} catch (final OAuth2Exception e) { } catch (final OAuth2Exception e) {
throw new AccessTokenRequiredException("Could not obtain access token", details, e); throw new AccessTokenRequiredException("Could not obtain access token", details, e);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment