linted a few modules, added a check that the language_code exists when calling Transcriber.transcribe, made that method non-abstract to do so. Added a test for this. Extracted a few things to fixtures.

This commit is contained in:
2019-07-18 12:19:42 +02:00
parent e7258f50c9
commit 327ff18726
6 changed files with 297 additions and 157 deletions

View File

@@ -1,23 +1,30 @@
from unittest import mock
from pytest import raises, fixture
from tatt.vendors.amazon import Transcriber
def test_transcriber_instantiate():
filepath = '/Users/zev/tester.mp3'
t = Transcriber(filepath)
assert str(t.filepath) == filepath
assert t.basename == 'tester.mp3'
assert t.media_file_uri == (
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
)
@fixture
def media_filepath():
return "/Users/zev/tester.mp3"
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job')
@fixture
def transcriber_instance(media_filepath):
return Transcriber(media_filepath)
def test_transcriber_instance(media_filepath, transcriber_instance):
assert str(transcriber_instance.filepath) == media_filepath
assert transcriber_instance.basename == "tester.mp3"
assert transcriber_instance.media_file_uri == (
f"https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3"
)
@mock.patch("tatt.vendors.amazon.tr.get_transcription_job")
def test_transcriber_retrieve(get_transcription_job):
filepath = '/Users/zev/tester.mp3'
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
job_name = "4db6808e-a7e8-4d8d-a1b7-753ab97094dc"
t = Transcriber.retrieve_transcript(job_name)
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
@@ -29,9 +36,40 @@ def test_transcriber_get_transcription_jobs():
def test_transcriber_retrieve_transcript():
jobs = Transcriber.get_transcription_jobs()
assert jobs
for j in jobs:
if j['status'].lower() == 'completed':
to_get = j['name']
if j["status"].lower() == "completed":
to_get = j["name"]
break
transcript = Transcriber.retrieve_transcript(to_get)
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'}
assert transcript == {
"jobName": "abcd.mp3",
"accountId": "416321668733",
"results": {
"transcripts": [{"transcript": "Hello there."}],
"items": [
{
"start_time": "0.0",
"end_time": "0.35",
"alternatives": [{"confidence": "0.8303", "content": "Hello"}],
"type": "pronunciation",
},
{
"start_time": "0.35",
"end_time": "0.76",
"alternatives": [{"confidence": "1.0000", "content": "there"}],
"type": "pronunciation",
},
{
"alternatives": [{"confidence": None, "content": "."}],
"type": "punctuation",
},
],
},
"status": "COMPLETED",
}
def test_transcribe_with_nonexistent_language_code(transcriber_instance):
with raises(KeyError):
transcriber_instance.transcribe(language_code="pretend-lang")

View File

@@ -3,6 +3,6 @@ from tatt.vendors import SERVICES
def test_services():
for service in SERVICES.values():
assert hasattr(service, 'Transcriber')
assert hasattr(service, 'NAME')
assert hasattr(service.Transcriber, 'cost_per_15_seconds')
assert hasattr(service, "Transcriber")
assert hasattr(service, "NAME")
assert hasattr(service.Transcriber, "cost_per_15_seconds")

View File

@@ -1,7 +1,3 @@
from tatt.vendors import amazon, google
SERVICES = {
'amazon': amazon,
'google': google,
}
SERVICES = {"amazon": amazon, "google": google}

127
tatt/vendors/amazon.py vendored
View File

@@ -11,7 +11,7 @@ from tatt import config
from tatt import exceptions
from .vendor import TranscriberBaseClass
NAME = 'amazon'
NAME = "amazon"
BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME)
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
TRANSCRIPT_TYPE = dict
@@ -19,28 +19,35 @@ TRANSCRIPT_TYPE = dict
def _check_for_config() -> bool:
return (
config.AWS_CONFIG_FILEPATH.exists()
and config.AWS_CREDENTIALS_FILEPATH.exists()
)
config.AWS_CONFIG_FILEPATH.exists() and config.AWS_CREDENTIALS_FILEPATH.exists()
)
class Transcriber(TranscriberBaseClass):
name = NAME
cost_per_15_seconds = .024 / 4
bucket_names = {'media': BUCKET_NAME_MEDIA,
'transcript': BUCKET_NAME_TRANSCRIPT}
cost_per_15_seconds = 0.024 / 4
bucket_names = {"media": BUCKET_NAME_MEDIA, "transcript": BUCKET_NAME_TRANSCRIPT}
no_config_error_message = 'please run "aws configure" first'
transcript_type = TRANSCRIPT_TYPE
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html
_language_list = ['en-US', 'es-US', 'en-AU', 'fr-CA', 'en-GB', 'de-DE',
'pt-BR', 'fr-FR', 'it-IT', 'ko-KR']
_language_list = [
"en-US",
"es-US",
"en-AU",
"fr-CA",
"en-GB",
"de-DE",
"pt-BR",
"fr-FR",
"it-IT",
"ko-KR",
]
if _check_for_config():
tr = boto3.client('transcribe')
s3 = boto3.resource('s3')
tr = boto3.client("transcribe")
s3 = boto3.resource("s3")
def __init__(self, filepath):
super().__init__(filepath)
@@ -57,7 +64,7 @@ class Transcriber(TranscriberBaseClass):
return (
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
f"{self.bucket_names['media']}/{self.basename}"
)
)
@classmethod
def _setup(cls):
@@ -75,98 +82,96 @@ class Transcriber(TranscriberBaseClass):
cls.s3.create_bucket(Bucket=bucket_name)
def transcribe(self, **kwargs) -> str:
super().transcribe(**kwargs)
self._upload_file()
try:
return self._request_transcription(**kwargs)
except self.tr.exceptions.ConflictException:
raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}')
f"{self.basename} already exists on {NAME}"
)
def _upload_file(self):
self.s3.Bucket(self.bucket_names['media']).upload_file(
str(self.filepath),
self.basename)
self.s3.Bucket(self.bucket_names["media"]).upload_file(
str(self.filepath), self.basename
)
def _request_transcription(
self,
language_code='en-US',
num_speakers=2,
enable_speaker_diarization=True,
) -> str:
self, language_code="en-US", num_speakers=2, enable_speaker_diarization=True
) -> str:
job_name = self.basename
kwargs = dict(
TranscriptionJobName=job_name,
LanguageCode=language_code,
MediaFormat=self.basename.split('.')[-1].lower(),
Media={
'MediaFileUri': self.media_file_uri
},
OutputBucketName=self.bucket_names['transcript']
)
TranscriptionJobName=job_name,
LanguageCode=language_code,
MediaFormat=self.basename.split(".")[-1].lower(),
Media={"MediaFileUri": self.media_file_uri},
OutputBucketName=self.bucket_names["transcript"],
)
if enable_speaker_diarization:
kwargs.update(dict(
Settings={
'ShowSpeakerLabels': True,
'MaxSpeakerLabels': num_speakers,
kwargs.update(
dict(
Settings={
"ShowSpeakerLabels": True,
"MaxSpeakerLabels": num_speakers,
}
))
)
)
self.tr.start_transcription_job(**kwargs)
return job_name
@classmethod
def get_transcription_jobs(
cls,
status:str = None,
job_name_query:str = None,
) -> List[dict]:
cls, status: str = None, job_name_query: str = None
) -> List[dict]:
kwargs = {'MaxResults': 100}
kwargs = {"MaxResults": 100}
if status is not None:
kwargs['Status'] = status.upper()
kwargs["Status"] = status.upper()
if job_name_query is not None:
kwargs['JobNameContains'] = job_name_query
kwargs["JobNameContains"] = job_name_query
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
key = 'TranscriptionJobSummaries'
key = "TranscriptionJobSummaries"
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
while jobs_data.get('NextToken'):
token = jobs_data['NextToken']
while jobs_data.get("NextToken"):
token = jobs_data["NextToken"]
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
return jobs
@classmethod
def retrieve_transcript(cls, transcription_job_name: str
) -> TRANSCRIPT_TYPE:
job = cls.tr.get_transcription_job(
TranscriptionJobName=transcription_job_name
)['TranscriptionJob']
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
job = cls.tr.get_transcription_job(TranscriptionJobName=transcription_job_name)[
"TranscriptionJob"
]
if not job['TranscriptionJobStatus'] == 'COMPLETED':
if not job["TranscriptionJobStatus"] == "COMPLETED":
return
transcript_file_uri = job['Transcript']['TranscriptFileUri']
transcript_file_uri = job["Transcript"]["TranscriptFileUri"]
transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1]
transcript_bucket = transcript_path.split('/', 1)[0]
transcript_key = transcript_path.split('/', 1)[1]
transcript_bucket = transcript_path.split("/", 1)[0]
transcript_key = transcript_path.split("/", 1)[1]
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
transcript_json = s3_object['Body'].read().decode('utf-8')
transcript_json = s3_object["Body"].read().decode("utf-8")
return json.loads(transcript_json)
@staticmethod
def homogenize_transcription_job_data(transcription_job_data):
return [{
'created': jd['CreationTime'],
'name': jd['TranscriptionJobName'],
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]
return [
{
"created": jd["CreationTime"],
"name": jd["TranscriptionJobName"],
"status": jd["TranscriptionJobStatus"],
}
for jd in transcription_job_data
]

233
tatt/vendors/google.py vendored
View File

@@ -12,51 +12,155 @@ from google.cloud import (
speech_v1p1beta1 as speech,
storage,
exceptions as gc_exceptions,
)
)
from tatt import exceptions, helpers, config as config_mod
from .vendor import TranscriberBaseClass
NAME = 'google'
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format(
'goog')
NAME = "google"
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format("goog")
TRANSCRIPT_TYPE = str
def _check_for_config():
return os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None
return os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None
class Transcriber(TranscriberBaseClass):
name = NAME
SUPPORTED_FORMATS = ['flac']
cost_per_15_seconds = [.004, .006, .009]
SUPPORTED_FORMATS = ["flac"]
cost_per_15_seconds = [0.004, 0.006, 0.009]
no_config_error_message = (
'Please sign up for the Google Speech-to-Text API '
'and put the path to your credentials in an '
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
)
"Please sign up for the Google Speech-to-Text API "
"and put the path to your credentials in an "
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
)
transcript_type = TRANSCRIPT_TYPE
# https://cloud.google.com/speech-to-text/docs/languages
# Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText)
_language_list = [
'af-ZA', 'am-ET', 'hy-AM', 'az-AZ', 'id-ID', 'ms-MY',
'bn-BD', 'bn-IN', 'ca-ES', 'cs-CZ', 'da-DK', 'de-DE', 'en-AU', 'en-CA',
'en-GH', 'en-GB', 'en-IN', 'en-IE', 'en-KE', 'en-NZ', 'en-NG', 'en-PH',
'en-SG', 'en-ZA', 'en-TZ', 'en-US', 'es-AR', 'es-BO', 'es-CL', 'es-CO',
'es-CR', 'es-EC', 'es-SV', 'es-ES', 'es-US', 'es-GT', 'es-HN', 'es-MX',
'es-NI', 'es-PA', 'es-PY', 'es-PE', 'es-PR', 'es-DO', 'es-UY', 'es-VE',
'eu-ES', 'fil-PH', 'fr-CA', 'fr-FR', 'gl-ES', 'ka-GE', 'gu-IN', 'hr-HR',
'zu-ZA', 'is-IS', 'it-IT', 'jv-ID', 'kn-IN', 'km-KH', 'lo-LA', 'lv-LV',
'lt-LT', 'hu-HU', 'ml-IN', 'mr-IN', 'nl-NL', 'ne-NP', 'nb-NO', 'pl-PL',
'pt-BR', 'pt-PT', 'ro-RO', 'si-LK', 'sk-SK', 'sl-SI', 'su-ID', 'sw-TZ',
'sw-KE', 'fi-FI', 'sv-SE', 'ta-IN', 'ta-SG', 'ta-LK', 'ta-MY', 'te-IN',
'vi-VN', 'tr-TR', 'ur-PK', 'ur-IN', 'el-GR', 'bg-BG', 'ru-RU', 'sr-RS',
'uk-UA', 'he-IL', 'ar-IL', 'ar-JO', 'ar-AE', 'ar-BH', 'ar-DZ', 'ar-SA',
'ar-IQ', 'ar-KW', 'ar-MA', 'ar-TN', 'ar-OM', 'ar-PS', 'ar-QA', 'ar-LB',
'ar-EG', 'fa-IR', 'hi-IN', 'th-TH', 'ko-KR', 'zh-TW', 'yue-Hant-HK',
'ja-JP', 'zh-HK', 'zh']
"af-ZA",
"am-ET",
"hy-AM",
"az-AZ",
"id-ID",
"ms-MY",
"bn-BD",
"bn-IN",
"ca-ES",
"cs-CZ",
"da-DK",
"de-DE",
"en-AU",
"en-CA",
"en-GH",
"en-GB",
"en-IN",
"en-IE",
"en-KE",
"en-NZ",
"en-NG",
"en-PH",
"en-SG",
"en-ZA",
"en-TZ",
"en-US",
"es-AR",
"es-BO",
"es-CL",
"es-CO",
"es-CR",
"es-EC",
"es-SV",
"es-ES",
"es-US",
"es-GT",
"es-HN",
"es-MX",
"es-NI",
"es-PA",
"es-PY",
"es-PE",
"es-PR",
"es-DO",
"es-UY",
"es-VE",
"eu-ES",
"fil-PH",
"fr-CA",
"fr-FR",
"gl-ES",
"ka-GE",
"gu-IN",
"hr-HR",
"zu-ZA",
"is-IS",
"it-IT",
"jv-ID",
"kn-IN",
"km-KH",
"lo-LA",
"lv-LV",
"lt-LT",
"hu-HU",
"ml-IN",
"mr-IN",
"nl-NL",
"ne-NP",
"nb-NO",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"si-LK",
"sk-SK",
"sl-SI",
"su-ID",
"sw-TZ",
"sw-KE",
"fi-FI",
"sv-SE",
"ta-IN",
"ta-SG",
"ta-LK",
"ta-MY",
"te-IN",
"vi-VN",
"tr-TR",
"ur-PK",
"ur-IN",
"el-GR",
"bg-BG",
"ru-RU",
"sr-RS",
"uk-UA",
"he-IL",
"ar-IL",
"ar-JO",
"ar-AE",
"ar-BH",
"ar-DZ",
"ar-SA",
"ar-IQ",
"ar-KW",
"ar-MA",
"ar-TN",
"ar-OM",
"ar-PS",
"ar-QA",
"ar-LB",
"ar-EG",
"fa-IR",
"hi-IN",
"th-TH",
"ko-KR",
"zh-TW",
"yue-Hant-HK",
"ja-JP",
"zh-HK",
"zh",
]
if _check_for_config():
speech_client = speech.SpeechClient()
@@ -69,10 +173,11 @@ class Transcriber(TranscriberBaseClass):
@classmethod
def _setup(cls):
super()._setup()
if not shutil.which('gsutil'):
if not shutil.which("gsutil"):
raise exceptions.DependencyRequired(
'Please install gcloud using the steps here:'
'https://cloud.google.com/storage/docs/gsutil_install')
"Please install gcloud using the steps here:"
"https://cloud.google.com/storage/docs/gsutil_install"
)
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
@@ -84,13 +189,13 @@ class Transcriber(TranscriberBaseClass):
# this might fail if a bucket by the name exists *anywhere* on GCS?
return
else:
print('made Google Cloud Storage Bucket for transcripts')
print("made Google Cloud Storage Bucket for transcripts")
def convert_file_format_if_needed(self):
if self.file_format not in self.SUPPORTED_FORMATS:
if not shutil.which('ffmpeg'):
raise exceptions.DependencyRequired('please install ffmpeg')
self.filepath = helpers.convert_file(self.filepath, 'flac')
if not shutil.which("ffmpeg"):
raise exceptions.DependencyRequired("please install ffmpeg")
self.filepath = helpers.convert_file(self.filepath, "flac")
@property
def file_format(self):
@@ -111,31 +216,31 @@ class Transcriber(TranscriberBaseClass):
def _check_if_transcript_exists(self, transcript_name=None):
return storage.Blob(
bucket=self.transcript_bucket,
name=transcript_name or self.basename
).exists(self.storage_client)
bucket=self.transcript_bucket, name=transcript_name or self.basename
).exists(self.storage_client)
def _request_transcription(
self,
language_code='en-US',
enable_automatic_punctuation=True,
enable_speaker_diarization=True,
num_speakers=2,
model='phone_call',
use_enhanced=True,
) -> str:
self,
language_code="en-US",
enable_automatic_punctuation=True,
enable_speaker_diarization=True,
num_speakers=2,
model="phone_call",
use_enhanced=True,
) -> str:
"""Returns the job_name"""
if self._check_if_transcript_exists():
raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}')
f"{self.basename} already exists on {NAME}"
)
num_audio_channels = helpers.get_num_audio_channels(self.filepath)
sample_rate = helpers.get_sample_rate(self.filepath)
with io.open(self.filepath, 'rb') as audio_file:
with io.open(self.filepath, "rb") as audio_file:
content = audio_file.read()
audio = speech.types.RecognitionAudio(content=content)
if language_code != 'en-US':
if language_code != "en-US":
model = None
config = speech.types.RecognitionConfig(
@@ -151,39 +256,37 @@ class Transcriber(TranscriberBaseClass):
diarization_speaker_count=num_speakers,
model=model,
use_enhanced=use_enhanced,
)
)
self.operation = self.speech_client.long_running_recognize(config,
audio)
self.operation = self.speech_client.long_running_recognize(config, audio)
print('transcribing...')
print("transcribing...")
while not self.operation.done():
sleep(1)
print('.')
print(".")
result_list = []
for result in self.operation.result().results:
result_list.append(str(result))
print('saving transcript')
transcript_path = '/tmp/transcript.txt'
with open(transcript_path, 'w') as fout:
fout.write('\n'.join(result_list))
print('uploading transcript')
print("saving transcript")
transcript_path = "/tmp/transcript.txt"
with open(transcript_path, "w") as fout:
fout.write("\n".join(result_list))
print("uploading transcript")
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
os.remove(transcript_path)
return self.basename
@classmethod
def retrieve_transcript(cls, transcription_job_name: str
) -> TRANSCRIPT_TYPE:
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
"""Get transcript from BUCKET_NAME_TRANSCRIPT"""
if not cls._check_if_transcript_exists(
cls,
transcript_name=transcription_job_name):
raise exceptions.DoesntExistError('no such transcript!')
cls, transcript_name=transcription_job_name
):
raise exceptions.DoesntExistError("no such transcript!")
blob = cls.transcript_bucket.blob(transcription_job_name)
f = tempfile.NamedTemporaryFile(delete=False)
f.close()
@@ -202,7 +305,7 @@ class Transcriber(TranscriberBaseClass):
@classmethod
def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]:
if status and status.lower() != 'completed':
if status and status.lower() != "completed":
return []
jobs = []
@@ -210,6 +313,6 @@ class Transcriber(TranscriberBaseClass):
for t in cls.transcript_bucket.list_blobs():
if job_name_query is not None and t.name != job_name_query:
continue
jobs.append({'name': t.name, 'status': 'COMPLETED'})
jobs.append({"name": t.name, "status": "COMPLETED"})
return jobs

View File

@@ -12,8 +12,8 @@ class TranscriberBaseClass:
def __init__(self, filepath):
self._setup()
if ' ' in filepath:
raise exceptions.FormatError('Please don\'t put any spaces in the filename.')
if " " in filepath:
raise exceptions.FormatError("Please don't put any spaces in the filename.")
self.filepath = PurePath(filepath)
self.basename = str(os.path.basename(self.filepath))
@@ -52,14 +52,14 @@ class TranscriberBaseClass:
def check_for_config(cls) -> bool:
pass
@abc.abstractmethod
def transcribe(self) -> str:
def transcribe(self, **kwargs) -> str:
"""
This should do any required logic,
then call self._request_transcription.
It should return the job_name.
"""
pass
if kwargs["language_code"] not in self.language_list():
raise KeyError(f"No such language code {kwargs['language_code']}")
@abc.abstractmethod
def _request_transcription(self) -> str:
@@ -72,12 +72,10 @@ class TranscriberBaseClass:
@classmethod
@abc.abstractmethod
def retrieve_transcript(cls, transcription_job_name: str
) -> Union[str, dict]:
def retrieve_transcript(cls, transcription_job_name: str) -> Union[str, dict]:
pass
@classmethod
@abc.abstractmethod
def get_transcription_jobs() -> List[dict]:
pass