diff --git a/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java b/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java index 5bc7574..17c2ff8 100644 --- a/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java +++ b/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java @@ -91,7 +91,9 @@ public class AuthorizationServer extends SpringBootServletInitializer { .authorizationEndpoint(authorizeEndpoint -> authorizeEndpoint .consentPage("/oauth2/consent")) .withObjectPostProcessor(enableGivingConsentWithNoScopes()) - .oidc(Customizer.withDefaults())) + .oidc(oidc -> oidc + .userInfoEndpoint(userInfo -> userInfo + .userInfoMapper(new OidcUserInfoMapper())))) .with(new ShibbolethConfigurer(), Customizer.withDefaults()) .sessionManagement(session -> session // Never use the session and always rely on the Shibboleth authentication diff --git a/src/main/java/se/su/dsv/oauth2/OidcUserInfoMapper.java b/src/main/java/se/su/dsv/oauth2/OidcUserInfoMapper.java new file mode 100644 index 0000000..99670d6 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/OidcUserInfoMapper.java @@ -0,0 +1,84 @@ +package se.su.dsv.oauth2; + +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider; +import se.su.dsv.oauth2.shibboleth.ShibbolethTokenPopulator; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +/// Straight up copied from [OidcUserInfoAuthenticationProvider.DefaultOidcUserInfoMapper] +/// with the addition of always including the [ShibbolethTokenPopulator#ENTITLEMENTS_CLAIM] claim +public class OidcUserInfoMapper implements Function { + private static final List EMAIL_CLAIMS = Arrays.asList( + StandardClaimNames.EMAIL, + StandardClaimNames.EMAIL_VERIFIED + ); + private static final List PHONE_CLAIMS = Arrays.asList( + StandardClaimNames.PHONE_NUMBER, + StandardClaimNames.PHONE_NUMBER_VERIFIED + ); + private static final List PROFILE_CLAIMS = Arrays.asList( + StandardClaimNames.NAME, + StandardClaimNames.FAMILY_NAME, + StandardClaimNames.GIVEN_NAME, + StandardClaimNames.MIDDLE_NAME, + StandardClaimNames.NICKNAME, + StandardClaimNames.PREFERRED_USERNAME, + StandardClaimNames.PROFILE, + StandardClaimNames.PICTURE, + StandardClaimNames.WEBSITE, + StandardClaimNames.GENDER, + StandardClaimNames.BIRTHDATE, + StandardClaimNames.ZONEINFO, + StandardClaimNames.LOCALE, + StandardClaimNames.UPDATED_AT + ); + + @Override + public OidcUserInfo apply(OidcUserInfoAuthenticationContext authenticationContext) { + OAuth2Authorization authorization = authenticationContext.getAuthorization(); + OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken(); + OAuth2AccessToken accessToken = authenticationContext.getAccessToken(); + Map scopeRequestedClaims = getClaimsRequestedByScope(idToken.getClaims(), + accessToken.getScopes()); + + return new OidcUserInfo(scopeRequestedClaims); + } + + private static Map getClaimsRequestedByScope(Map claims, + Set requestedScopes) { + Set scopeRequestedClaimNames = new HashSet<>(32); + scopeRequestedClaimNames.add(StandardClaimNames.SUB); + scopeRequestedClaimNames.add(ShibbolethTokenPopulator.ENTITLEMENTS_CLAIM); + + if (requestedScopes.contains(OidcScopes.ADDRESS)) { + scopeRequestedClaimNames.add(StandardClaimNames.ADDRESS); + } + if (requestedScopes.contains(OidcScopes.EMAIL)) { + scopeRequestedClaimNames.addAll(EMAIL_CLAIMS); + } + if (requestedScopes.contains(OidcScopes.PHONE)) { + scopeRequestedClaimNames.addAll(PHONE_CLAIMS); + } + if (requestedScopes.contains(OidcScopes.PROFILE)) { + scopeRequestedClaimNames.addAll(PROFILE_CLAIMS); + } + + Map requestedClaims = new HashMap<>(claims); + requestedClaims.keySet().removeIf((claimName) -> !scopeRequestedClaimNames.contains(claimName)); + + return requestedClaims; + } +} diff --git a/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethTokenPopulator.java b/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethTokenPopulator.java index 2f1d778..eae7aab 100644 --- a/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethTokenPopulator.java +++ b/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethTokenPopulator.java @@ -16,6 +16,7 @@ import java.util.Set; /// Populate the tokens with Shibboleth attributes, if available. public final class ShibbolethTokenPopulator implements OAuth2TokenCustomizer { + public static final String ENTITLEMENTS_CLAIM = "entitlements"; private static final Set EMAIL_CLAIMS = Set.of( StandardClaimNames.EMAIL, StandardClaimNames.EMAIL_VERIFIED @@ -40,19 +41,15 @@ public final class ShibbolethTokenPopulator implements OAuth2TokenCustomizer entitlements = context - .getPrincipal() - .getAuthorities() - .stream() - .filter(Entitlement.class::isInstance) - .map(Entitlement.class::cast) - .map(Entitlement::entitlement) - .toList(); + List entitlements = getEntitlements(context); - context.getClaims().claim("entitlements", entitlements); + context.getClaims().claim(ENTITLEMENTS_CLAIM, entitlements); } if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + List entitlements = getEntitlements(context); + context.getClaims().claim(ENTITLEMENTS_CLAIM, entitlements); + if (context.getPrincipal().getDetails() instanceof ShibbolethAuthenticationDetails details) { OidcUserInfo oidcUserInfo = getOidcUserInfo(details); @@ -67,6 +64,17 @@ public final class ShibbolethTokenPopulator implements OAuth2TokenCustomizer getEntitlements(JwtEncodingContext context) { + return context + .getPrincipal() + .getAuthorities() + .stream() + .filter(Entitlement.class::isInstance) + .map(Entitlement.class::cast) + .map(Entitlement::entitlement) + .toList(); + } + private Set getAuthorizedClaims(final Set scopes) { Set authorizedClaims = new HashSet<>(); if (scopes.contains(OidcScopes.PROFILE)) { diff --git a/src/test/java/se/su/dsv/oauth2/AbstractMetadataTest.java b/src/test/java/se/su/dsv/oauth2/AbstractMetadataTest.java index d2fbfe0..998ffa7 100644 --- a/src/test/java/se/su/dsv/oauth2/AbstractMetadataTest.java +++ b/src/test/java/se/su/dsv/oauth2/AbstractMetadataTest.java @@ -37,7 +37,7 @@ public class AbstractMetadataTest { @BeforeEach public void setUp() throws Exception { // 1. Get metadata - MvcResult metadataResult = mockMvc.perform(get("/.well-known/oauth-authorization-server")) + MvcResult metadataResult = mockMvc.perform(get("/.well-known/openid-configuration")) .andExpect(status().isOk()) .andReturn(); @@ -69,6 +69,10 @@ public class AbstractMetadataTest { return metadata.get("introspection_endpoint").asText(); } + protected String getUserInfoEndpoint() { + return metadata.get("userinfo_endpoint").asText(); + } + protected JWTClaimsSet verifyToken(String token) throws Exception { return processor.process(token, null); } diff --git a/src/test/java/se/su/dsv/oauth2/IdTokenTest.java b/src/test/java/se/su/dsv/oauth2/IdTokenTest.java index b8188e4..df66b98 100644 --- a/src/test/java/se/su/dsv/oauth2/IdTokenTest.java +++ b/src/test/java/se/su/dsv/oauth2/IdTokenTest.java @@ -6,6 +6,8 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.test.context.ActiveProfiles; +import java.util.List; + import static org.junit.jupiter.api.Assertions.*; import static se.su.dsv.oauth2.ShibbolethRequestProcessor.remoteUser; @@ -108,4 +110,21 @@ public class IdTokenTest extends AbstractMetadataCodeFlowTest { assertNotNull(claimsSet.getClaim(StandardClaimNames.EMAIL_VERIFIED)); } + @Test + public void includes_entitlements_in_the_id_token() throws Exception { + TokenResponse tokenResponse = authorize(request -> request + .queryParam("scope", OidcScopes.OPENID) + .with(remoteUser("someone@university") + .entitlement("gdpr") + .entitlement("hr"))); + + String idToken = tokenResponse.idToken(); + assertNotNull(idToken); + + JWTClaimsSet claimsSet = verifyToken(idToken); + List entitlements = claimsSet.getStringListClaim("entitlements"); + assertTrue(entitlements.contains("gdpr")); + assertTrue(entitlements.contains("hr")); + } + } diff --git a/src/test/java/se/su/dsv/oauth2/UserInfoEndpointTest.java b/src/test/java/se/su/dsv/oauth2/UserInfoEndpointTest.java index 0d73c77..bc2a672 100644 --- a/src/test/java/se/su/dsv/oauth2/UserInfoEndpointTest.java +++ b/src/test/java/se/su/dsv/oauth2/UserInfoEndpointTest.java @@ -3,13 +3,18 @@ package se.su.dsv.oauth2; import com.fasterxml.jackson.databind.JsonNode; import org.junit.jupiter.api.Test; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.test.web.servlet.MvcResult; import java.net.URI; +import static org.hamcrest.Matchers.hasItem; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static se.su.dsv.oauth2.ShibbolethRequestProcessor.remoteUser; @SpringBootTest(classes = TestRegisteredClientConfiguration.class) public class UserInfoEndpointTest extends AbstractMetadataCodeFlowTest { @@ -26,4 +31,24 @@ public class UserInfoEndpointTest extends AbstractMetadataCodeFlowTest { assertEquals("/oidc/userinfo", userInfoUri.getPath()); } + + @Test + public void includes_entitlements_in_userinfo() throws Exception { + TokenResponse tokenResponse = authorize(request -> request + .queryParam("scope", OidcScopes.OPENID) + .with(remoteUser("someone@university") + .entitlement("gdpr") + .entitlement("hr"))); + + String accessToken = tokenResponse.accessToken(); + assertNotNull(accessToken); + + mockMvc.perform(get(getUserInfoEndpoint()) + .header("Authorization", "Bearer " + accessToken)) + .andExpect(status().isOk()) + .andExpectAll( + jsonPath("$.entitlements").isArray(), + jsonPath("$.entitlements").value(hasItem("gdpr")), + jsonPath("$.entitlements").value(hasItem("hr"))); + } }