diff --git a/src/main/java/se/su/dsv/oauth2/staging/CustomAuthorizationEndpointFilter.java b/src/main/java/se/su/dsv/oauth2/staging/CustomAuthorizationEndpointFilter.java index 4b946bb..036b685 100644 --- a/src/main/java/se/su/dsv/oauth2/staging/CustomAuthorizationEndpointFilter.java +++ b/src/main/java/se/su/dsv/oauth2/staging/CustomAuthorizationEndpointFilter.java @@ -9,11 +9,14 @@ import jakarta.servlet.http.HttpFilter; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter; import org.springframework.security.web.DefaultRedirectStrategy; @@ -72,7 +75,11 @@ public class CustomAuthorizationEndpointFilter extends HttpFilter { } if (loggedInUser.getAuthorities().contains(new Entitlement(developerEntitlement))) { - proceedWithDeveloperAuthorization(request, response, loggedInUser); + try { + proceedWithDeveloperAuthorization(request, response, loggedInUser); + } catch (OAuth2AuthorizationCodeRequestAuthenticationException exception) { + sendAuthorizationError(request, response, exception); + } } else { chain.doFilter(request, response); } @@ -193,4 +200,40 @@ public class CustomAuthorizationEndpointFilter extends HttpFilter { String redirectUri = uriBuilder.build(true).toUriString(); this.redirectStrategy.sendRedirect(request, response, redirectUri); } + + private void sendAuthorizationError( + HttpServletRequest request, + HttpServletResponse response, + OAuth2AuthorizationCodeRequestAuthenticationException authorizationCodeRequestAuthenticationException) + throws IOException + { + OAuth2Error error = authorizationCodeRequestAuthenticationException.getError(); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = authorizationCodeRequestAuthenticationException + .getAuthorizationCodeRequestAuthentication(); + + if (authorizationCodeRequestAuthentication == null + || !StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) { + response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString()); + return; + } + + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(authorizationCodeRequestAuthentication.getRedirectUri()) + .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); + if (StringUtils.hasText(error.getDescription())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, + UriUtils.encode(error.getDescription(), StandardCharsets.UTF_8)); + } + if (StringUtils.hasText(error.getUri())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, + UriUtils.encode(error.getUri(), StandardCharsets.UTF_8)); + } + if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, + UriUtils.encode(authorizationCodeRequestAuthentication.getState(), StandardCharsets.UTF_8)); + } + // build(true) -> Components are explicitly encoded + String redirectUri = uriBuilder.build(true).toUriString(); + this.redirectStrategy.sendRedirect(request, response, redirectUri); + } } diff --git a/src/main/resources/templates/error.jte b/src/main/resources/templates/error.jte new file mode 100644 index 0000000..c9bd052 --- /dev/null +++ b/src/main/resources/templates/error.jte @@ -0,0 +1,8 @@ +@param Integer status +@param String error +@param String message + +@template.base(title = message, content = @` + <h1>${status} ${error}</h1> + <p>${message}</p> +`) diff --git a/src/test/java/se/su/dsv/oauth2/StagingProfileTest.java b/src/test/java/se/su/dsv/oauth2/StagingProfileTest.java index 4b0e836..cac604d 100644 --- a/src/test/java/se/su/dsv/oauth2/StagingProfileTest.java +++ b/src/test/java/se/su/dsv/oauth2/StagingProfileTest.java @@ -6,8 +6,14 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.test.context.ActiveProfiles; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.junit.jupiter.api.Assertions.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; import static se.su.dsv.oauth2.ShibbolethRequestProcessor.remoteUser; +import static se.su.dsv.oauth2.TestRegisteredClientConfiguration.CLIENT_ID; +import static se.su.dsv.oauth2.TestRegisteredClientConfiguration.REDIRECT_URI; @SpringBootTest( classes = TestRegisteredClientConfiguration.class, @@ -155,4 +161,67 @@ public class StagingProfileTest extends AbstractMetadataCodeFlowTest { assertTrue(claims.getStringListClaim("entitlements").contains(customEntitlement), "Does not contain custom entitlement"); } + + @Test + public void correctly_handles_missing_response_type_parameter() throws Exception { + mockMvc.perform(post(getAuthorizationEndpoint()) + .with(remoteUser("developer") + .entitlement(DEVELOPER_ENTITLEMENT)) + .queryParam("client_id", CLIENT_ID) + .queryParam("redirect_uri", REDIRECT_URI) + .formField("principal", "developer")) + .andExpect(status().isBadRequest()) + .andExpect(status().reason(containsString("response_type"))); + } + + @Test + public void redirects_back_to_client_if_there_are_errors_but_valid_redirect_uri() throws Exception { + mockMvc.perform(post(getAuthorizationEndpoint()) + .with(remoteUser("developer") + .entitlement(DEVELOPER_ENTITLEMENT)) + .queryParam("response_type", "code") + .queryParam("client_id", CLIENT_ID) + .queryParam("redirect_uri", REDIRECT_URI) + .queryParam("scope", "invalid") + .formField("principal", "developer")) + .andExpect(status().is3xxRedirection()) + .andExpect(result -> { + String redirectedUrl = result.getResponse().getRedirectedUrl(); + assertThat(redirectedUrl, containsString("error=invalid_scope")); + }); + } + + @Test + public void does_not_redirect_with_invalid_client_id() throws Exception { + mockMvc.perform(post(getAuthorizationEndpoint()) + .with(remoteUser("developer") + .entitlement(DEVELOPER_ENTITLEMENT)) + .queryParam("response_type", "code") + .queryParam("client_id", "invalid-client-id") + .queryParam("redirect_uri", REDIRECT_URI) + .formField("principal", "developer")) + .andExpect(status().isBadRequest()) + .andExpect(status().reason(containsString("client_id"))); + } + + @Test + public void maintains_state_during_error_redirect() throws Exception { + String state = "state123"; + + mockMvc.perform(post(getAuthorizationEndpoint()) + .with(remoteUser("developer") + .entitlement(DEVELOPER_ENTITLEMENT)) + .queryParam("response_type", "code") + .queryParam("client_id", CLIENT_ID) + .queryParam("redirect_uri", REDIRECT_URI) + .queryParam("state", state) + .queryParam("scope", "invalid") + .formField("principal", "developer")) + .andExpect(status().is3xxRedirection()) + .andExpect(result -> { + String redirectedUrl = result.getResponse().getRedirectedUrl(); + assertThat(redirectedUrl, containsString("error=invalid_scope")); + assertThat(redirectedUrl, containsString("state=" + state)); + }); + } }