Persist tokens between restarts #3
@ -4,6 +4,7 @@ import org.springframework.context.annotation.Bean;
|
|||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.Profile;
|
import org.springframework.context.annotation.Profile;
|
||||||
import org.springframework.jdbc.core.simple.JdbcClient;
|
import org.springframework.jdbc.core.simple.JdbcClient;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Profile("!embedded")
|
@Profile("!embedded")
|
||||||
@ -12,4 +13,9 @@ public class PersistentConfiguration {
|
|||||||
public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) {
|
public JDBCClientRepository jdbcClientRepository(JdbcClient jdbcClient) {
|
||||||
return new JDBCClientRepository(jdbcClient);
|
return new JDBCClientRepository(jdbcClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public OAuth2AuthorizationService authorizationService(JdbcClient jdbcClient) {
|
||||||
|
return new SerializingJDBCOAuth2AuthorizationService(jdbcClient);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,172 @@
|
|||||||
|
package se.su.dsv.oauth2;
|
||||||
|
|
||||||
|
import org.springframework.jdbc.core.simple.JdbcClient;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2Token;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2UserCode;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/// Stores [OAuth2Authorization] by first using Java serialization to convert it to a byte array,
|
||||||
|
/// and then storing that byte array in the database.
|
||||||
|
class SerializingJDBCOAuth2AuthorizationService implements OAuth2AuthorizationService {
|
||||||
|
private final System.Logger logger = System.getLogger(this.getClass().getName());
|
||||||
|
|
||||||
|
private final JdbcClient jdbc;
|
||||||
|
|
||||||
|
SerializingJDBCOAuth2AuthorizationService(final JdbcClient jdbc) {
|
||||||
|
this.jdbc = jdbc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(final OAuth2Authorization authorization) {
|
||||||
|
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||||
|
try (var oos = new ObjectOutputStream(out)) {
|
||||||
|
oos.writeObject(authorization);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.log(System.Logger.Level.ERROR, "Failed to serialize OAuth2Authorization", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
byte[] bytes = out.toByteArray();
|
||||||
|
|
||||||
|
String state = nullIfEmpty(authorization.getAttribute(OAuth2ParameterNames.STATE));
|
||||||
|
String codeToken = getToken(authorization, OAuth2AuthorizationCode.class);
|
||||||
|
String deviceCodeToken = getToken(authorization, OAuth2DeviceCode.class);
|
||||||
|
String userCodeToken = getToken(authorization, OAuth2UserCode.class);
|
||||||
|
String accessToken = getToken(authorization, OAuth2AccessToken.class);
|
||||||
|
String refreshToken = getToken(authorization, OAuth2RefreshToken.class);
|
||||||
|
String idToken = getToken(authorization, OidcIdToken.class);
|
||||||
|
jdbc.sql("""
|
||||||
|
INSERT INTO v2_oauth2_authorization (id, serialized_data, state, code_token, device_code_token, user_code_token, access_token, refresh_token, id_token)
|
||||||
|
VALUES (:id, :serialized_data, :state, :code_token, :device_code_token, :user_code_token, :access_token, :refresh_token, :id_token)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
serialized_data = VALUES(serialized_data),
|
||||||
|
state = VALUES(state),
|
||||||
|
code_token = VALUES(code_token),
|
||||||
|
device_code_token = VALUES(device_code_token),
|
||||||
|
user_code_token = VALUES(user_code_token),
|
||||||
|
access_token = VALUES(access_token),
|
||||||
|
refresh_token = VALUES(refresh_token),
|
||||||
|
id_token = VALUES(id_token)
|
||||||
|
""")
|
||||||
|
.param("id", authorization.getId())
|
||||||
|
.param("serialized_data", bytes)
|
||||||
|
.param("state", state)
|
||||||
|
.param("code_token", codeToken)
|
||||||
|
.param("device_code_token", deviceCodeToken)
|
||||||
|
.param("user_code_token", userCodeToken)
|
||||||
|
.param("access_token", accessToken)
|
||||||
|
.param("refresh_token", refreshToken)
|
||||||
|
.param("id_token", idToken)
|
||||||
|
.update();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T extends OAuth2Token> String getToken(
|
||||||
|
final OAuth2Authorization authorization,
|
||||||
|
final Class<T> tokenType)
|
||||||
|
{
|
||||||
|
OAuth2Authorization.Token<T> token = authorization.getToken(tokenType);
|
||||||
|
if (token != null) {
|
||||||
|
return token.getToken().getTokenValue();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String nullIfEmpty(final String attribute) {
|
||||||
|
return attribute == null || attribute.isBlank() ? null : attribute;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void remove(final OAuth2Authorization authorization) {
|
||||||
|
final String id = authorization.getId();
|
||||||
|
jdbc.sql("DELETE FROM v2_oauth2_authorization WHERE id = :id")
|
||||||
|
.param("id", id)
|
||||||
|
.update();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OAuth2Authorization findById(final String id) {
|
||||||
|
Optional<byte[]> optionalBytes = jdbc.sql("SELECT serialized_data FROM v2_oauth2_authorization WHERE id = :id")
|
||||||
|
.param("id", id)
|
||||||
|
.query(byte[].class)
|
||||||
|
.optional();
|
||||||
|
|
||||||
|
if (optionalBytes.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] bytes = optionalBytes.get();
|
||||||
|
return readOAuth2Authorization(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2Authorization readOAuth2Authorization(final byte[] bytes) {
|
||||||
|
try (var ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
|
||||||
|
Object object = ois.readObject();
|
||||||
|
if (object instanceof OAuth2Authorization authorization) {
|
||||||
|
return authorization;
|
||||||
|
} else {
|
||||||
|
logger.log(System.Logger.Level.WARNING, "Garbage OAuth2Authorization found in database");
|
||||||
|
}
|
||||||
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
|
logger.log(System.Logger.Level.WARNING, "Failed to deserialize OAuth2Authorization", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OAuth2Authorization findByToken(final String token, final OAuth2TokenType tokenType) {
|
||||||
|
if (tokenType == null) {
|
||||||
|
return findByAnyToken(token);
|
||||||
|
}
|
||||||
|
String sql = switch (tokenType.getValue()) {
|
||||||
|
case "state" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE state = :token";
|
||||||
|
case "code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE code_token = :token";
|
||||||
|
case "device_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE device_code_token = :token";
|
||||||
|
case "user_code" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE user_code_token = :token";
|
||||||
|
case "access_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE access_token = :token";
|
||||||
|
case "refresh_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE refresh_token = :token";
|
||||||
|
case "id_token" -> "SELECT serialized_data FROM v2_oauth2_authorization WHERE id_token = :token";
|
||||||
|
default -> throw new UnsupportedOperationException("Unknown token type: " + tokenType.getValue());
|
||||||
|
};
|
||||||
|
|
||||||
|
return jdbc
|
||||||
|
.sql(sql)
|
||||||
|
.param("token", token)
|
||||||
|
.query(byte[].class)
|
||||||
|
.optional()
|
||||||
|
.map(this::readOAuth2Authorization)
|
||||||
|
.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2Authorization findByAnyToken(final String token) {
|
||||||
|
return jdbc.sql("""
|
||||||
|
SELECT serialized_data FROM v2_oauth2_authorization
|
||||||
|
WHERE state = :token
|
||||||
|
OR code_token = :token
|
||||||
|
OR device_code_token = :token
|
||||||
|
OR user_code_token = :token
|
||||||
|
OR access_token = :token
|
||||||
|
OR refresh_token = :token
|
||||||
|
OR id_token = :token
|
||||||
|
""")
|
||||||
|
.param("token", token)
|
||||||
|
.query(byte[].class)
|
||||||
|
.optional()
|
||||||
|
.map(this::readOAuth2Authorization)
|
||||||
|
.orElse(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
13
src/main/resources/db/migration/V4__token_presistence.sql
Normal file
13
src/main/resources/db/migration/V4__token_presistence.sql
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
CREATE TABLE v2_oauth2_authorization
|
||||||
|
(
|
||||||
|
id varchar(100) NOT NULL,
|
||||||
|
serialized_data BLOB NOT NULL,
|
||||||
|
state TEXT,
|
||||||
|
code_token TEXT,
|
||||||
|
device_code_token TEXT,
|
||||||
|
user_code_token TEXT,
|
||||||
|
access_token TEXT,
|
||||||
|
refresh_token TEXT,
|
||||||
|
id_token TEXT,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
);
|
||||||
Loading…
x
Reference in New Issue
Block a user