diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5decd7b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,46 @@ +FROM eclipse-temurin:23 AS build + +WORKDIR /build + +# Create as small a runtime as possible but Spring/Tomcat needs a lot of modules +RUN jlink \ + --output jre \ + --add-modules java.sql,java.desktop,java.management,java.naming,java.security.jgss,java.instrument + +COPY pom.xml mvnw ./ +COPY .mvn .mvn + +RUN ./mvnw dependency:copy-dependencies \ + --activate-profiles=!persistent \ + --define includeScope=compile \ + --define outputDirectory=lib + +RUN ./mvnw dependency:build-classpath \ + --activate-profiles=!persistent \ + --define includeScope=compile \ + --define mdep.outputFile=classpath \ + --define mdep.prefix=lib + +COPY src src + +RUN ./mvnw compile + +FROM debian:stable-slim AS runtime + +WORKDIR /app + +COPY --from=build /build/jre jre +COPY --from=build /build/lib lib +COPY --from=build /build/classpath classpath +COPY --from=build /build/target/classes classes + +# Adds the output of Maven compilation to output +RUN echo ":classes" >> classpath + +EXPOSE 8080 + +CMD [ "./jre/bin/java" \ + , "-cp", "@classpath" \ + , "se.su.dsv.oauth2.AuthorizationServer" \ + , "--spring.profiles.active=dev,embedded" \ + ] diff --git a/README.md b/README.md index 8f8a5c8..dddcd6d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,18 @@ +## Using as an embedded Docker Compose service + +``` +services: + oauth2: + build: https://gitea.dsv.su.se/DMC/oauth2-authorization-server.git + restart: unless-stopped + ports: + - "<host_port>:8080" + environment: + CLIENT_ID=awesome-app + CLIENT_SECRET=p4ssw0rd + CLIENT_REDIRECT_URI=http://localhost/oauth2/callback +``` + ## Development ### Prerequisites - JDK 17 (or later) diff --git a/pom.xml b/pom.xml index 4c28b99..c6fb8a1 100644 --- a/pom.xml +++ b/pom.xml @@ -23,10 +23,6 @@ </properties> <dependencies> - <dependency> - <groupId>org.springframework.boot</groupId> - <artifactId>spring-boot-starter-jdbc</artifactId> - </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-oauth2-authorization-server</artifactId> @@ -52,21 +48,6 @@ <version>${jte.version}</version> </dependency> - <dependency> - <groupId>org.flywaydb</groupId> - <artifactId>flyway-core</artifactId> - </dependency> - <dependency> - <groupId>org.flywaydb</groupId> - <artifactId>flyway-mysql</artifactId> - </dependency> - - <dependency> - <groupId>org.mariadb.jdbc</groupId> - <artifactId>mariadb-java-client</artifactId> - <scope>runtime</scope> - </dependency> - <!-- Development tools --> <dependency> <groupId>org.springframework.boot</groupId> @@ -146,7 +127,6 @@ <sourceDirectory>${project.basedir}/src/main/resources/templates</sourceDirectory> <targetDirectory>${project.build.directory}/jte-classes</targetDirectory> <contentType>Html</contentType> - <binaryStaticContent>true</binaryStaticContent> <extensions> <extension> <className>gg.jte.models.generator.ModelExtension</className> @@ -172,4 +152,32 @@ </plugins> </build> + <profiles> + <profile> + <id>persistent</id> + <activation> + <activeByDefault>true</activeByDefault> + </activation> + <dependencies> + <dependency> + <groupId>org.springframework.boot</groupId> + <artifactId>spring-boot-starter-jdbc</artifactId> + </dependency> + <dependency> + <groupId>org.flywaydb</groupId> + <artifactId>flyway-core</artifactId> + </dependency> + <dependency> + <groupId>org.flywaydb</groupId> + <artifactId>flyway-mysql</artifactId> + </dependency> + <dependency> + <groupId>org.mariadb.jdbc</groupId> + <artifactId>mariadb-java-client</artifactId> + <scope>runtime</scope> + </dependency> + </dependencies> + </profile> + </profiles> + </project> diff --git a/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java b/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java index 810cf99..478e989 100644 --- a/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java +++ b/src/main/java/se/su/dsv/oauth2/AuthorizationServer.java @@ -24,6 +24,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.intercept.AuthorizationFilter; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.preauth.j2ee.J2eePreAuthenticatedProcessingFilter; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; import org.springframework.security.web.util.matcher.MediaTypeRequestMatcher; import se.su.dsv.oauth2.shibboleth.Entitlement; import se.su.dsv.oauth2.shibboleth.ShibbolethAuthenticationDetailsSource; @@ -153,6 +154,11 @@ public class AuthorizationServer extends SpringBootServletInitializer { // Using a custom authentication details source to extract the Shibboleth attributes // and convert them to the relevant Spring Security objects. object.setAuthenticationDetailsSource(new ShibbolethAuthenticationDetailsSource()); + + // Prevent session creation + // It can cause conflicts when running on the same host as an embedded docker container + // as it overwrites the session cookie (it does not factor in port) + object.setSecurityContextRepository(new RequestAttributeSecurityContextRepository()); return object; } }; diff --git a/src/main/java/se/su/dsv/oauth2/EmbeddedConfiguration.java b/src/main/java/se/su/dsv/oauth2/EmbeddedConfiguration.java new file mode 100644 index 0000000..2896fb9 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/EmbeddedConfiguration.java @@ -0,0 +1,84 @@ +package se.su.dsv.oauth2; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import se.su.dsv.oauth2.admin.repository.ClientRepository; +import se.su.dsv.oauth2.admin.repository.ClientRow; + +import java.security.Principal; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +@Configuration +@Profile("embedded") +public class EmbeddedConfiguration { + @Bean + public ClientRepository clientRepository() { + ArrayList<ClientRow> clients = new ArrayList<>(); + ClientRow clientRow = getClientFromEnvironment(); + if (clientRow != null) { + clients.add(clientRow); + } + + return new InMemoryClientrepository(clients); + + } + + private static ClientRow getClientFromEnvironment() { + String clientId = System.getenv("CLIENT_ID"); + String clientSecret = System.getenv("CLIENT_SECRET"); + String redirectUri = System.getenv("CLIENT_REDIRECT_URI"); + String scopeString = System.getenv("CLIENT_SCOPES"); + + return new ClientRow(clientId, clientId, clientId, "dev@localhost", + redirectUri, scopeString, clientSecret); + } + + private static class InMemoryClientrepository implements ClientRepository { + private List<ClientRow> clientRows; + + public InMemoryClientrepository(final List<ClientRow> clients) { + this.clientRows = new ArrayList<>(clients); + } + + @Override + public void addNewClient(final ClientRow clientRow) { + clientRows.add(clientRow); + } + + @Override + public List<ClientRow> getClients(final Principal owner) { + return List.copyOf(clientRows); + } + + @Override + public void addClientOwner(final String principalName, final String id) { + } + + @Override + public void removeOwner(final String id, final String owner) { + } + + @Override + public List<String> getOwners(final String id) { + return List.of("dev@localhost"); + } + + @Override + public Optional<ClientRow> getClientRowById(final String id) { + return clientRows.stream() + .filter(clientRow -> Objects.equals(clientRow.id(), id)) + .findAny(); + } + + @Override + public Optional<ClientRow> getClientRowByClientId(final String clientId) { + return clientRows.stream() + .filter(clientRow -> Objects.equals(clientRow.clientId(), clientId)) + .findAny(); + } + } +} diff --git a/src/main/java/se/su/dsv/oauth2/JDBCClientRepository.java b/src/main/java/se/su/dsv/oauth2/JDBCClientRepository.java new file mode 100644 index 0000000..638398e --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/JDBCClientRepository.java @@ -0,0 +1,79 @@ +package se.su.dsv.oauth2; + +import org.springframework.jdbc.core.simple.JdbcClient; +import se.su.dsv.oauth2.admin.repository.ClientRepository; +import se.su.dsv.oauth2.admin.repository.ClientRow; + +import java.security.Principal; +import java.util.List; +import java.util.Optional; + +public class JDBCClientRepository implements ClientRepository { + public final JdbcClient jdbc; + + public JDBCClientRepository(final JdbcClient jdbc) { + this.jdbc = jdbc; + } + + public JdbcClient getJdbc() { + return jdbc; + } + + @Override + public void addClientOwner(final String principalName, final String id) { + getJdbc().sql("INSERT INTO client_owner (client_id, owner) VALUES (:clientId, :owner)") + .param("clientId", id) + .param("owner", principalName) + .update(); + } + + @Override + public List<ClientRow> getClients(final Principal owner) { + return getJdbc().sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE id IN (SELECT client_id FROM client_owner WHERE owner = :owner)") + .param("owner", owner.getName()) + .query(ClientRow.class) + .list(); + } + + @Override + public void removeOwner(final String id, final String owner) { + getJdbc().sql("DELETE FROM client_owner WHERE client_id = :id AND owner = :owner") + .param("id", id) + .param("owner", owner) + .update(); + } + + @Override + public Optional<ClientRow> getClientRowById(final String id) { + return getJdbc().sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE id = :id") + .param("id", id) + .query(ClientRow.class) + .optional(); + } + + @Override + public Optional<ClientRow> getClientRowByClientId(final String clientId) { + return getJdbc().sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE client_id = :clientId") + .param("clientId", clientId) + .query(ClientRow.class) + .optional(); + } + + @Override + public void addNewClient(final ClientRow clientRow) { + getJdbc().sql(""" + INSERT INTO client (id, client_id, client_secret, name, redirect_uri, contact_email, scopes) + VALUES (:id, :clientId, :clientSecret, :name, :redirectUri, :contactEmail, :scopes) + """) + .paramSource(clientRow) + .update(); + } + + @Override + public List<String> getOwners(final String id) { + return getJdbc().sql("SELECT owner FROM client_owner WHERE client_id = :clientId") + .param("clientId", id) + .query(String.class) + .list(); + } +} diff --git a/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java new file mode 100644 index 0000000..1e595e0 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java @@ -0,0 +1,15 @@ +package se.su.dsv.oauth2; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.jdbc.core.simple.JdbcClient; + +@Configuration +@Profile("!embedded") +public class PersistentConfiguration { + @Bean + public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) { + return new JDBCClientRepository(jdbcClient); + } +} diff --git a/src/main/java/se/su/dsv/oauth2/admin/ClientManager.java b/src/main/java/se/su/dsv/oauth2/admin/ClientManager.java index 8ff738b..966e2c6 100644 --- a/src/main/java/se/su/dsv/oauth2/admin/ClientManager.java +++ b/src/main/java/se/su/dsv/oauth2/admin/ClientManager.java @@ -1,6 +1,5 @@ package se.su.dsv.oauth2.admin; -import org.springframework.jdbc.core.simple.JdbcClient; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -10,29 +9,29 @@ import org.springframework.security.oauth2.server.authorization.settings.ClientS import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat; import org.springframework.security.oauth2.server.authorization.settings.TokenSettings; import org.springframework.stereotype.Service; +import se.su.dsv.oauth2.admin.repository.ClientRepository; +import se.su.dsv.oauth2.admin.repository.ClientRow; import java.security.Principal; import java.time.Duration; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.function.Predicate; -import java.util.stream.Collectors; @Service public class ClientManager implements RegisteredClientRepository, ClientManagementService { private final PasswordEncoder passwordEncoder; - private final JdbcClient jdbc; + private final ClientRepository clientRepository; public ClientManager( PasswordEncoder passwordEncoder, - JdbcClient jdbc) + ClientRepository clientRepository) { this.passwordEncoder = passwordEncoder; - this.jdbc = jdbc; + this.clientRepository = clientRepository; } @Override @@ -42,49 +41,32 @@ public class ClientManager implements RegisteredClientRepository, ClientManageme String clientSecret = clientData.isPublic() ? null : Util.generateAlphanumericString(32); String encodedClientSecret = clientSecret == null ? null : passwordEncoder.encode(clientSecret); String redirectURI = clientData.redirectURI() != null ? clientData.redirectURI().toString() : null; + String scopeString = String.join(" ", clientData.scopes()); - jdbc.sql(""" - INSERT INTO client (id, client_id, client_secret, name, redirect_uri, contact_email, scopes) - VALUES (:id, :clientId, :clientSecret, :name, :redirectUri, :contactEmail, :scopes) - """) - .param("id", id) - .param("clientId", clientId) - .param("clientSecret", encodedClientSecret) - .param("name", clientData.clientName()) - .param("redirectUri", redirectURI) - .param("contactEmail", clientData.contactEmail()) - .param("scopes", String.join(" ", clientData.scopes())) - .update(); + ClientRow clientRow = new ClientRow(id, clientId, clientData.clientName(), clientData.contactEmail(), + redirectURI, scopeString, encodedClientSecret); - addClientOwner(owner.getName(), id); + clientRepository.addNewClient(clientRow); + + clientRepository.addClientOwner(owner.getName(), id); return new NewClient(id, clientId, clientSecret); } - private void addClientOwner(final String principalName, final String id) { - jdbc.sql("INSERT INTO client_owner (client_id, owner) VALUES (:clientId, :owner)") - .param("clientId", id) - .param("owner", principalName) - .update(); - } - @Override public Optional<Client> getClient(final Principal principal, final String id) { - boolean ownsClient = getOwners(id).contains(principal.getName()); + boolean ownsClient = clientRepository.getOwners(id).contains(principal.getName()); if (!ownsClient) { return Optional.empty(); } - return getClientRowById(id) + return clientRepository.getClientRowById(id) .map(this::toClient); } @Override public List<Client> getClients(final Principal owner) { - return jdbc.sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE id IN (SELECT client_id FROM client_owner WHERE owner = :owner)") - .param("owner", owner.getName()) - .query(ClientRow.class) - .list() + return clientRepository.getClients(owner) .stream() .map(this::toClient) .toList(); @@ -92,48 +74,28 @@ public class ClientManager implements RegisteredClientRepository, ClientManageme @Override public void addOwner(final Principal currentUser, final String id, final String newOwnerPrincipal) { - if (!getOwners(id).contains(currentUser.getName())) { + if (!clientRepository.getOwners(id).contains(currentUser.getName())) { throw new IllegalStateException(currentUser.getName() + " is not an owner of the client"); } - jdbc.sql("INSERT INTO client_owner (client_id, owner) VALUES (:id, :owner) ON DUPLICATE KEY UPDATE owner = owner") - .param("id", id) - .param("owner", newOwnerPrincipal) - .update(); + clientRepository.addClientOwner(newOwnerPrincipal, id); } @Override public boolean removeOwner(final Principal currentUser, final String id, final String owner) { - if (!getOwners(id).contains(currentUser.getName())) { + if (!clientRepository.getOwners(id).contains(currentUser.getName())) { throw new IllegalStateException(currentUser.getName() + " is not an owner of the client"); } if (currentUser.getName().equals(owner)) { return false; } else { - jdbc.sql("DELETE FROM client_owner WHERE client_id = :id AND owner = :owner") - .param("id", id) - .param("owner", owner) - .update(); + clientRepository.removeOwner(id, owner); return true; } } - private Optional<ClientRow> getClientRowById(final String id) { - return jdbc.sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE id = :id") - .param("id", id) - .query(ClientRow.class) - .optional(); - } - - - private Optional<ClientRow> getClientRowByClientId(final String clientId) { - return jdbc.sql("SELECT id, client_id, name, contact_email, redirect_uri, scopes, client_secret FROM client WHERE client_id = :clientId") - .param("clientId", clientId) - .query(ClientRow.class) - .optional(); - } private Client toClient(final ClientRow clientRow) { - List<String> owners = getOwners(clientRow.id()); + List<String> owners = new ArrayList<>(clientRepository.getOwners(clientRow.id())); owners.sort(Comparator.naturalOrder()); Set<String> scopes = clientRow.scopeSet(); boolean isPublic = clientRow.isPublic(); @@ -148,30 +110,23 @@ public class ClientManager implements RegisteredClientRepository, ClientManageme owners); } - private List<String> getOwners(final String id) { - return jdbc.sql("SELECT owner FROM client_owner WHERE client_id = :clientId") - .param("clientId", id) - .query(String.class) - .list(); - } - // Used by various components of the OAuth 2.0 infrastructure to upgrade // the client secret if necessary based on the PasswordEncoder bean. @Override public void save(final RegisteredClient registeredClient) { - throw new UnsupportedOperationException("ClientManager#save(RegisteredClient)"); + // TODO fix support for upgrading client secrets } @Override public RegisteredClient findById(final String id) { - return getClientRowById(id) + return clientRepository.getClientRowById(id) .map(ClientManager::toRegisteredClient) .orElse(null); } @Override public RegisteredClient findByClientId(final String clientId) { - return getClientRowByClientId(clientId) + return clientRepository.getClientRowByClientId(clientId) .map(ClientManager::toRegisteredClient) .orElse(null); } @@ -216,24 +171,4 @@ public class ClientManager implements RegisteredClientRepository, ClientManageme .scopes(currentScopes -> currentScopes.addAll(clientRow.scopeSet())) .build(); } - - private record ClientRow( - String id, - String clientId, - String name, - String contactEmail, - String redirectUri, - String scopes, - String clientSecret) - { - private Set<String> scopeSet() { - return Arrays.stream(this.scopes.split(" ")) - .filter(Predicate.not(String::isBlank)) - .collect(Collectors.toUnmodifiableSet()); - } - - private boolean isPublic() { - return clientSecret == null; - } - } } diff --git a/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRepository.java b/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRepository.java new file mode 100644 index 0000000..e69c2db --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRepository.java @@ -0,0 +1,21 @@ +package se.su.dsv.oauth2.admin.repository; + +import java.security.Principal; +import java.util.List; +import java.util.Optional; + +public interface ClientRepository { + void addNewClient(ClientRow clientRow); + + List<ClientRow> getClients(Principal owner); + + void addClientOwner(String principalName, String id); + + void removeOwner(String id, String owner); + + List<String> getOwners(String id); + + Optional<ClientRow> getClientRowById(String id); + + Optional<ClientRow> getClientRowByClientId(String clientId); +} diff --git a/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRow.java b/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRow.java new file mode 100644 index 0000000..2764de8 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/admin/repository/ClientRow.java @@ -0,0 +1,29 @@ +package se.su.dsv.oauth2.admin.repository; + +import java.util.Arrays; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +public record ClientRow( + String id, + String clientId, + String name, + String contactEmail, + String redirectUri, + String scopes, + String clientSecret) +{ + public Set<String> scopeSet() { + if (scopes == null) { + return Set.of(); + } + return Arrays.stream(this.scopes.split(" ")) + .filter(Predicate.not(String::isBlank)) + .collect(Collectors.toUnmodifiableSet()); + } + + public boolean isPublic() { + return clientSecret == null; + } +} diff --git a/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethConfigurer.java b/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethConfigurer.java index 982a55f..1aa0252 100644 --- a/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethConfigurer.java +++ b/src/main/java/se/su/dsv/oauth2/shibboleth/ShibbolethConfigurer.java @@ -7,6 +7,7 @@ import org.springframework.security.web.authentication.preauth.PreAuthenticatedA import org.springframework.security.web.authentication.preauth.PreAuthenticatedGrantedAuthoritiesUserDetailsService; import org.springframework.security.web.authentication.preauth.j2ee.J2eePreAuthenticatedProcessingFilter; import org.springframework.security.web.authentication.preauth.x509.X509AuthenticationFilter; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; public class ShibbolethConfigurer extends AbstractHttpConfigurer<ShibbolethConfigurer, HttpSecurity> { @Override @@ -24,6 +25,12 @@ public class ShibbolethConfigurer extends AbstractHttpConfigurer<ShibbolethConfi filter.setAuthenticationDetailsSource(new ShibbolethAuthenticationDetailsSource()); filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + // Do not create a session. + // 1) it is not necessary + // 2) it can cause conflicts when running on the same host as an embedded docker container + // as it overwrites the session cookie (it does not factor in port) + filter.setSecurityContextRepository(new RequestAttributeSecurityContextRepository()); + // The default filter order is X509 followed by J2EE (pre-authentication which is what Shibboleth does). // Spring Authorization server then puts the OAuth 2.0 authorization filter before J2EE, and it requires // the user to be authenticated. Then there is also the custom authorization endpoint used in staging diff --git a/src/main/resources/application-embedded.yml b/src/main/resources/application-embedded.yml new file mode 100644 index 0000000..6da4a58 --- /dev/null +++ b/src/main/resources/application-embedded.yml @@ -0,0 +1,4 @@ +gg: + jte: + developmentMode: false + usePrecompiledTemplates: true diff --git a/src/test/java/se/su/dsv/oauth2/EmbeddedContainerTest.java b/src/test/java/se/su/dsv/oauth2/EmbeddedContainerTest.java new file mode 100644 index 0000000..fb27112 --- /dev/null +++ b/src/test/java/se/su/dsv/oauth2/EmbeddedContainerTest.java @@ -0,0 +1,162 @@ +package se.su.dsv.oauth2; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.jwk.source.JWKSourceBuilder; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jose.util.DefaultResourceRetriever; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestClient; +import org.springframework.web.util.UriComponentsBuilder; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.images.builder.ImageFromDockerfile; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.net.URI; +import java.net.URL; +import java.nio.file.Paths; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +@Testcontainers +public class EmbeddedContainerTest { + + private static final String CLIENT_ID = "client-id"; + private static final String CLIENT_SECRET = "client-secret"; + private static final String CLIENT_REDIRECT_URI = "http://localhost:8080"; + private static final String AUTHORIZATION_HEADER = "Basic " + HttpHeaders.encodeBasicAuth(CLIENT_ID, CLIENT_SECRET, null); + + @Container + static GenericContainer<?> container = new GenericContainer<>( + new ImageFromDockerfile() + .withFileFromPath(".", Paths.get("."))) + .withExposedPorts(8080) + .withEnv("CLIENT_ID", CLIENT_ID) + .withEnv("CLIENT_SECRET", CLIENT_SECRET) + .withEnv("CLIENT_REDIRECT_URI", CLIENT_REDIRECT_URI) + .withEnv("CLIENT_SCOPES", "openid profile email"); + + private RestClient restClient; + private ObjectMapper objectMapper; + + @BeforeEach + public void setUp() { + String baseUri = UriComponentsBuilder.newInstance() + .scheme("http") + .host(container.getHost()) + .port(container.getMappedPort(8080)) + .toUriString(); + + restClient = RestClient.create(baseUri); + objectMapper = new ObjectMapper(); + } + + @Test + public void working_container() { + ResponseEntity<String> response = restClient + .get() + .retrieve() + .onStatus(ignored -> false) // treat all responses as successful and let asserts fail + .toEntity(String.class); + + assertThat(response.getStatusCode()) + .isEqualTo(HttpStatus.OK); + + assertThat(response.getBody()) + .contains("DSV"); + } + + @Test + public void custom_authorize_flow_via_metadata_and_public_key_verification() throws Exception { + String metadata = restClient.get() + .uri("/.well-known/oauth-authorization-server") + .retrieve() + .body(String.class); + + JsonNode parsedMetadata = objectMapper.readTree(metadata); + + // 2. Get JWKS + URL jwksUri = URI.create(parsedMetadata.required("jwks_uri").asText()).toURL(); + JWKSource<SecurityContext> jwkSource = JWKSourceBuilder + .create(jwksUri, new DefaultResourceRetriever()) + .build(); + + final DefaultJWTProcessor<SecurityContext> processor = new DefaultJWTProcessor<>(); + JWSAlgorithm acceptedAlgorithms = JWSAlgorithm.RS256; + JWSVerificationKeySelector<SecurityContext> keySelector = + new JWSVerificationKeySelector<>(acceptedAlgorithms, jwkSource); + processor.setJWSKeySelector(keySelector); + + final MultiValueMap<String, Object> form = new LinkedMultiValueMap<>(); + form.put("principal", List.of("test")); + + String authorizationEndpoint = parsedMetadata.required("authorization_endpoint").asText(); + String authorizeUri = UriComponentsBuilder.fromUriString(authorizationEndpoint) + .queryParam("response_type", "code") + .queryParam("client_id", CLIENT_ID) + .queryParam("redirect_uri", CLIENT_REDIRECT_URI) + .queryParam("scope", "openid profile email") + .build() + .toUriString(); + + ResponseEntity<Void> authorizationResponse = restClient.post() + .uri(authorizeUri) + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body(form) + .retrieve() + .toBodilessEntity(); + + URI redirectLocation = authorizationResponse.getHeaders().getLocation(); + assertThat(redirectLocation).isNotNull(); + + String query = redirectLocation.getQuery(); + assertThat(query).isNotBlank(); + + String code = query.substring("code=".length()); + assertThat(code).isNotBlank(); + + LinkedMultiValueMap<Object, Object> tokenRequestBody = new LinkedMultiValueMap<>(); + tokenRequestBody.add("code", code); + tokenRequestBody.add("grant_type", "authorization_code"); + tokenRequestBody.add("redirect_uri", CLIENT_REDIRECT_URI); + + TokenResponse tokenResponse = restClient.post() + .uri(parsedMetadata.required("token_endpoint").asText()) + .header("Authorization", AUTHORIZATION_HEADER) + .body(tokenRequestBody) + .retrieve() + .body(TokenResponse.class); + + assertThat(tokenResponse).isNotNull(); + assertThat(tokenResponse.accessToken()).isNotBlank(); + assertThat(tokenResponse.idToken()).isNotBlank(); + + JWTClaimsSet accessTokenClaims = assertDoesNotThrow( + () -> processor.process(tokenResponse.accessToken(), null), + "Failed to verify access token"); + + assertThat(accessTokenClaims.getSubject()).isEqualTo("test"); + + JWTClaimsSet idTokenClaims = assertDoesNotThrow( + () -> processor.process(tokenResponse.idToken(), null), + "Failed to verify id token"); + + assertThat(accessTokenClaims.getSubject()).isEqualTo("test"); + assertThat(idTokenClaims.getSubject()).isEqualTo("test"); + } +}