From 278fe30b0221df6655891d60a5b04757c369915f Mon Sep 17 00:00:00 2001 From: Erik Thuning <boooink@gmail.com> Date: Mon, 27 Mar 2023 16:36:27 +0200 Subject: [PATCH] Refactored subtitles handling to include generation via whisper. --- README.md | 87 ++++++++++---- config.ini.example | 8 ++ pipeline/handlers/handler.py | 2 + pipeline/handlers/subtitles.py | 201 +++++++++++++++++++++++++++------ pipeline/package.py | 40 ++++++- pipeline/utils.py | 12 +- requirements.txt | 1 + test.py | 120 +++++++++++++------- 8 files changed, 357 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 5f7e994..391c37a 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,10 @@ Valid top-level keys and their expected values are: A JSON object representing the subtitles to be acted upon. Its format is documented in detail under the heading [Subtitles](#subtitles). + * `generate_subtitles`: `JSON` + A JSON object representing subtitle tracks to be generated. Its format is + documented in detail under the heading [Subtitles](#subtitles). + * `sources`: `JSON` A JSON object representing the sources to be acted upon. Its format is documented in detail under the heading [Job Sources](#job-sources). @@ -92,29 +96,52 @@ Valid top-level keys and their expected values are: ### Subtitles -The subtitles object consists of a number of pre-defined keys, each -corresponding to a possible subtitles track for the presentation. +There are two top-level keys that deal with subtitles: `subtitles` and +`generate_subtitles`. The `subtitles` object is a simple key-value map, +mapping subtitle track names to subtitle files to be stored. The +`generate_subtitles` object maps subtitle track names to generation tasks. +Keys must be unique across these two maps. -Allowed keys are: +If the value for a given key in `subtitles` is `null`, that track is deleted +from the presentation. Non-null values must be files located +under `upload_dir`. - * `Svenska`: `string` - The path to a swedish subtitles file in WEBVTT format. Relative to the - value of `upload_dir`. +Any subtitle tracks that exist in the presentation but are omitted in the job +specification are left unmodified. - * `English`: `string` - The path to an english subtitles file in WEBVTT format. Relative to the - value of `upload_dir`. +Values in `generate_subtitles` are objects with the following structure: - * `Svenska (Whisper)`: `bool` - A boolean indicating whether to generate swedish subtitles using the - Whisper auto-transcription tool. **Not yet implemented** + * `type`: `string` + The transcription engine to be used for generating the subtitles. Currently + only supports the value "whisper". - * `English (Whisper)`: `bool` - A boolean indicating whether to generate english subtitles using the - Whisper auto-transcription tool. **Not yet implemented** + * `source`: `string` + The name of one of this presentation's video streams, which will be used to + generate the subtitle track. Should preferrably use a stream with a camera + feed for best synchronization results. The stream may either be an already + existing one or one created by this job. -If a key is omitted, no action is taken on that track. If a key is included -but its value is falsy, the track will be deleted. +Here is an example of valid `subtitles` and `generate_subtitles` sections: +```json +{ + "subtitles": { + "English": "path/to/subs.vtt", + "Svenska": null + }, + "generate_subtitles": { + "Generated": { + "type": "whisper", + "source": "camera" + } + } +} +``` +This example would save the provided WEBVTT file for the "English" track, +generate subtitles for the "Generated" track based on the given source and +delete the "Svenska" track. Note that the source "camera" must exist (either +in this job specification or in the package already on disk in case of an +update), and `upload_dir` must be provided in the job specification in order +to be able to resolve the path to the English subtitle track. ### Job Sources @@ -192,13 +219,13 @@ Empty string and list values will simply overwrite any stored data with the empty values. There are some keys where this isn't the case, they are documented above. -In order to delete a stream, pass an empty object (`{}`) as the value of the +In order to delete a stream, pass `null` as the value of the appropriate stream name key: ```json { "pkg_id": "some_package_id", "sources": { - "sourcename": {} + "sourcename": null } } ``` @@ -228,9 +255,13 @@ This is a job specification that has all keys and values: "thumb": "mythumb.jpg", "subtitles": { "English": "en.vtt", - "Swedish": "swedishsubs.vtt", - "English (Whisper)": true, - "Swedish (Whisper)": false + "Swedish": "swedishsubs.vtt" + }, + "generate_subtitles": { + "Generated": { + "type": "whisper", + "source": "main" + } }, "sources": { "main": { @@ -252,7 +283,7 @@ This is a job specification that has all keys and values: ``` This job specification creates a new package, letting the daemon generate -thumbnail and posters: +thumbnail, subtitles and posters: ```json { @@ -268,7 +299,12 @@ thumbnail and posters: "courses": ["IDSV", "PROG1"], "tags": ["programming", "python"], "thumbnail": "", - "subtitles": "mysubs.vtt", + "generate_subtitles": { + "Generated": { + "type": "whisper", + "source": "main" + } + }, "sources": { "main": { "video": "videos/myvideo.mp4", @@ -301,6 +337,9 @@ An update job making some changes that don't involve uploads: "created": 1665151669, "presenters": ["A Person", "ausername"], "courses": [], + "subtitles": { + "Generated": null + } } ``` diff --git a/config.ini.example b/config.ini.example index d3dc902..e003a50 100644 --- a/config.ini.example +++ b/config.ini.example @@ -59,6 +59,14 @@ url = ldaps://ldap.example.com base_dn = dc=example,dc=com +[SubtitlesHandler] +# The whisper model to use for subtitle generation +whispermodel = 'large-v2' + +# Where to store model data +modeldir = /some/path + + [ThumbnailHandler] # The base image to use when creating presentation thumbnails baseimage = /path/to/template.png diff --git a/pipeline/handlers/handler.py b/pipeline/handlers/handler.py index 9d332ea..dd5a868 100644 --- a/pipeline/handlers/handler.py +++ b/pipeline/handlers/handler.py @@ -1,3 +1,5 @@ +import logging + from abc import ABCMeta, abstractmethod from os import mkdir, path from shutil import rmtree diff --git a/pipeline/handlers/subtitles.py b/pipeline/handlers/subtitles.py index 06b6339..545106b 100644 --- a/pipeline/handlers/subtitles.py +++ b/pipeline/handlers/subtitles.py @@ -1,63 +1,192 @@ +import logging + from os import path, rename, remove from .handler import Handler from ..exceptions import ValidationException +import whisper +import whisper.utils -''' -This class handles package subtitles. -In order to accept a job as valid, the following fields are required: -pkg_id, subtitles +def _do_whisper_transcribe(inpath, outpath, modelname, modeldir): + """ + Transcribe the given file at 'inpath' to a VTT file at 'outpath' + using the Whisper engine. + + Should be called asynchronously. + """ + + logger = logging.getLogger('play-daemon.SubtitlesHandler.Whisper') + logger.info(f"Starting whisper transcription job for {inpath}") + whisperModel = whisper.load_model( + modelname, + download_root=modeldir) + result = whisper.transcribe(whisperModel, inpath) + language = result['language'] + logger.info(f"Detected language '{language}' in '{inpath}'") + vttWriter = whisper.utils.WriteVTT(path.dirname(outpath)) + vttWriter.always_include_hours = True + with open(outpath, 'w') as f: + vttWriter.write_result(result, f) + logger.info(f"Finished whisper transcription job for {inpath}") -The 'subtitles' field may be either a filename or a falsy value. If the -'subtitles' field is not falsy, the field 'upload_dir' must also exist. -The value of 'upload_dir' should be the absolute path to a directory -containing the file indicated by 'subtitles'. -''' @Handler.register class SubtitlesHandler(Handler): + """ + This class handles package subtitles. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def wants(self, jobspec, existing_package): - if 'subtitles' in jobspec: + """ + Return True if this handler wants to process this jobspec. + Raises an exception if the job is wanted but doesn't pass validation. + + A job is wanted if the job specification contains a 'subtitles' or a + 'generate_subtitles' key. + """ + if 'subtitles' in jobspec or 'generate_subtitles' in jobspec: return self._validate(jobspec, existing_package) return False def _validate(self, jobspec, existing_package): + """ + Return True if the job is valid for this handler. + + Validity requirements are: + - Keys in 'subtitles' and 'generate_subtitles' must be + mututally unique. + - If any value in the 'subtitles' object is not None, the job must + contain an 'upload_dir' key which must point to an + existing directory. + - All 'subtitles' values that are not None must be existing files + under 'upload_dir'. + - All 'source' values in subtitle generation specifications must be a + valid source name, either one that already exists or one provided + under 'sources' in this job. + - The 'type' value in subtitle generation specifications must be a + supported generator (currently only 'whisper'). + """ super()._validate(jobspec, existing_package) - for name, subsfile in jobspec['subtitles'].items(): - if "Whisper" in name: + # Check for duplicate track names + generate_names = jobspec.get('generate_subtitles', {}).keys() + store_names = jobspec.get('subtitles', {}).keys() + common_names = generate_names & store_names + if common_names: + names_string = ', '.join(common_names) + raise ValidationException( + f"Duplicate subtitle track name(s): {names_string}") + + # Validate generation tasks + for name, genspec in jobspec.get('generate_subtitles', {}).items(): + if genspec['type'] != 'whisper': + raise ValidationException( + "Unsupported subtitle generation job type: " + f"{genspec['type']}") + expected_source = genspec['source'] + jobsource = jobspec.get('sources', {}).get(expected_source, {}) + existing_source = existing_package['sources'].get(expected_source, + {}) + if 'video' not in jobsource and 'video' not in existing_source: + raise ValidationException(f"Subtitle track '{name}' refers " + "to a missing source: " + f"{expected_source}") + + # Validate storage tasks + for name, subsfile in jobspec.get('subtitles', {}).items(): + if subsfile is None: continue - if subsfile: - if 'upload_dir' not in jobspec: - raise ValidationException(f"upload_dir missing") - subspath = path.join(jobspec['upload_dir'], subsfile) - if not path.isfile(subspath): - raise ValidationException( - f"{subspath} is not a valid file") + if 'upload_dir' not in jobspec: + raise ValidationException("upload_dir missing") + subspath = path.join(jobspec['upload_dir'], subsfile) + if not path.isfile(subspath): + raise ValidationException( + f"{subspath} is not a valid file") return True + def _do_whisper_transcribe(self): + """ + Transcribe the given file at 'inpath' to a VTT file at 'outpath' + using the Whisper engine. + + Should be called asynchronously. + """ + logger = self.logger + config = self.config + def func(inpath, outpath): + logger.info(f"Starting whisper transcription job for {inpath}") + whisperModel = whisper.load_model( + config['whispermodel'], + download_root=config['modeldir']) + result = whisper.transcribe(whisperModel, inpath) + language = result['language'] + logger.info(f"Detected language '{language}' in '{inpath}'") + vttWriter = whisper.utils.WriteVTT(path.dirname(outpath)) + vttWriter.always_include_hours = True + with open(outpath, 'w') as f: + vttWriter.write_result(result, f) + logger.info(f"Finished whisper transcription job for {inpath}") + return func + def _handle(self, jobspec, existing_package, tempdir): + """ + Return a function to apply changes to the stored package. + + Any subtitle generation tasks are run before apply_func is returned. + The returned function moves subtitle files into the package's basedir + and updates the package metadata. + Replaced subtitle tracks are deleted. + """ + + transcribes = [] + resultfiles = {} + + for name, item in jobspec.get('generate_subtitles', {}).items(): + sourcename = item['source'] + sourcepath = None + source_from_job = jobspec.get('sources', {}).get(sourcename, {}) + file_from_job = source_from_job.get('video', None) + if file_from_job: + sourcepath = path.join(jobspec['upload_dir'], file_from_job) + else: + existing_source = existing_package['sources'].get(name, {}) + # Sorting the available resolutions in ascending numeric order + resolutions_sorted = sorted(existing_source['video'].keys(), + key=int)[0] + # Picking the smallest resolution, + # since the sound quality is always the same + sourcepath = path.join( + existing_package.basedir, + existing_source['video'][resolutions_sorted[0]]) + + outpath = path.join(tempdir, f"{sourcename}.vtt") + + transcribe = self.asyncjob(_do_whisper_transcribe, + (sourcepath, + outpath, + self.config['whispermodel'], + self.config['modeldir'])) + transcribes.append(transcribe) + resultfiles[name] = outpath + + self.logger.debug("Waiting for transcribes") + # Wait for transcribes to finish + for item in transcribes: + item.wait() + self.logger.debug("Done, making apply_func") + def apply_func(package): - for name, subsfile in jobspec['subtitles'].items(): - if "Whisper" in name: - # TODO: Implement Whisper subs generation - continue - elif subsfile: - # Save the passed file, replacing any existing - existing = package['subtitles'].get(name, None) - if existing: - remove(path.join(package.basedir, existing)) - rename(path.join(jobspec['upload_dir'], subsfile), - path.join(package.basedir, subsfile)) - package['subtitles'][name] = subsfile - else: - # Falsy value, delete existing - if package['subtitles'].get(name, None): - remove(path.join(package.basedir, - package['subtitles'].pop(name))) + for name, subsfile in jobspec.get('subtitles', {}).items(): + subspath = None + if subsfile is not None: + subspath = path.join(jobspec['upload_dir'], subsfile) + package.set_subtitle_track(name, subspath) + + for name, subspath in resultfiles.items(): + package.set_subtitle_track(name, subspath) return apply_func diff --git a/pipeline/package.py b/pipeline/package.py index 6084884..687a3c0 100644 --- a/pipeline/package.py +++ b/pipeline/package.py @@ -47,11 +47,17 @@ class Package: and self.basedir == other.basedir and self._contents == other._contents) - ''' - Set duration by longest video file - (they should all be the same length though) - ''' + def get(self, key, default=None): + if key in self._contents: + return self._contents[key] + else: + return default + def set_duration(self): + """ + Set duration by longest video file + (they should all be the same length though) + """ durations = [] for source in self._contents['sources'].values(): for video in source['video'].values(): @@ -64,6 +70,32 @@ class Package: print(e.stderr) self['duration'] = max(durations) + def set_subtitle_track(self, name, inpath): + """ + Set the subtitle track indicated by 'name' to the file at 'inpath'. + + If 'inpath' is None, the track is removed. + """ + subtitles = self._contents['subtitles'] + + if not inpath: + # Delete this track + if subtitles.get(name, None): + remove(path.join(self.basedir, + subtitles.pop(name))) + return + + # Check if there is a subtitles file + # already associated with this name + existing = subtitles.get(name, None) + if existing: + # A subtitles file exists for this name, remove it + remove(path.join(self.basedir, existing)) + + # Save the file to the correct place and update metadata + rename(inpath, path.join(self.basedir, f"{name}.vtt")) + subtitles[name] = f"{name}.vtt" + def asdict(self): out = deepcopy(self._contents) out['pkg_id'] = self.uuid diff --git a/pipeline/utils.py b/pipeline/utils.py index a3967c1..019863f 100644 --- a/pipeline/utils.py +++ b/pipeline/utils.py @@ -18,10 +18,9 @@ canonical_jobspec = { 'courses': [str], 'tags': [str], 'thumb': str, - 'subtitles': {'Svenska': str, - 'English': str, - 'Svenska (Whisper)': bool, - 'English (Whisper)': bool}, + 'subtitles': {str: str}, + 'generate_subtitles': {str: {'type': str, + 'source': str}}, 'sources': {str: {'poster': str, 'playAudio': bool, 'video': str}}, @@ -45,10 +44,7 @@ canonical_manifest = { 'courses': [str], 'tags': [str], 'thumb': str, - 'subtitles': {'Svenska': str, - 'English': str, - 'Svenska (Whisper)': bool, - 'English (Whisper)': bool}, + 'subtitles': {str: str}, 'sources': {str: {'poster': str, 'playAudio': bool, 'video': {'720': str, diff --git a/requirements.txt b/requirements.txt index e25ba8f..b8c8da9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ watchdog ffmpeg-python Pillow ldap3 +openai-whisper diff --git a/test.py b/test.py index 97b14ff..bcd2682 100755 --- a/test.py +++ b/test.py @@ -191,40 +191,51 @@ class PipelineTest(DaemonTest): json.dump(queuedata, f) return jobid - def wait_for_result(self, jobid, handlername, timeout=5): - resultfile = path.join(self.notify_url, - f"{jobid}.{handlername}") + def wait_for_result(self, jobid, handlers, timeout=5): + resultfiles = {} + for handler in handlers: + resultfiles = {path.join(self.notify_url, + f"{jobid}.{handler}"): False + for handler in handlers} for _try in range(1, timeout+1): sleep(1) - if path.exists(resultfile): + for resultfile, found in resultfiles.items(): + if path.exists(resultfile): + resultfiles[resultfile] = True + if all(resultfiles.values()): break else: self.fail(f"No result produced after {_try} seconds") - with open(resultfile) as f: - try: - result = json.load(f) - except json.decoder.JSONDecodeError: - print("¤ Contents of invalid notification file ¤") - print(f.read()) - print("¤ End invalid notification file contents ¤") - self.fail(f"Invalid JSON in result file.") - # Validate that this is the correct resultfile - self.assertEqual(jobid, result['jobid']) - self.assertEqual(handlername, result['origin']) + final_result = None + for resultfile in resultfiles: + with open(resultfile) as f: + try: + result = json.load(f) + except json.decoder.JSONDecodeError: + print("¤ Contents of invalid notification file ¤") + print(f.read()) + print("¤ End invalid notification file contents ¤") + self.fail(f"Invalid JSON in result file.") - if result['type'] == 'success': - # On success, check match of saved package and notification + # Validate that this is the correct resultfile + self.assertEqual(jobid, result['jobid']) + + if result['pending'] == []: + final_result = result + + if final_result['type'] == 'success': + # On success of all expected handlers, check match of saved + # package and notification package = PackageManager(self.pipeconf['packagedir'], result['package']['pkg_id']).read() for key, value in result['package'].items(): if key == 'pkg_id': - continue - if key == 'thumb' and not value: - continue - self.assertEqual(value, package[key]) + self.assertEqual(value, package.uuid) + else: + self.assertEqual(value, package[key]) - return result + return final_result def init_job(self, pkgid=False, subs=False, thumb=False, source_count=0, poster_count=0): @@ -303,7 +314,7 @@ class PipelineTest(DaemonTest): with open(path.join(self.pipeconf['queuedir'], jobid), 'x') as f: json.dump(queuedata, f) - result = self.wait_for_result(jobid, 'QueueReader') + result = self.wait_for_result(jobid, ['QueueReader']) self.assertEqual(result['type'], 'error') self.assertEqual(result['message'], '"Job specification missing \'type\' key."') @@ -316,7 +327,7 @@ class PipelineTest(DaemonTest): with open(path.join(self.pipeconf['queuedir'], jobid), 'x') as f: json.dump(queuedata, f) - result = self.wait_for_result(jobid, 'QueueReader') + result = self.wait_for_result(jobid, ['QueueReader']) self.assertEqual(result['type'], 'error') self.assertEqual(result['message'], "Invalid type 'invalid_type' in job specification.") @@ -326,18 +337,12 @@ class PipelineTest(DaemonTest): jobspec = self.init_job(subs=True) jobid = self.submit_default_job(jobspec) - result = self.wait_for_result(jobid, 'SubtitlesHandler') - - # There should be no further pending jobs - self.assertEqual([], result['pending']) + result = self.wait_for_result(jobid, ['SubtitlesHandler']) with PackageManager(self.pipeconf['packagedir'], result['package']['pkg_id']) as package: - for lang, subsfile in package['subtitles'].items(): - # Check match of saved package and jobspec - self.assertEqual(package['subtitles'][lang], - jobspec['subtitles'][lang]) - # Subsfile should be in place + for lang in jobspec['subtitles'].keys(): + # Subsfile should be in place for each key subspath = path.join(package.basedir, package['subtitles'][lang]) self.assertTrue(path.exists(subspath)) @@ -357,10 +362,7 @@ class PipelineTest(DaemonTest): jobspec['tags'] = ['foo', 'bar'] jobid = self.submit_default_job(jobspec) - result = self.wait_for_result(jobid, 'MetadataHandler') - - # There should be no further pending jobs - self.assertEqual([], result['pending']) + result = self.wait_for_result(jobid, ['MetadataHandler']) with PackageManager(self.pipeconf['packagedir'], result['package']['pkg_id']) as package: @@ -381,11 +383,10 @@ class PipelineTest(DaemonTest): jobspec = self.init_job(source_count=4, poster_count=2) jobid = self.submit_default_job(jobspec) - # Awaiting poster handler because it must run after transcode - result = self.wait_for_result(jobid, 'PosterHandler', timeout=180) - - # There should be no further pending handlers - self.assertEqual([], result['pending']) + result = self.wait_for_result(jobid, ['AudioHandler', + 'TranscodeHandler', + 'PosterHandler'], + timeout=180) with PackageManager(self.pipeconf['packagedir'], result['package']['pkg_id']) as package: @@ -402,6 +403,41 @@ class PipelineTest(DaemonTest): # uldir should be gone self.assertFalse(path.exists(jobspec['upload_dir'])) + #@unittest.skip("This test is very slow") + def test_generating_subs(self): + jobspec = self.init_job(source_count=1, poster_count=1) + subsource = next(iter(jobspec['sources'])) + jobspec['generate_subtitles'] = {'Generated': {'type': 'whisper', + 'source': subsource}} + + jobid = self.submit_default_job(jobspec) + result = self.wait_for_result(jobid, ['AudioHandler', + 'TranscodeHandler', + 'SubtitlesHandler'], + timeout=180) + + with PackageManager(self.pipeconf['packagedir'], + result['package']['pkg_id']) as package: + # Check match of saved package and jobspec + for name, source in jobspec['sources'].items(): + pkgsource = package['sources'][name] + self.assertEqual(source['playAudio'], pkgsource['playAudio']) + self.assertTrue(path.exists(path.join(package.basedir, + pkgsource['poster']))) + for variant, filename in pkgsource['video'].items(): + videopath = path.join(package.basedir, filename) + self.assertTrue(path.exists(videopath)) + subspath = path.join(package.basedir, + package['subtitles']['Generated']) + self.assertTrue(path.exists(subspath)) + with open(subspath) as f: + print("¤ Generated subs ¤") + print(f.read()) + print("¤ End generated subs ¤") + + # uldir should be gone + self.assertFalse(path.exists(jobspec['upload_dir'])) + if __name__ == '__main__': if not path.exists(f"./{filesdir}"):