diff --git a/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java index 1e595e0..2f3adb4 100644 --- a/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java +++ b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java @@ -4,6 +4,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Profile; import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @Configuration @Profile("!embedded") @@ -12,4 +13,9 @@ public class PersistentConfiguration { public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) { return new JDBCClientRepository(jdbcClient); } + + @Bean + public OAuth2AuthorizationService authorizationService(JdbcClient jdbcClient) { + return new SerializingJDBCOAuth2AuthorizationService(jdbcClient); + } } diff --git a/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java b/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java new file mode 100644 index 0000000..70f7376 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java @@ -0,0 +1,172 @@ +package se.su.dsv.oauth2; + +import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Optional; + +/// Stores [OAuth2Authorization] by first using Java serialization to convert it to a byte array, +/// and then storing that byte array in the database. +class SerializingJDBCOAuth2AuthorizationService implements OAuth2AuthorizationService { + private final System.Logger logger = System.getLogger(this.getClass().getName()); + + private final JdbcClient jdbc; + + SerializingJDBCOAuth2AuthorizationService(final JdbcClient jdbc) { + this.jdbc = jdbc; + } + + @Override + public void save(final OAuth2Authorization authorization) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (var oos = new ObjectOutputStream(out)) { + oos.writeObject(authorization); + } catch (IOException e) { + logger.log(System.Logger.Level.ERROR, "Failed to serialize OAuth2Authorization", e); + return; + } + byte[] bytes = out.toByteArray(); + + String state = nullIfEmpty(authorization.getAttribute(OAuth2ParameterNames.STATE)); + String codeToken = getToken(authorization, OAuth2AuthorizationCode.class); + String deviceCodeToken = getToken(authorization, OAuth2DeviceCode.class); + String userCodeToken = getToken(authorization, OAuth2UserCode.class); + String accessToken = getToken(authorization, OAuth2AccessToken.class); + String refreshToken = getToken(authorization, OAuth2RefreshToken.class); + String idToken = getToken(authorization, OidcIdToken.class); + jdbc.sql(""" + INSERT INTO v2_oauth2_authorization (id, serialized_data, state, code_token, device_code_token, user_code_token, access_token, refresh_token, id_token) + VALUES (:id, :serialized_data, :state, :code_token, :device_code_token, :user_code_token, :access_token, :refresh_token, :id_token) + ON DUPLICATE KEY UPDATE + serialized_data = VALUES(serialized_data), + state = VALUES(state), + code_token = VALUES(code_token), + device_code_token = VALUES(device_code_token), + user_code_token = VALUES(user_code_token), + access_token = VALUES(access_token), + refresh_token = VALUES(refresh_token), + id_token = VALUES(id_token) + """) + .param("id", authorization.getId()) + .param("serialized_data", bytes) + .param("state", state) + .param("code_token", codeToken) + .param("device_code_token", deviceCodeToken) + .param("user_code_token", userCodeToken) + .param("access_token", accessToken) + .param("refresh_token", refreshToken) + .param("id_token", idToken) + .update(); + } + + private static String getToken( + final OAuth2Authorization authorization, + final Class tokenType) + { + OAuth2Authorization.Token token = authorization.getToken(tokenType); + if (token != null) { + return token.getToken().getTokenValue(); + } else { + return null; + } + } + + private String nullIfEmpty(final String attribute) { + return attribute == null || attribute.isBlank() ? null : attribute; + } + + @Override + public void remove(final OAuth2Authorization authorization) { + final String id = authorization.getId(); + jdbc.sql("DELETE FROM v2_oauth2_authorization WHERE id = :id") + .param("id", id) + .update(); + } + + @Override + public OAuth2Authorization findById(final String id) { + Optional optionalBytes = jdbc.sql("SELECT serialized_data FROM v2_oauth2_authorization WHERE id = :id") + .param("id", id) + .query(byte[].class) + .optional(); + + if (optionalBytes.isEmpty()) { + return null; + } + + byte[] bytes = optionalBytes.get(); + return readOAuth2Authorization(bytes); + } + + private OAuth2Authorization readOAuth2Authorization(final byte[] bytes) { + try (var ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) { + Object object = ois.readObject(); + if (object instanceof OAuth2Authorization authorization) { + return authorization; + } else { + logger.log(System.Logger.Level.WARNING, "Garbage OAuth2Authorization found in database"); + } + } catch (IOException | ClassNotFoundException e) { + logger.log(System.Logger.Level.WARNING, "Failed to deserialize OAuth2Authorization", e); + } + return null; + } + + @Override + public OAuth2Authorization findByToken(final String token, final OAuth2TokenType tokenType) { + if (tokenType == null) { + return findByAnyToken(token); + } + String sql = switch (tokenType.getValue()) { + case "state" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE state = :token"; + case "code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE code_token = :token"; + case "device_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE device_code_token = :token"; + case "user_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE user_code_token = :token"; + case "access_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE access_token = :token"; + case "refresh_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE refresh_token = :token"; + case "id_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE id_token = :token"; + default -> throw new UnsupportedOperationException("Unknown token type: " + tokenType.getValue()); + }; + + return jdbc + .sql(sql) + .param("token", token) + .query(byte[].class) + .optional() + .map(this::readOAuth2Authorization) + .orElse(null); + } + + private OAuth2Authorization findByAnyToken(final String token) { + return jdbc.sql(""" + SELECT serialized_data FROM v2_oauth2_authorization + WHERE state = :token + OR code_token = :token + OR device_code_token = :token + OR user_code_token = :token + OR access_token = :token + OR refresh_token = :token + OR id_token = :token + """) + .param("token", token) + .query(byte[].class) + .optional() + .map(this::readOAuth2Authorization) + .orElse(null); + } +} diff --git a/src/main/resources/db/migration/V4__token_presistence.sql b/src/main/resources/db/migration/V4__token_presistence.sql new file mode 100644 index 0000000..8d0667d --- /dev/null +++ b/src/main/resources/db/migration/V4__token_presistence.sql @@ -0,0 +1,13 @@ +CREATE TABLE v2_oauth2_authorization +( + id varchar(100) NOT NULL, + serialized_data BLOB NOT NULL, + state TEXT, + code_token TEXT, + device_code_token TEXT, + user_code_token TEXT, + access_token TEXT, + refresh_token TEXT, + id_token TEXT, + PRIMARY KEY (id) +);