linted a few modules, added a check that the language_code exists when calling Transcriber.transcribe, made that method non-abstract to do so. Added a test for this. Extracted a few things to fixtures.

This commit is contained in:
2019-07-18 12:19:42 +02:00
parent e7258f50c9
commit 327ff18726
6 changed files with 297 additions and 157 deletions

View File

@@ -1,23 +1,30 @@
from unittest import mock from unittest import mock
from pytest import raises, fixture
from tatt.vendors.amazon import Transcriber from tatt.vendors.amazon import Transcriber
@fixture
def test_transcriber_instantiate(): def media_filepath():
filepath = '/Users/zev/tester.mp3' return "/Users/zev/tester.mp3"
t = Transcriber(filepath)
assert str(t.filepath) == filepath
assert t.basename == 'tester.mp3'
assert t.media_file_uri == (
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
)
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job') @fixture
def transcriber_instance(media_filepath):
return Transcriber(media_filepath)
def test_transcriber_instance(media_filepath, transcriber_instance):
assert str(transcriber_instance.filepath) == media_filepath
assert transcriber_instance.basename == "tester.mp3"
assert transcriber_instance.media_file_uri == (
f"https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3"
)
@mock.patch("tatt.vendors.amazon.tr.get_transcription_job")
def test_transcriber_retrieve(get_transcription_job): def test_transcriber_retrieve(get_transcription_job):
filepath = '/Users/zev/tester.mp3' job_name = "4db6808e-a7e8-4d8d-a1b7-753ab97094dc"
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
t = Transcriber.retrieve_transcript(job_name) t = Transcriber.retrieve_transcript(job_name)
get_transcription_job.assert_called_with(TranscriptionJobName=job_name) get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
@@ -29,9 +36,40 @@ def test_transcriber_get_transcription_jobs():
def test_transcriber_retrieve_transcript(): def test_transcriber_retrieve_transcript():
jobs = Transcriber.get_transcription_jobs() jobs = Transcriber.get_transcription_jobs()
assert jobs
for j in jobs: for j in jobs:
if j['status'].lower() == 'completed': if j["status"].lower() == "completed":
to_get = j['name'] to_get = j["name"]
break break
transcript = Transcriber.retrieve_transcript(to_get) transcript = Transcriber.retrieve_transcript(to_get)
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'} assert transcript == {
"jobName": "abcd.mp3",
"accountId": "416321668733",
"results": {
"transcripts": [{"transcript": "Hello there."}],
"items": [
{
"start_time": "0.0",
"end_time": "0.35",
"alternatives": [{"confidence": "0.8303", "content": "Hello"}],
"type": "pronunciation",
},
{
"start_time": "0.35",
"end_time": "0.76",
"alternatives": [{"confidence": "1.0000", "content": "there"}],
"type": "pronunciation",
},
{
"alternatives": [{"confidence": None, "content": "."}],
"type": "punctuation",
},
],
},
"status": "COMPLETED",
}
def test_transcribe_with_nonexistent_language_code(transcriber_instance):
with raises(KeyError):
transcriber_instance.transcribe(language_code="pretend-lang")

View File

@@ -3,6 +3,6 @@ from tatt.vendors import SERVICES
def test_services(): def test_services():
for service in SERVICES.values(): for service in SERVICES.values():
assert hasattr(service, 'Transcriber') assert hasattr(service, "Transcriber")
assert hasattr(service, 'NAME') assert hasattr(service, "NAME")
assert hasattr(service.Transcriber, 'cost_per_15_seconds') assert hasattr(service.Transcriber, "cost_per_15_seconds")

View File

@@ -1,7 +1,3 @@
from tatt.vendors import amazon, google from tatt.vendors import amazon, google
SERVICES = { SERVICES = {"amazon": amazon, "google": google}
'amazon': amazon,
'google': google,
}

127
tatt/vendors/amazon.py vendored
View File

@@ -11,7 +11,7 @@ from tatt import config
from tatt import exceptions from tatt import exceptions
from .vendor import TranscriberBaseClass from .vendor import TranscriberBaseClass
NAME = 'amazon' NAME = "amazon"
BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME) BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME)
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME) BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
TRANSCRIPT_TYPE = dict TRANSCRIPT_TYPE = dict
@@ -19,28 +19,35 @@ TRANSCRIPT_TYPE = dict
def _check_for_config() -> bool: def _check_for_config() -> bool:
return ( return (
config.AWS_CONFIG_FILEPATH.exists() config.AWS_CONFIG_FILEPATH.exists() and config.AWS_CREDENTIALS_FILEPATH.exists()
and config.AWS_CREDENTIALS_FILEPATH.exists() )
)
class Transcriber(TranscriberBaseClass): class Transcriber(TranscriberBaseClass):
name = NAME name = NAME
cost_per_15_seconds = .024 / 4 cost_per_15_seconds = 0.024 / 4
bucket_names = {'media': BUCKET_NAME_MEDIA, bucket_names = {"media": BUCKET_NAME_MEDIA, "transcript": BUCKET_NAME_TRANSCRIPT}
'transcript': BUCKET_NAME_TRANSCRIPT}
no_config_error_message = 'please run "aws configure" first' no_config_error_message = 'please run "aws configure" first'
transcript_type = TRANSCRIPT_TYPE transcript_type = TRANSCRIPT_TYPE
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html
_language_list = ['en-US', 'es-US', 'en-AU', 'fr-CA', 'en-GB', 'de-DE', _language_list = [
'pt-BR', 'fr-FR', 'it-IT', 'ko-KR'] "en-US",
"es-US",
"en-AU",
"fr-CA",
"en-GB",
"de-DE",
"pt-BR",
"fr-FR",
"it-IT",
"ko-KR",
]
if _check_for_config(): if _check_for_config():
tr = boto3.client('transcribe') tr = boto3.client("transcribe")
s3 = boto3.resource('s3') s3 = boto3.resource("s3")
def __init__(self, filepath): def __init__(self, filepath):
super().__init__(filepath) super().__init__(filepath)
@@ -57,7 +64,7 @@ class Transcriber(TranscriberBaseClass):
return ( return (
f"https://s3-{config.AWS_REGION}.amazonaws.com/" f"https://s3-{config.AWS_REGION}.amazonaws.com/"
f"{self.bucket_names['media']}/{self.basename}" f"{self.bucket_names['media']}/{self.basename}"
) )
@classmethod @classmethod
def _setup(cls): def _setup(cls):
@@ -75,98 +82,96 @@ class Transcriber(TranscriberBaseClass):
cls.s3.create_bucket(Bucket=bucket_name) cls.s3.create_bucket(Bucket=bucket_name)
def transcribe(self, **kwargs) -> str: def transcribe(self, **kwargs) -> str:
super().transcribe(**kwargs)
self._upload_file() self._upload_file()
try: try:
return self._request_transcription(**kwargs) return self._request_transcription(**kwargs)
except self.tr.exceptions.ConflictException: except self.tr.exceptions.ConflictException:
raise exceptions.AlreadyExistsError( raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}') f"{self.basename} already exists on {NAME}"
)
def _upload_file(self): def _upload_file(self):
self.s3.Bucket(self.bucket_names['media']).upload_file( self.s3.Bucket(self.bucket_names["media"]).upload_file(
str(self.filepath), str(self.filepath), self.basename
self.basename) )
def _request_transcription( def _request_transcription(
self, self, language_code="en-US", num_speakers=2, enable_speaker_diarization=True
language_code='en-US', ) -> str:
num_speakers=2,
enable_speaker_diarization=True,
) -> str:
job_name = self.basename job_name = self.basename
kwargs = dict( kwargs = dict(
TranscriptionJobName=job_name, TranscriptionJobName=job_name,
LanguageCode=language_code, LanguageCode=language_code,
MediaFormat=self.basename.split('.')[-1].lower(), MediaFormat=self.basename.split(".")[-1].lower(),
Media={ Media={"MediaFileUri": self.media_file_uri},
'MediaFileUri': self.media_file_uri OutputBucketName=self.bucket_names["transcript"],
}, )
OutputBucketName=self.bucket_names['transcript']
)
if enable_speaker_diarization: if enable_speaker_diarization:
kwargs.update(dict( kwargs.update(
Settings={ dict(
'ShowSpeakerLabels': True, Settings={
'MaxSpeakerLabels': num_speakers, "ShowSpeakerLabels": True,
"MaxSpeakerLabels": num_speakers,
} }
)) )
)
self.tr.start_transcription_job(**kwargs) self.tr.start_transcription_job(**kwargs)
return job_name return job_name
@classmethod @classmethod
def get_transcription_jobs( def get_transcription_jobs(
cls, cls, status: str = None, job_name_query: str = None
status:str = None, ) -> List[dict]:
job_name_query:str = None,
) -> List[dict]:
kwargs = {'MaxResults': 100} kwargs = {"MaxResults": 100}
if status is not None: if status is not None:
kwargs['Status'] = status.upper() kwargs["Status"] = status.upper()
if job_name_query is not None: if job_name_query is not None:
kwargs['JobNameContains'] = job_name_query kwargs["JobNameContains"] = job_name_query
jobs_data = cls.tr.list_transcription_jobs(**kwargs) jobs_data = cls.tr.list_transcription_jobs(**kwargs)
key = 'TranscriptionJobSummaries' key = "TranscriptionJobSummaries"
jobs = cls.homogenize_transcription_job_data(jobs_data[key]) jobs = cls.homogenize_transcription_job_data(jobs_data[key])
while jobs_data.get('NextToken'): while jobs_data.get("NextToken"):
token = jobs_data['NextToken'] token = jobs_data["NextToken"]
jobs_data = cls.tr.list_transcription_jobs(NextToken=token) jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
jobs += cls.homogenize_transcription_job_data(jobs_data[key]) jobs += cls.homogenize_transcription_job_data(jobs_data[key])
return jobs return jobs
@classmethod @classmethod
def retrieve_transcript(cls, transcription_job_name: str def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
) -> TRANSCRIPT_TYPE: job = cls.tr.get_transcription_job(TranscriptionJobName=transcription_job_name)[
job = cls.tr.get_transcription_job( "TranscriptionJob"
TranscriptionJobName=transcription_job_name ]
)['TranscriptionJob']
if not job['TranscriptionJobStatus'] == 'COMPLETED': if not job["TranscriptionJobStatus"] == "COMPLETED":
return return
transcript_file_uri = job['Transcript']['TranscriptFileUri'] transcript_file_uri = job["Transcript"]["TranscriptFileUri"]
transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1] transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1]
transcript_bucket = transcript_path.split('/', 1)[0] transcript_bucket = transcript_path.split("/", 1)[0]
transcript_key = transcript_path.split('/', 1)[1] transcript_key = transcript_path.split("/", 1)[1]
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get() s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
transcript_json = s3_object['Body'].read().decode('utf-8') transcript_json = s3_object["Body"].read().decode("utf-8")
return json.loads(transcript_json) return json.loads(transcript_json)
@staticmethod @staticmethod
def homogenize_transcription_job_data(transcription_job_data): def homogenize_transcription_job_data(transcription_job_data):
return [{ return [
'created': jd['CreationTime'], {
'name': jd['TranscriptionJobName'], "created": jd["CreationTime"],
'status': jd['TranscriptionJobStatus'] "name": jd["TranscriptionJobName"],
} "status": jd["TranscriptionJobStatus"],
for jd in transcription_job_data] }
for jd in transcription_job_data
]

233
tatt/vendors/google.py vendored
View File

@@ -12,51 +12,155 @@ from google.cloud import (
speech_v1p1beta1 as speech, speech_v1p1beta1 as speech,
storage, storage,
exceptions as gc_exceptions, exceptions as gc_exceptions,
) )
from tatt import exceptions, helpers, config as config_mod from tatt import exceptions, helpers, config as config_mod
from .vendor import TranscriberBaseClass from .vendor import TranscriberBaseClass
NAME = 'google' NAME = "google"
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format( BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format("goog")
'goog')
TRANSCRIPT_TYPE = str TRANSCRIPT_TYPE = str
def _check_for_config(): def _check_for_config():
return os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None return os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None
class Transcriber(TranscriberBaseClass): class Transcriber(TranscriberBaseClass):
name = NAME name = NAME
SUPPORTED_FORMATS = ['flac'] SUPPORTED_FORMATS = ["flac"]
cost_per_15_seconds = [.004, .006, .009] cost_per_15_seconds = [0.004, 0.006, 0.009]
no_config_error_message = ( no_config_error_message = (
'Please sign up for the Google Speech-to-Text API ' "Please sign up for the Google Speech-to-Text API "
'and put the path to your credentials in an ' "and put the path to your credentials in an "
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"' 'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
) )
transcript_type = TRANSCRIPT_TYPE transcript_type = TRANSCRIPT_TYPE
# https://cloud.google.com/speech-to-text/docs/languages # https://cloud.google.com/speech-to-text/docs/languages
# Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText) # Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText)
_language_list = [ _language_list = [
'af-ZA', 'am-ET', 'hy-AM', 'az-AZ', 'id-ID', 'ms-MY', "af-ZA",
'bn-BD', 'bn-IN', 'ca-ES', 'cs-CZ', 'da-DK', 'de-DE', 'en-AU', 'en-CA', "am-ET",
'en-GH', 'en-GB', 'en-IN', 'en-IE', 'en-KE', 'en-NZ', 'en-NG', 'en-PH', "hy-AM",
'en-SG', 'en-ZA', 'en-TZ', 'en-US', 'es-AR', 'es-BO', 'es-CL', 'es-CO', "az-AZ",
'es-CR', 'es-EC', 'es-SV', 'es-ES', 'es-US', 'es-GT', 'es-HN', 'es-MX', "id-ID",
'es-NI', 'es-PA', 'es-PY', 'es-PE', 'es-PR', 'es-DO', 'es-UY', 'es-VE', "ms-MY",
'eu-ES', 'fil-PH', 'fr-CA', 'fr-FR', 'gl-ES', 'ka-GE', 'gu-IN', 'hr-HR', "bn-BD",
'zu-ZA', 'is-IS', 'it-IT', 'jv-ID', 'kn-IN', 'km-KH', 'lo-LA', 'lv-LV', "bn-IN",
'lt-LT', 'hu-HU', 'ml-IN', 'mr-IN', 'nl-NL', 'ne-NP', 'nb-NO', 'pl-PL', "ca-ES",
'pt-BR', 'pt-PT', 'ro-RO', 'si-LK', 'sk-SK', 'sl-SI', 'su-ID', 'sw-TZ', "cs-CZ",
'sw-KE', 'fi-FI', 'sv-SE', 'ta-IN', 'ta-SG', 'ta-LK', 'ta-MY', 'te-IN', "da-DK",
'vi-VN', 'tr-TR', 'ur-PK', 'ur-IN', 'el-GR', 'bg-BG', 'ru-RU', 'sr-RS', "de-DE",
'uk-UA', 'he-IL', 'ar-IL', 'ar-JO', 'ar-AE', 'ar-BH', 'ar-DZ', 'ar-SA', "en-AU",
'ar-IQ', 'ar-KW', 'ar-MA', 'ar-TN', 'ar-OM', 'ar-PS', 'ar-QA', 'ar-LB', "en-CA",
'ar-EG', 'fa-IR', 'hi-IN', 'th-TH', 'ko-KR', 'zh-TW', 'yue-Hant-HK', "en-GH",
'ja-JP', 'zh-HK', 'zh'] "en-GB",
"en-IN",
"en-IE",
"en-KE",
"en-NZ",
"en-NG",
"en-PH",
"en-SG",
"en-ZA",
"en-TZ",
"en-US",
"es-AR",
"es-BO",
"es-CL",
"es-CO",
"es-CR",
"es-EC",
"es-SV",
"es-ES",
"es-US",
"es-GT",
"es-HN",
"es-MX",
"es-NI",
"es-PA",
"es-PY",
"es-PE",
"es-PR",
"es-DO",
"es-UY",
"es-VE",
"eu-ES",
"fil-PH",
"fr-CA",
"fr-FR",
"gl-ES",
"ka-GE",
"gu-IN",
"hr-HR",
"zu-ZA",
"is-IS",
"it-IT",
"jv-ID",
"kn-IN",
"km-KH",
"lo-LA",
"lv-LV",
"lt-LT",
"hu-HU",
"ml-IN",
"mr-IN",
"nl-NL",
"ne-NP",
"nb-NO",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"si-LK",
"sk-SK",
"sl-SI",
"su-ID",
"sw-TZ",
"sw-KE",
"fi-FI",
"sv-SE",
"ta-IN",
"ta-SG",
"ta-LK",
"ta-MY",
"te-IN",
"vi-VN",
"tr-TR",
"ur-PK",
"ur-IN",
"el-GR",
"bg-BG",
"ru-RU",
"sr-RS",
"uk-UA",
"he-IL",
"ar-IL",
"ar-JO",
"ar-AE",
"ar-BH",
"ar-DZ",
"ar-SA",
"ar-IQ",
"ar-KW",
"ar-MA",
"ar-TN",
"ar-OM",
"ar-PS",
"ar-QA",
"ar-LB",
"ar-EG",
"fa-IR",
"hi-IN",
"th-TH",
"ko-KR",
"zh-TW",
"yue-Hant-HK",
"ja-JP",
"zh-HK",
"zh",
]
if _check_for_config(): if _check_for_config():
speech_client = speech.SpeechClient() speech_client = speech.SpeechClient()
@@ -69,10 +173,11 @@ class Transcriber(TranscriberBaseClass):
@classmethod @classmethod
def _setup(cls): def _setup(cls):
super()._setup() super()._setup()
if not shutil.which('gsutil'): if not shutil.which("gsutil"):
raise exceptions.DependencyRequired( raise exceptions.DependencyRequired(
'Please install gcloud using the steps here:' "Please install gcloud using the steps here:"
'https://cloud.google.com/storage/docs/gsutil_install') "https://cloud.google.com/storage/docs/gsutil_install"
)
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT) cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
@@ -84,13 +189,13 @@ class Transcriber(TranscriberBaseClass):
# this might fail if a bucket by the name exists *anywhere* on GCS? # this might fail if a bucket by the name exists *anywhere* on GCS?
return return
else: else:
print('made Google Cloud Storage Bucket for transcripts') print("made Google Cloud Storage Bucket for transcripts")
def convert_file_format_if_needed(self): def convert_file_format_if_needed(self):
if self.file_format not in self.SUPPORTED_FORMATS: if self.file_format not in self.SUPPORTED_FORMATS:
if not shutil.which('ffmpeg'): if not shutil.which("ffmpeg"):
raise exceptions.DependencyRequired('please install ffmpeg') raise exceptions.DependencyRequired("please install ffmpeg")
self.filepath = helpers.convert_file(self.filepath, 'flac') self.filepath = helpers.convert_file(self.filepath, "flac")
@property @property
def file_format(self): def file_format(self):
@@ -111,31 +216,31 @@ class Transcriber(TranscriberBaseClass):
def _check_if_transcript_exists(self, transcript_name=None): def _check_if_transcript_exists(self, transcript_name=None):
return storage.Blob( return storage.Blob(
bucket=self.transcript_bucket, bucket=self.transcript_bucket, name=transcript_name or self.basename
name=transcript_name or self.basename ).exists(self.storage_client)
).exists(self.storage_client)
def _request_transcription( def _request_transcription(
self, self,
language_code='en-US', language_code="en-US",
enable_automatic_punctuation=True, enable_automatic_punctuation=True,
enable_speaker_diarization=True, enable_speaker_diarization=True,
num_speakers=2, num_speakers=2,
model='phone_call', model="phone_call",
use_enhanced=True, use_enhanced=True,
) -> str: ) -> str:
"""Returns the job_name""" """Returns the job_name"""
if self._check_if_transcript_exists(): if self._check_if_transcript_exists():
raise exceptions.AlreadyExistsError( raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}') f"{self.basename} already exists on {NAME}"
)
num_audio_channels = helpers.get_num_audio_channels(self.filepath) num_audio_channels = helpers.get_num_audio_channels(self.filepath)
sample_rate = helpers.get_sample_rate(self.filepath) sample_rate = helpers.get_sample_rate(self.filepath)
with io.open(self.filepath, 'rb') as audio_file: with io.open(self.filepath, "rb") as audio_file:
content = audio_file.read() content = audio_file.read()
audio = speech.types.RecognitionAudio(content=content) audio = speech.types.RecognitionAudio(content=content)
if language_code != 'en-US': if language_code != "en-US":
model = None model = None
config = speech.types.RecognitionConfig( config = speech.types.RecognitionConfig(
@@ -151,39 +256,37 @@ class Transcriber(TranscriberBaseClass):
diarization_speaker_count=num_speakers, diarization_speaker_count=num_speakers,
model=model, model=model,
use_enhanced=use_enhanced, use_enhanced=use_enhanced,
) )
self.operation = self.speech_client.long_running_recognize(config, self.operation = self.speech_client.long_running_recognize(config, audio)
audio)
print('transcribing...') print("transcribing...")
while not self.operation.done(): while not self.operation.done():
sleep(1) sleep(1)
print('.') print(".")
result_list = [] result_list = []
for result in self.operation.result().results: for result in self.operation.result().results:
result_list.append(str(result)) result_list.append(str(result))
print('saving transcript') print("saving transcript")
transcript_path = '/tmp/transcript.txt' transcript_path = "/tmp/transcript.txt"
with open(transcript_path, 'w') as fout: with open(transcript_path, "w") as fout:
fout.write('\n'.join(result_list)) fout.write("\n".join(result_list))
print('uploading transcript') print("uploading transcript")
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path) self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
os.remove(transcript_path) os.remove(transcript_path)
return self.basename return self.basename
@classmethod @classmethod
def retrieve_transcript(cls, transcription_job_name: str def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
) -> TRANSCRIPT_TYPE:
"""Get transcript from BUCKET_NAME_TRANSCRIPT""" """Get transcript from BUCKET_NAME_TRANSCRIPT"""
if not cls._check_if_transcript_exists( if not cls._check_if_transcript_exists(
cls, cls, transcript_name=transcription_job_name
transcript_name=transcription_job_name): ):
raise exceptions.DoesntExistError('no such transcript!') raise exceptions.DoesntExistError("no such transcript!")
blob = cls.transcript_bucket.blob(transcription_job_name) blob = cls.transcript_bucket.blob(transcription_job_name)
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
f.close() f.close()
@@ -202,7 +305,7 @@ class Transcriber(TranscriberBaseClass):
@classmethod @classmethod
def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]: def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]:
if status and status.lower() != 'completed': if status and status.lower() != "completed":
return [] return []
jobs = [] jobs = []
@@ -210,6 +313,6 @@ class Transcriber(TranscriberBaseClass):
for t in cls.transcript_bucket.list_blobs(): for t in cls.transcript_bucket.list_blobs():
if job_name_query is not None and t.name != job_name_query: if job_name_query is not None and t.name != job_name_query:
continue continue
jobs.append({'name': t.name, 'status': 'COMPLETED'}) jobs.append({"name": t.name, "status": "COMPLETED"})
return jobs return jobs

View File

@@ -12,8 +12,8 @@ class TranscriberBaseClass:
def __init__(self, filepath): def __init__(self, filepath):
self._setup() self._setup()
if ' ' in filepath: if " " in filepath:
raise exceptions.FormatError('Please don\'t put any spaces in the filename.') raise exceptions.FormatError("Please don't put any spaces in the filename.")
self.filepath = PurePath(filepath) self.filepath = PurePath(filepath)
self.basename = str(os.path.basename(self.filepath)) self.basename = str(os.path.basename(self.filepath))
@@ -52,14 +52,14 @@ class TranscriberBaseClass:
def check_for_config(cls) -> bool: def check_for_config(cls) -> bool:
pass pass
@abc.abstractmethod def transcribe(self, **kwargs) -> str:
def transcribe(self) -> str:
""" """
This should do any required logic, This should do any required logic,
then call self._request_transcription. then call self._request_transcription.
It should return the job_name. It should return the job_name.
""" """
pass if kwargs["language_code"] not in self.language_list():
raise KeyError(f"No such language code {kwargs['language_code']}")
@abc.abstractmethod @abc.abstractmethod
def _request_transcription(self) -> str: def _request_transcription(self) -> str:
@@ -72,12 +72,10 @@ class TranscriberBaseClass:
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def retrieve_transcript(cls, transcription_job_name: str def retrieve_transcript(cls, transcription_job_name: str) -> Union[str, dict]:
) -> Union[str, dict]:
pass pass
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def get_transcription_jobs() -> List[dict]: def get_transcription_jobs() -> List[dict]:
pass pass