From 9a6e21a396743ef9f3ad28eea7d41c4bf34d5814 Mon Sep 17 00:00:00 2001 From: Andreas Svanberg Date: Fri, 28 Mar 2025 11:58:35 +0100 Subject: [PATCH] Persist tokens between restarts Utilize Java serialization to turn the entire OAuth2Authorization to a binary blob and store that in the database. Could not find a better way to do it given the types involved (like Map properties). Sure, Java serialization can fail on arbitrary objects but hopefully since OAuth2Authorization implements java.io.Serializable any properties put in are serializable as well. --- .../dsv/oauth2/PersistentConfiguration.java | 6 + ...alizingJDBCOAuth2AuthorizationService.java | 172 ++++++++++++++++++ .../db/migration/V4__token_presistence.sql | 13 ++ 3 files changed, 191 insertions(+) create mode 100644 src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java create mode 100644 src/main/resources/db/migration/V4__token_presistence.sql diff --git a/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java index 1e595e0..2f3adb4 100644 --- a/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java +++ b/src/main/java/se/su/dsv/oauth2/PersistentConfiguration.java @@ -4,6 +4,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Profile; import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @Configuration @Profile("!embedded") @@ -12,4 +13,9 @@ public class PersistentConfiguration { public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) { return new JDBCClientRepository(jdbcClient); } + + @Bean + public OAuth2AuthorizationService authorizationService(JdbcClient jdbcClient) { + return new SerializingJDBCOAuth2AuthorizationService(jdbcClient); + } } diff --git a/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java b/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java new file mode 100644 index 0000000..70f7376 --- /dev/null +++ b/src/main/java/se/su/dsv/oauth2/SerializingJDBCOAuth2AuthorizationService.java @@ -0,0 +1,172 @@ +package se.su.dsv.oauth2; + +import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Optional; + +/// Stores [OAuth2Authorization] by first using Java serialization to convert it to a byte array, +/// and then storing that byte array in the database. +class SerializingJDBCOAuth2AuthorizationService implements OAuth2AuthorizationService { + private final System.Logger logger = System.getLogger(this.getClass().getName()); + + private final JdbcClient jdbc; + + SerializingJDBCOAuth2AuthorizationService(final JdbcClient jdbc) { + this.jdbc = jdbc; + } + + @Override + public void save(final OAuth2Authorization authorization) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (var oos = new ObjectOutputStream(out)) { + oos.writeObject(authorization); + } catch (IOException e) { + logger.log(System.Logger.Level.ERROR, "Failed to serialize OAuth2Authorization", e); + return; + } + byte[] bytes = out.toByteArray(); + + String state = nullIfEmpty(authorization.getAttribute(OAuth2ParameterNames.STATE)); + String codeToken = getToken(authorization, OAuth2AuthorizationCode.class); + String deviceCodeToken = getToken(authorization, OAuth2DeviceCode.class); + String userCodeToken = getToken(authorization, OAuth2UserCode.class); + String accessToken = getToken(authorization, OAuth2AccessToken.class); + String refreshToken = getToken(authorization, OAuth2RefreshToken.class); + String idToken = getToken(authorization, OidcIdToken.class); + jdbc.sql(""" + INSERT INTO v2_oauth2_authorization (id, serialized_data, state, code_token, device_code_token, user_code_token, access_token, refresh_token, id_token) + VALUES (:id, :serialized_data, :state, :code_token, :device_code_token, :user_code_token, :access_token, :refresh_token, :id_token) + ON DUPLICATE KEY UPDATE + serialized_data = VALUES(serialized_data), + state = VALUES(state), + code_token = VALUES(code_token), + device_code_token = VALUES(device_code_token), + user_code_token = VALUES(user_code_token), + access_token = VALUES(access_token), + refresh_token = VALUES(refresh_token), + id_token = VALUES(id_token) + """) + .param("id", authorization.getId()) + .param("serialized_data", bytes) + .param("state", state) + .param("code_token", codeToken) + .param("device_code_token", deviceCodeToken) + .param("user_code_token", userCodeToken) + .param("access_token", accessToken) + .param("refresh_token", refreshToken) + .param("id_token", idToken) + .update(); + } + + private static String getToken( + final OAuth2Authorization authorization, + final Class tokenType) + { + OAuth2Authorization.Token token = authorization.getToken(tokenType); + if (token != null) { + return token.getToken().getTokenValue(); + } else { + return null; + } + } + + private String nullIfEmpty(final String attribute) { + return attribute == null || attribute.isBlank() ? null : attribute; + } + + @Override + public void remove(final OAuth2Authorization authorization) { + final String id = authorization.getId(); + jdbc.sql("DELETE FROM v2_oauth2_authorization WHERE id = :id") + .param("id", id) + .update(); + } + + @Override + public OAuth2Authorization findById(final String id) { + Optional optionalBytes = jdbc.sql("SELECT serialized_data FROM v2_oauth2_authorization WHERE id = :id") + .param("id", id) + .query(byte[].class) + .optional(); + + if (optionalBytes.isEmpty()) { + return null; + } + + byte[] bytes = optionalBytes.get(); + return readOAuth2Authorization(bytes); + } + + private OAuth2Authorization readOAuth2Authorization(final byte[] bytes) { + try (var ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) { + Object object = ois.readObject(); + if (object instanceof OAuth2Authorization authorization) { + return authorization; + } else { + logger.log(System.Logger.Level.WARNING, "Garbage OAuth2Authorization found in database"); + } + } catch (IOException | ClassNotFoundException e) { + logger.log(System.Logger.Level.WARNING, "Failed to deserialize OAuth2Authorization", e); + } + return null; + } + + @Override + public OAuth2Authorization findByToken(final String token, final OAuth2TokenType tokenType) { + if (tokenType == null) { + return findByAnyToken(token); + } + String sql = switch (tokenType.getValue()) { + case "state" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE state = :token"; + case "code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE code_token = :token"; + case "device_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE device_code_token = :token"; + case "user_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE user_code_token = :token"; + case "access_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE access_token = :token"; + case "refresh_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE refresh_token = :token"; + case "id_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE id_token = :token"; + default -> throw new UnsupportedOperationException("Unknown token type: " + tokenType.getValue()); + }; + + return jdbc + .sql(sql) + .param("token", token) + .query(byte[].class) + .optional() + .map(this::readOAuth2Authorization) + .orElse(null); + } + + private OAuth2Authorization findByAnyToken(final String token) { + return jdbc.sql(""" + SELECT serialized_data FROM v2_oauth2_authorization + WHERE state = :token + OR code_token = :token + OR device_code_token = :token + OR user_code_token = :token + OR access_token = :token + OR refresh_token = :token + OR id_token = :token + """) + .param("token", token) + .query(byte[].class) + .optional() + .map(this::readOAuth2Authorization) + .orElse(null); + } +} diff --git a/src/main/resources/db/migration/V4__token_presistence.sql b/src/main/resources/db/migration/V4__token_presistence.sql new file mode 100644 index 0000000..8d0667d --- /dev/null +++ b/src/main/resources/db/migration/V4__token_presistence.sql @@ -0,0 +1,13 @@ +CREATE TABLE v2_oauth2_authorization +( + id varchar(100) NOT NULL, + serialized_data BLOB NOT NULL, + state TEXT, + code_token TEXT, + device_code_token TEXT, + user_code_token TEXT, + access_token TEXT, + refresh_token TEXT, + id_token TEXT, + PRIMARY KEY (id) +); -- 2.39.5