Persist tokens between restarts

Utilize Java serialization to turn the entire OAuth2Authorization to a binary blob and store that in the database. Could not find a better way to do it given the types involved (like Map<String, Object> properties). Sure, Java serialization can fail on arbitrary objects but hopefully since OAuth2Authorization implements java.io.Serializable any properties put in are serializable as well.
This commit is contained in:
Andreas Svanberg 2025-03-28 11:58:35 +01:00
parent f0947c5ff8
commit 9a6e21a396
Signed by: ansv7779
GPG Key ID: 729B051CFFD42F92
3 changed files with 191 additions and 0 deletions

@ -4,6 +4,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@Configuration
@Profile("!embedded")
@ -12,4 +13,9 @@ public class PersistentConfiguration {
public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) {
return new JDBCClientRepository(jdbcClient);
}
@Bean
public OAuth2AuthorizationService authorizationService(JdbcClient jdbcClient) {
return new SerializingJDBCOAuth2AuthorizationService(jdbcClient);
}
}

@ -0,0 +1,172 @@
package se.su.dsv.oauth2;
import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2UserCode;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Optional;
/// Stores [OAuth2Authorization] by first using Java serialization to convert it to a byte array,
/// and then storing that byte array in the database.
class SerializingJDBCOAuth2AuthorizationService implements OAuth2AuthorizationService {
private final System.Logger logger = System.getLogger(this.getClass().getName());
private final JdbcClient jdbc;
SerializingJDBCOAuth2AuthorizationService(final JdbcClient jdbc) {
this.jdbc = jdbc;
}
@Override
public void save(final OAuth2Authorization authorization) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (var oos = new ObjectOutputStream(out)) {
oos.writeObject(authorization);
} catch (IOException e) {
logger.log(System.Logger.Level.ERROR, "Failed to serialize OAuth2Authorization", e);
return;
}
byte[] bytes = out.toByteArray();
String state = nullIfEmpty(authorization.getAttribute(OAuth2ParameterNames.STATE));
String codeToken = getToken(authorization, OAuth2AuthorizationCode.class);
String deviceCodeToken = getToken(authorization, OAuth2DeviceCode.class);
String userCodeToken = getToken(authorization, OAuth2UserCode.class);
String accessToken = getToken(authorization, OAuth2AccessToken.class);
String refreshToken = getToken(authorization, OAuth2RefreshToken.class);
String idToken = getToken(authorization, OidcIdToken.class);
jdbc.sql("""
INSERT INTO v2_oauth2_authorization (id, serialized_data, state, code_token, device_code_token, user_code_token, access_token, refresh_token, id_token)
VALUES (:id, :serialized_data, :state, :code_token, :device_code_token, :user_code_token, :access_token, :refresh_token, :id_token)
ON DUPLICATE KEY UPDATE
serialized_data = VALUES(serialized_data),
state = VALUES(state),
code_token = VALUES(code_token),
device_code_token = VALUES(device_code_token),
user_code_token = VALUES(user_code_token),
access_token = VALUES(access_token),
refresh_token = VALUES(refresh_token),
id_token = VALUES(id_token)
""")
.param("id", authorization.getId())
.param("serialized_data", bytes)
.param("state", state)
.param("code_token", codeToken)
.param("device_code_token", deviceCodeToken)
.param("user_code_token", userCodeToken)
.param("access_token", accessToken)
.param("refresh_token", refreshToken)
.param("id_token", idToken)
.update();
}
private static <T extends OAuth2Token> String getToken(
final OAuth2Authorization authorization,
final Class<T> tokenType)
{
OAuth2Authorization.Token<T> token = authorization.getToken(tokenType);
if (token != null) {
return token.getToken().getTokenValue();
} else {
return null;
}
}
private String nullIfEmpty(final String attribute) {
return attribute == null || attribute.isBlank() ? null : attribute;
}
@Override
public void remove(final OAuth2Authorization authorization) {
final String id = authorization.getId();
jdbc.sql("DELETE FROM v2_oauth2_authorization WHERE id = :id")
.param("id", id)
.update();
}
@Override
public OAuth2Authorization findById(final String id) {
Optional<byte[]> optionalBytes = jdbc.sql("SELECT serialized_data FROM v2_oauth2_authorization WHERE id = :id")
.param("id", id)
.query(byte[].class)
.optional();
if (optionalBytes.isEmpty()) {
return null;
}
byte[] bytes = optionalBytes.get();
return readOAuth2Authorization(bytes);
}
private OAuth2Authorization readOAuth2Authorization(final byte[] bytes) {
try (var ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
Object object = ois.readObject();
if (object instanceof OAuth2Authorization authorization) {
return authorization;
} else {
logger.log(System.Logger.Level.WARNING, "Garbage OAuth2Authorization found in database");
}
} catch (IOException | ClassNotFoundException e) {
logger.log(System.Logger.Level.WARNING, "Failed to deserialize OAuth2Authorization", e);
}
return null;
}
@Override
public OAuth2Authorization findByToken(final String token, final OAuth2TokenType tokenType) {
if (tokenType == null) {
return findByAnyToken(token);
}
String sql = switch (tokenType.getValue()) {
case "state" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE state = :token";
case "code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE code_token = :token";
case "device_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE device_code_token = :token";
case "user_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE user_code_token = :token";
case "access_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE access_token = :token";
case "refresh_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE refresh_token = :token";
case "id_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE id_token = :token";
default -> throw new UnsupportedOperationException("Unknown token type: " + tokenType.getValue());
};
return jdbc
.sql(sql)
.param("token", token)
.query(byte[].class)
.optional()
.map(this::readOAuth2Authorization)
.orElse(null);
}
private OAuth2Authorization findByAnyToken(final String token) {
return jdbc.sql("""
SELECT serialized_data FROM v2_oauth2_authorization
WHERE state = :token
OR code_token = :token
OR device_code_token = :token
OR user_code_token = :token
OR access_token = :token
OR refresh_token = :token
OR id_token = :token
""")
.param("token", token)
.query(byte[].class)
.optional()
.map(this::readOAuth2Authorization)
.orElse(null);
}
}

@ -0,0 +1,13 @@
CREATE TABLE v2_oauth2_authorization
(
id varchar(100) NOT NULL,
serialized_data BLOB NOT NULL,
state TEXT,
code_token TEXT,
device_code_token TEXT,
user_code_token TEXT,
access_token TEXT,
refresh_token TEXT,
id_token TEXT,
PRIMARY KEY (id)
);