WIP: Submit transcoding jobs via a HTTP API #6

Draft
ansv7779 wants to merge 22 commits from api-submission into master
7 changed files with 237 additions and 21 deletions
Showing only changes of commit 6852ce6d50 - Show all commits

View File

@ -1,8 +1,10 @@
package se.su.dsv.whisperapi; package se.su.dsv.whisperapi;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.simple.JdbcClient; import org.springframework.jdbc.core.simple.JdbcClient;
import se.su.dsv.whisperapi.core.Job; import se.su.dsv.whisperapi.core.Job;
import se.su.dsv.whisperapi.core.JobCompletion; import se.su.dsv.whisperapi.core.JobCompletion;
import se.su.dsv.whisperapi.core.NotificationStatus;
import se.su.dsv.whisperapi.core.OutputFormat; import se.su.dsv.whisperapi.core.OutputFormat;
import se.su.dsv.whisperapi.core.TranscriptionRepository; import se.su.dsv.whisperapi.core.TranscriptionRepository;
import se.su.dsv.whisperapi.core.Transcription; import se.su.dsv.whisperapi.core.Transcription;
@ -10,6 +12,11 @@ import se.su.dsv.whisperapi.core.Transcription;
import java.net.URI; import java.net.URI;
import java.nio.file.Path; import java.nio.file.Path;
import java.security.Principal; import java.security.Principal;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -43,12 +50,7 @@ public class JDBCTranscriptionRepository implements TranscriptionRepository {
""") """)
.param("id", uuid) .param("id", uuid)
.param("owner", owner.getName()) .param("owner", owner.getName())
.query((rs, rowNum) -> { .query(TranscriptionRowMapper.INSTANCE)
UUID id = UUID.fromString(rs.getString("id"));
URI callbackUri = URI.create(rs.getString("callback_uri"));
OutputFormat outputFormat = OutputFormat.valueOf(rs.getString("output_format"));
return new Transcription(id, owner, callbackUri, outputFormat);
})
.optional(); .optional();
} }
@ -104,21 +106,7 @@ public class JDBCTranscriptionRepository implements TranscriptionRepository {
WHERE id = :id WHERE id = :id
""") """)
.param("id", jobId) .param("id", jobId)
.query((rs, rowNum) -> { .query(JobRowMapper.INSTANCE)
UUID id = UUID.fromString(rs.getString("id"));
String resultFileAbsolutePath = rs.getString("result_file_absolute_path");
if (!rs.wasNull()) {
return new Job(id, new Job.Status.Completed(new JobCompletion.Success(Path.of(resultFileAbsolutePath))));
}
String errorMessage = rs.getString("error_message");
if (!rs.wasNull()) {
return new Job(id, new Job.Status.Completed(new JobCompletion.Failure(errorMessage)));
}
return new Job(id, new Job.Status.Pending());
})
.optional(); .optional();
} }
@ -143,4 +131,116 @@ public class JDBCTranscriptionRepository implements TranscriptionRepository {
.update(); .update();
} }
} }
@Override
public List<Transcription> getProcessingTranscriptions() {
return jdbcClient.sql("""
SELECT id, owner, callback_uri, output_format
FROM transcriptions
WHERE notification_success = FALSE
AND id IN (
SELECT transcription_id
FROM jobs
)
""")
.query(TranscriptionRowMapper.INSTANCE)
.list();
}
@Override
public List<Job> getJobs(Transcription transcription) {
return jdbcClient.sql("""
SELECT id, result_file_absolute_path, error_message
FROM jobs
WHERE transcription_id = :transcription_id
""")
.param("transcription_id", transcription.id())
.query(JobRowMapper.INSTANCE)
.list();
}
@Override
public NotificationStatus getNotificationStatus(Transcription transcription) {
return jdbcClient.sql("""
SELECT last_notification_time, notification_attempts
FROM transcriptions
WHERE id = :id
""")
.param("id", transcription.id())
.query((rs, rowNum) -> {
Timestamp lastNotificationTime = rs.getTimestamp("last_notification_time");
int notificationAttempts = rs.getInt("notification_attempts");
if (notificationAttempts == 0) {
return new NotificationStatus.Never();
} else {
return new NotificationStatus.Failed(lastNotificationTime.toInstant(), notificationAttempts);
}
})
.single();
}
@Override
public void markAsCompleted(Transcription transcription) {
jdbcClient.sql("""
UPDATE transcriptions
SET notification_successful = true
WHERE id = :id
""")
.param("id", transcription.id())
.update();
}
@Override
public void increaseFailureCount(Transcription transcription, Instant now) {
jdbcClient.sql("""
UPDATE transcriptions
SET last_notification_time = :now,
notification_attempts = notification_attempts + 1
WHERE id = :id
""")
.param("now", Time.from(now))
.param("id", transcription.id())
.update();
}
private enum TranscriptionRowMapper implements RowMapper<Transcription> {
INSTANCE;
@Override
public Transcription mapRow(ResultSet rs, int rowNum) throws SQLException {
UUID id = UUID.fromString(rs.getString("id"));
Principal owner = new SimplePrincipal(rs.getString("owner"));
URI callbackUri = URI.create(rs.getString("callback_uri"));
OutputFormat outputFormat = OutputFormat.valueOf(rs.getString("output_format"));
return new Transcription(id, owner, callbackUri, outputFormat);
}
}
private enum JobRowMapper implements RowMapper<Job> {
INSTANCE;
@Override
public Job mapRow(ResultSet rs, int rowNum) throws SQLException {
UUID id = UUID.fromString(rs.getString("id"));
return new Job(id, getStatus(rs));
}
private Job.Status getStatus(ResultSet rs) throws SQLException {
String resultFileAbsolutePath = rs.getString("result_file_absolute_path");
if (!rs.wasNull()) {
return new Job.Status.Completed(new JobCompletion.Success(Path.of(resultFileAbsolutePath)));
}
String errorMessage = rs.getString("error_message");
if (!rs.wasNull()) {
return new Job.Status.Completed(new JobCompletion.Failure(errorMessage));
}
return new Job.Status.Pending();
}
}
record SimplePrincipal(String getName) implements Principal {
}
} }

View File

@ -0,0 +1,19 @@
package se.su.dsv.whisperapi;
import org.springframework.scheduling.annotation.Scheduled;
import se.su.dsv.whisperapi.core.TranscriptionService;
import java.util.concurrent.TimeUnit;
public class SendOutCallbacksJob {
private final TranscriptionService transcriptionService;
public SendOutCallbacksJob(TranscriptionService transcriptionService) {
this.transcriptionService = transcriptionService;
}
@Scheduled(fixedRate = 5, timeUnit = TimeUnit.SECONDS)
public void sendOutCallbacks() {
transcriptionService.checkForCompletedTranscriptions();
}
}

View File

@ -6,6 +6,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.jdbc.core.simple.JdbcClient; import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.security.config.Customizer; import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.SecurityFilterChain;
@ -14,6 +15,7 @@ import se.su.dsv.whisperapi.core.TranscriptionService;
@SpringBootApplication @SpringBootApplication
@EnableConfigurationProperties(WhisperFrontendConfiguration.class) @EnableConfigurationProperties(WhisperFrontendConfiguration.class)
@EnableScheduling
public class WhisperApiApplication { public class WhisperApiApplication {
public static void main(String[] args) { public static void main(String[] args) {
@ -54,4 +56,9 @@ public class WhisperApiApplication {
public JDBCTranscriptionRepository jdbcTransactionRepository(JdbcClient jdbcClient) { public JDBCTranscriptionRepository jdbcTransactionRepository(JdbcClient jdbcClient) {
return new JDBCTranscriptionRepository(jdbcClient); return new JDBCTranscriptionRepository(jdbcClient);
} }
@Bean
public SendOutCallbacksJob sendOutCallbacksJob(TranscriptionService transcriptionService) {
return new SendOutCallbacksJob(transcriptionService);
}
} }

View File

@ -0,0 +1,8 @@
package se.su.dsv.whisperapi.core;
import java.time.Instant;
public sealed interface NotificationStatus {
record Never() implements NotificationStatus {}
record Failed(Instant lastAttempt, int numberOfAttempts) implements NotificationStatus {}
}

View File

@ -1,6 +1,7 @@
package se.su.dsv.whisperapi.core; package se.su.dsv.whisperapi.core;
import java.security.Principal; import java.security.Principal;
import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -26,4 +27,14 @@ public interface TranscriptionRepository {
Optional<Job> findJobById(UUID jobId); Optional<Job> findJobById(UUID jobId);
void setJobCompleted(Job job, JobCompletion jobCompletion); void setJobCompleted(Job job, JobCompletion jobCompletion);
List<Transcription> getProcessingTranscriptions();
List<Job> getJobs(Transcription transcription);
NotificationStatus getNotificationStatus(Transcription transcription);
void markAsCompleted(Transcription transcription);
void increaseFailureCount(Transcription transcription, Instant now);
} }

View File

@ -7,15 +7,23 @@ import com.fasterxml.jackson.databind.SerializationFeature;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.net.URI; import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.security.Principal; import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
public class TranscriptionService { public class TranscriptionService {
private static final System.Logger LOGGER = System.getLogger(TranscriptionService.class.getName());
private static final Duration INITIAL_NOTIFICATION_DELAY = Duration.ofMinutes(15);
private final TranscriptionRepository transcriptionRepository; private final TranscriptionRepository transcriptionRepository;
private final Path fileDirectory; private final Path fileDirectory;
private final Path jobsDirectory; private final Path jobsDirectory;
@ -108,4 +116,63 @@ public class TranscriptionService {
} }
transcriptionRepository.setJobCompleted(job, jobCompletion); transcriptionRepository.setJobCompleted(job, jobCompletion);
} }
public void checkForCompletedTranscriptions() {
Instant now = Instant.now();
List<Transcription> processing = transcriptionRepository.getProcessingTranscriptions();
for (Transcription transcription : processing) {
List<Job> jobs = transcriptionRepository.getJobs(transcription);
boolean allJobsCompleted = jobs.stream()
.allMatch(Job::isCompleted);
if (allJobsCompleted && shouldNotifyOwner(transcription, now)) {
boolean notificationSuccessful = notifyOwner(transcription, jobs);
if (notificationSuccessful) {
markTranscriptionAsCompleted(transcription);
}
else {
increaseFailureCount(transcription, now);
}
}
}
}
private boolean shouldNotifyOwner(Transcription transcription, Instant now) {
NotificationStatus notificationStatus = transcriptionRepository.getNotificationStatus(transcription);
return switch (notificationStatus) {
case NotificationStatus.Never() -> true;
case NotificationStatus.Failed(Instant lastAttempt, int numberOfAttempts) -> {
int delayMultiplier = (int) Math.pow(2, numberOfAttempts - 1); // double the delay each time
Duration delay = INITIAL_NOTIFICATION_DELAY.multipliedBy(delayMultiplier);
yield now.isAfter(lastAttempt.plus(delay));
}
};
}
private boolean notifyOwner(final Transcription transcription, List<Job> jobs) {
URI callbackUri = transcription.callbackUri();
try (HttpClient client = HttpClient.newHttpClient()) {
HttpRequest request = HttpRequest.newBuilder()
.uri(callbackUri)
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString("""
{
"id": "%s",
"status": "completed"
}"""))
.build();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
return response.statusCode() == 200;
} catch (IOException | InterruptedException e) {
LOGGER.log(System.Logger.Level.ERROR, "Failed to notify owner", e);
return false;
}
}
private void increaseFailureCount(Transcription transcription, Instant now) {
transcriptionRepository.increaseFailureCount(transcription, now);
}
private void markTranscriptionAsCompleted(Transcription transcription) {
transcriptionRepository.markAsCompleted(transcription);
}
} }

View File

@ -0,0 +1,4 @@
ALTER TABLE transcriptions
ADD COLUMN notification_success BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN last_notification_time DATETIME DEFAULT NULL,
ADD COLUMN notification_attempts INT NOT NULL DEFAULT 0;