whisper-daemon/pipeline/whisperthread.py

123 lines
4.6 KiB
Python

from collections.abc import MutableMapping
from pathlib import Path
from queue import Queue
import multiprocessing as mp
# Commented as a workaround, see the hack comment in __init__()
#import torch
from .exceptions import ConfigurationException
from .job import Job
from .queuethread import QueueThread
from .whisperprocess import whisper_process
@QueueThread.register
class WhisperThread(QueueThread):
def __init__(self,
inqueue: Queue,
outqueue: mp.Queue,
donedir: Path,
model: str,
modeldir: Path,
workerconfig: MutableMapping):
"""
Instantiate a WhisperThread, which will contain one or more workers.
The passed config must contain the following keys:
- model: The whisper model to be used
- modeldir: The directory where the model is stored
Worker count is governed by two configuration values: device and
count. Both are optional.
If device is not specified, the device will be determined by
GPU availability.
If only device is specified and the device is GPU, a handler is
created for each avilable GPU. If no GPUs are available, an exception
is raised.
If only device is provided and the device is CPU, a single handler
is created.
If count is provided, that number of handlers will be created. If the
device is GPU, and count is greater than the number of available GPUs,
an exception is raised.
"""
super().__init__(inqueue)
self.outqueue = outqueue
self.donedir = donedir
self.modeldir = modeldir
self.model = model
self.highqueue = mp.Queue()
self.lowqueue = mp.Queue()
# There is a hack here due to an unfortunate confluence of bugs.
#
# - Torch can't deal with seeing all GPUs and then getting some of
# them masked by CUDA_VISIBLE_DEVICES.
# - Triton can't deal with more than one GPU being visible.
#
# As a consequence, it isn't possible to auto-detect the appropriate
# number of whisper workers, so the relevant code has been commented
# out awaiting a fix.
# Commented out as a workaround
#gpu_count = torch.cuda.device_count()
if 'device' in workerconfig:
device = workerconfig['device']
else:
# Throw an exception as a workaround
raise ConfigurationException('device and count are mandatory '
'due to library bugs')
if gpu_count > 0:
device = 'gpu'
else:
device = 'cpu'
if device == 'cpu':
worker_count = int(workerconfig.get('count', 1))
devices = [f'{device}' for i in range(worker_count)]
else:
# This block is commented as a workaround
#if gpu_count == 0:
# raise ConfigurationException(
# 'GPU requested but none available.')
#
#worker_count = int(workerconfig.get('count', gpu_count))
#if worker_count > gpu_count:
# raise ConfigurationException(
# f'{worker_count} workers requested '
# f'but only {gpu_count} GPUs available.')
# This line is introduced as a workaround
worker_count = int(workerconfig['count'])
devices = [f'{i}' for i in range(worker_count)]
self.logger.debug('Spawning %s workers on %s', worker_count, device)
self.workers = [self.spawn_worker(dev) for dev in devices]
def spawn_worker(self, device: str):
worker_process = mp.Process(target=whisper_process,
args=(self.highqueue,
self.lowqueue,
self.outqueue,
self.donedir,
self.model,
self.modeldir,
self.logger.getEffectiveLevel(),
device))
worker_process.daemon = True
worker_process.start()
return worker_process
def _process(self, item: Job):
if item.origin == "play":
self.highqueue.put(item)
self.logger.debug('%s queued to high queue', item.jobid)
else:
self.lowqueue.put(item)
self.logger.debug('%s queued to low queue', item.jobid)