linted a few modules, added a check that the language_code exists when calling Transcriber.transcribe, made that method non-abstract to do so. Added a test for this. Extracted a few things to fixtures.
This commit is contained in:
@@ -1,23 +1,30 @@
|
||||
from unittest import mock
|
||||
from pytest import raises, fixture
|
||||
|
||||
from tatt.vendors.amazon import Transcriber
|
||||
|
||||
|
||||
|
||||
def test_transcriber_instantiate():
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
t = Transcriber(filepath)
|
||||
assert str(t.filepath) == filepath
|
||||
assert t.basename == 'tester.mp3'
|
||||
assert t.media_file_uri == (
|
||||
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
|
||||
)
|
||||
@fixture
|
||||
def media_filepath():
|
||||
return "/Users/zev/tester.mp3"
|
||||
|
||||
|
||||
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job')
|
||||
@fixture
|
||||
def transcriber_instance(media_filepath):
|
||||
return Transcriber(media_filepath)
|
||||
|
||||
|
||||
def test_transcriber_instance(media_filepath, transcriber_instance):
|
||||
assert str(transcriber_instance.filepath) == media_filepath
|
||||
assert transcriber_instance.basename == "tester.mp3"
|
||||
assert transcriber_instance.media_file_uri == (
|
||||
f"https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3"
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("tatt.vendors.amazon.tr.get_transcription_job")
|
||||
def test_transcriber_retrieve(get_transcription_job):
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
|
||||
job_name = "4db6808e-a7e8-4d8d-a1b7-753ab97094dc"
|
||||
t = Transcriber.retrieve_transcript(job_name)
|
||||
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
|
||||
|
||||
@@ -29,9 +36,40 @@ def test_transcriber_get_transcription_jobs():
|
||||
|
||||
def test_transcriber_retrieve_transcript():
|
||||
jobs = Transcriber.get_transcription_jobs()
|
||||
assert jobs
|
||||
for j in jobs:
|
||||
if j['status'].lower() == 'completed':
|
||||
to_get = j['name']
|
||||
if j["status"].lower() == "completed":
|
||||
to_get = j["name"]
|
||||
break
|
||||
transcript = Transcriber.retrieve_transcript(to_get)
|
||||
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'}
|
||||
assert transcript == {
|
||||
"jobName": "abcd.mp3",
|
||||
"accountId": "416321668733",
|
||||
"results": {
|
||||
"transcripts": [{"transcript": "Hello there."}],
|
||||
"items": [
|
||||
{
|
||||
"start_time": "0.0",
|
||||
"end_time": "0.35",
|
||||
"alternatives": [{"confidence": "0.8303", "content": "Hello"}],
|
||||
"type": "pronunciation",
|
||||
},
|
||||
{
|
||||
"start_time": "0.35",
|
||||
"end_time": "0.76",
|
||||
"alternatives": [{"confidence": "1.0000", "content": "there"}],
|
||||
"type": "pronunciation",
|
||||
},
|
||||
{
|
||||
"alternatives": [{"confidence": None, "content": "."}],
|
||||
"type": "punctuation",
|
||||
},
|
||||
],
|
||||
},
|
||||
"status": "COMPLETED",
|
||||
}
|
||||
|
||||
|
||||
def test_transcribe_with_nonexistent_language_code(transcriber_instance):
|
||||
with raises(KeyError):
|
||||
transcriber_instance.transcribe(language_code="pretend-lang")
|
||||
|
||||
@@ -3,6 +3,6 @@ from tatt.vendors import SERVICES
|
||||
|
||||
def test_services():
|
||||
for service in SERVICES.values():
|
||||
assert hasattr(service, 'Transcriber')
|
||||
assert hasattr(service, 'NAME')
|
||||
assert hasattr(service.Transcriber, 'cost_per_15_seconds')
|
||||
assert hasattr(service, "Transcriber")
|
||||
assert hasattr(service, "NAME")
|
||||
assert hasattr(service.Transcriber, "cost_per_15_seconds")
|
||||
|
||||
6
tatt/vendors/__init__.py
vendored
6
tatt/vendors/__init__.py
vendored
@@ -1,7 +1,3 @@
|
||||
from tatt.vendors import amazon, google
|
||||
|
||||
SERVICES = {
|
||||
'amazon': amazon,
|
||||
'google': google,
|
||||
}
|
||||
|
||||
SERVICES = {"amazon": amazon, "google": google}
|
||||
|
||||
127
tatt/vendors/amazon.py
vendored
127
tatt/vendors/amazon.py
vendored
@@ -11,7 +11,7 @@ from tatt import config
|
||||
from tatt import exceptions
|
||||
from .vendor import TranscriberBaseClass
|
||||
|
||||
NAME = 'amazon'
|
||||
NAME = "amazon"
|
||||
BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME)
|
||||
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
|
||||
TRANSCRIPT_TYPE = dict
|
||||
@@ -19,28 +19,35 @@ TRANSCRIPT_TYPE = dict
|
||||
|
||||
def _check_for_config() -> bool:
|
||||
return (
|
||||
config.AWS_CONFIG_FILEPATH.exists()
|
||||
and config.AWS_CREDENTIALS_FILEPATH.exists()
|
||||
)
|
||||
|
||||
config.AWS_CONFIG_FILEPATH.exists() and config.AWS_CREDENTIALS_FILEPATH.exists()
|
||||
)
|
||||
|
||||
|
||||
class Transcriber(TranscriberBaseClass):
|
||||
|
||||
name = NAME
|
||||
cost_per_15_seconds = .024 / 4
|
||||
bucket_names = {'media': BUCKET_NAME_MEDIA,
|
||||
'transcript': BUCKET_NAME_TRANSCRIPT}
|
||||
cost_per_15_seconds = 0.024 / 4
|
||||
bucket_names = {"media": BUCKET_NAME_MEDIA, "transcript": BUCKET_NAME_TRANSCRIPT}
|
||||
|
||||
no_config_error_message = 'please run "aws configure" first'
|
||||
transcript_type = TRANSCRIPT_TYPE
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html
|
||||
_language_list = ['en-US', 'es-US', 'en-AU', 'fr-CA', 'en-GB', 'de-DE',
|
||||
'pt-BR', 'fr-FR', 'it-IT', 'ko-KR']
|
||||
_language_list = [
|
||||
"en-US",
|
||||
"es-US",
|
||||
"en-AU",
|
||||
"fr-CA",
|
||||
"en-GB",
|
||||
"de-DE",
|
||||
"pt-BR",
|
||||
"fr-FR",
|
||||
"it-IT",
|
||||
"ko-KR",
|
||||
]
|
||||
|
||||
if _check_for_config():
|
||||
tr = boto3.client('transcribe')
|
||||
s3 = boto3.resource('s3')
|
||||
tr = boto3.client("transcribe")
|
||||
s3 = boto3.resource("s3")
|
||||
|
||||
def __init__(self, filepath):
|
||||
super().__init__(filepath)
|
||||
@@ -57,7 +64,7 @@ class Transcriber(TranscriberBaseClass):
|
||||
return (
|
||||
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
|
||||
f"{self.bucket_names['media']}/{self.basename}"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _setup(cls):
|
||||
@@ -75,98 +82,96 @@ class Transcriber(TranscriberBaseClass):
|
||||
cls.s3.create_bucket(Bucket=bucket_name)
|
||||
|
||||
def transcribe(self, **kwargs) -> str:
|
||||
super().transcribe(**kwargs)
|
||||
self._upload_file()
|
||||
try:
|
||||
return self._request_transcription(**kwargs)
|
||||
except self.tr.exceptions.ConflictException:
|
||||
raise exceptions.AlreadyExistsError(
|
||||
f'{self.basename} already exists on {NAME}')
|
||||
f"{self.basename} already exists on {NAME}"
|
||||
)
|
||||
|
||||
def _upload_file(self):
|
||||
self.s3.Bucket(self.bucket_names['media']).upload_file(
|
||||
str(self.filepath),
|
||||
self.basename)
|
||||
self.s3.Bucket(self.bucket_names["media"]).upload_file(
|
||||
str(self.filepath), self.basename
|
||||
)
|
||||
|
||||
def _request_transcription(
|
||||
self,
|
||||
language_code='en-US',
|
||||
num_speakers=2,
|
||||
enable_speaker_diarization=True,
|
||||
) -> str:
|
||||
self, language_code="en-US", num_speakers=2, enable_speaker_diarization=True
|
||||
) -> str:
|
||||
job_name = self.basename
|
||||
|
||||
kwargs = dict(
|
||||
TranscriptionJobName=job_name,
|
||||
LanguageCode=language_code,
|
||||
MediaFormat=self.basename.split('.')[-1].lower(),
|
||||
Media={
|
||||
'MediaFileUri': self.media_file_uri
|
||||
},
|
||||
OutputBucketName=self.bucket_names['transcript']
|
||||
)
|
||||
TranscriptionJobName=job_name,
|
||||
LanguageCode=language_code,
|
||||
MediaFormat=self.basename.split(".")[-1].lower(),
|
||||
Media={"MediaFileUri": self.media_file_uri},
|
||||
OutputBucketName=self.bucket_names["transcript"],
|
||||
)
|
||||
|
||||
if enable_speaker_diarization:
|
||||
kwargs.update(dict(
|
||||
Settings={
|
||||
'ShowSpeakerLabels': True,
|
||||
'MaxSpeakerLabels': num_speakers,
|
||||
kwargs.update(
|
||||
dict(
|
||||
Settings={
|
||||
"ShowSpeakerLabels": True,
|
||||
"MaxSpeakerLabels": num_speakers,
|
||||
}
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
self.tr.start_transcription_job(**kwargs)
|
||||
return job_name
|
||||
|
||||
@classmethod
|
||||
def get_transcription_jobs(
|
||||
cls,
|
||||
status:str = None,
|
||||
job_name_query:str = None,
|
||||
) -> List[dict]:
|
||||
cls, status: str = None, job_name_query: str = None
|
||||
) -> List[dict]:
|
||||
|
||||
kwargs = {'MaxResults': 100}
|
||||
kwargs = {"MaxResults": 100}
|
||||
|
||||
if status is not None:
|
||||
kwargs['Status'] = status.upper()
|
||||
kwargs["Status"] = status.upper()
|
||||
if job_name_query is not None:
|
||||
kwargs['JobNameContains'] = job_name_query
|
||||
kwargs["JobNameContains"] = job_name_query
|
||||
|
||||
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
|
||||
key = 'TranscriptionJobSummaries'
|
||||
key = "TranscriptionJobSummaries"
|
||||
|
||||
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
while jobs_data.get('NextToken'):
|
||||
token = jobs_data['NextToken']
|
||||
while jobs_data.get("NextToken"):
|
||||
token = jobs_data["NextToken"]
|
||||
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
|
||||
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
return jobs
|
||||
|
||||
@classmethod
|
||||
def retrieve_transcript(cls, transcription_job_name: str
|
||||
) -> TRANSCRIPT_TYPE:
|
||||
job = cls.tr.get_transcription_job(
|
||||
TranscriptionJobName=transcription_job_name
|
||||
)['TranscriptionJob']
|
||||
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
|
||||
job = cls.tr.get_transcription_job(TranscriptionJobName=transcription_job_name)[
|
||||
"TranscriptionJob"
|
||||
]
|
||||
|
||||
if not job['TranscriptionJobStatus'] == 'COMPLETED':
|
||||
if not job["TranscriptionJobStatus"] == "COMPLETED":
|
||||
return
|
||||
|
||||
transcript_file_uri = job['Transcript']['TranscriptFileUri']
|
||||
transcript_file_uri = job["Transcript"]["TranscriptFileUri"]
|
||||
transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1]
|
||||
|
||||
transcript_bucket = transcript_path.split('/', 1)[0]
|
||||
transcript_key = transcript_path.split('/', 1)[1]
|
||||
transcript_bucket = transcript_path.split("/", 1)[0]
|
||||
transcript_key = transcript_path.split("/", 1)[1]
|
||||
|
||||
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
|
||||
transcript_json = s3_object['Body'].read().decode('utf-8')
|
||||
transcript_json = s3_object["Body"].read().decode("utf-8")
|
||||
return json.loads(transcript_json)
|
||||
|
||||
@staticmethod
|
||||
def homogenize_transcription_job_data(transcription_job_data):
|
||||
return [{
|
||||
'created': jd['CreationTime'],
|
||||
'name': jd['TranscriptionJobName'],
|
||||
'status': jd['TranscriptionJobStatus']
|
||||
}
|
||||
for jd in transcription_job_data]
|
||||
return [
|
||||
{
|
||||
"created": jd["CreationTime"],
|
||||
"name": jd["TranscriptionJobName"],
|
||||
"status": jd["TranscriptionJobStatus"],
|
||||
}
|
||||
for jd in transcription_job_data
|
||||
]
|
||||
|
||||
233
tatt/vendors/google.py
vendored
233
tatt/vendors/google.py
vendored
@@ -12,51 +12,155 @@ from google.cloud import (
|
||||
speech_v1p1beta1 as speech,
|
||||
storage,
|
||||
exceptions as gc_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
from tatt import exceptions, helpers, config as config_mod
|
||||
from .vendor import TranscriberBaseClass
|
||||
|
||||
NAME = 'google'
|
||||
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format(
|
||||
'goog')
|
||||
NAME = "google"
|
||||
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format("goog")
|
||||
TRANSCRIPT_TYPE = str
|
||||
|
||||
|
||||
def _check_for_config():
|
||||
return os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None
|
||||
return os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None
|
||||
|
||||
|
||||
class Transcriber(TranscriberBaseClass):
|
||||
|
||||
name = NAME
|
||||
SUPPORTED_FORMATS = ['flac']
|
||||
cost_per_15_seconds = [.004, .006, .009]
|
||||
SUPPORTED_FORMATS = ["flac"]
|
||||
cost_per_15_seconds = [0.004, 0.006, 0.009]
|
||||
no_config_error_message = (
|
||||
'Please sign up for the Google Speech-to-Text API '
|
||||
'and put the path to your credentials in an '
|
||||
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
|
||||
)
|
||||
"Please sign up for the Google Speech-to-Text API "
|
||||
"and put the path to your credentials in an "
|
||||
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
|
||||
)
|
||||
transcript_type = TRANSCRIPT_TYPE
|
||||
# https://cloud.google.com/speech-to-text/docs/languages
|
||||
# Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText)
|
||||
_language_list = [
|
||||
'af-ZA', 'am-ET', 'hy-AM', 'az-AZ', 'id-ID', 'ms-MY',
|
||||
'bn-BD', 'bn-IN', 'ca-ES', 'cs-CZ', 'da-DK', 'de-DE', 'en-AU', 'en-CA',
|
||||
'en-GH', 'en-GB', 'en-IN', 'en-IE', 'en-KE', 'en-NZ', 'en-NG', 'en-PH',
|
||||
'en-SG', 'en-ZA', 'en-TZ', 'en-US', 'es-AR', 'es-BO', 'es-CL', 'es-CO',
|
||||
'es-CR', 'es-EC', 'es-SV', 'es-ES', 'es-US', 'es-GT', 'es-HN', 'es-MX',
|
||||
'es-NI', 'es-PA', 'es-PY', 'es-PE', 'es-PR', 'es-DO', 'es-UY', 'es-VE',
|
||||
'eu-ES', 'fil-PH', 'fr-CA', 'fr-FR', 'gl-ES', 'ka-GE', 'gu-IN', 'hr-HR',
|
||||
'zu-ZA', 'is-IS', 'it-IT', 'jv-ID', 'kn-IN', 'km-KH', 'lo-LA', 'lv-LV',
|
||||
'lt-LT', 'hu-HU', 'ml-IN', 'mr-IN', 'nl-NL', 'ne-NP', 'nb-NO', 'pl-PL',
|
||||
'pt-BR', 'pt-PT', 'ro-RO', 'si-LK', 'sk-SK', 'sl-SI', 'su-ID', 'sw-TZ',
|
||||
'sw-KE', 'fi-FI', 'sv-SE', 'ta-IN', 'ta-SG', 'ta-LK', 'ta-MY', 'te-IN',
|
||||
'vi-VN', 'tr-TR', 'ur-PK', 'ur-IN', 'el-GR', 'bg-BG', 'ru-RU', 'sr-RS',
|
||||
'uk-UA', 'he-IL', 'ar-IL', 'ar-JO', 'ar-AE', 'ar-BH', 'ar-DZ', 'ar-SA',
|
||||
'ar-IQ', 'ar-KW', 'ar-MA', 'ar-TN', 'ar-OM', 'ar-PS', 'ar-QA', 'ar-LB',
|
||||
'ar-EG', 'fa-IR', 'hi-IN', 'th-TH', 'ko-KR', 'zh-TW', 'yue-Hant-HK',
|
||||
'ja-JP', 'zh-HK', 'zh']
|
||||
"af-ZA",
|
||||
"am-ET",
|
||||
"hy-AM",
|
||||
"az-AZ",
|
||||
"id-ID",
|
||||
"ms-MY",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"ca-ES",
|
||||
"cs-CZ",
|
||||
"da-DK",
|
||||
"de-DE",
|
||||
"en-AU",
|
||||
"en-CA",
|
||||
"en-GH",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"en-IE",
|
||||
"en-KE",
|
||||
"en-NZ",
|
||||
"en-NG",
|
||||
"en-PH",
|
||||
"en-SG",
|
||||
"en-ZA",
|
||||
"en-TZ",
|
||||
"en-US",
|
||||
"es-AR",
|
||||
"es-BO",
|
||||
"es-CL",
|
||||
"es-CO",
|
||||
"es-CR",
|
||||
"es-EC",
|
||||
"es-SV",
|
||||
"es-ES",
|
||||
"es-US",
|
||||
"es-GT",
|
||||
"es-HN",
|
||||
"es-MX",
|
||||
"es-NI",
|
||||
"es-PA",
|
||||
"es-PY",
|
||||
"es-PE",
|
||||
"es-PR",
|
||||
"es-DO",
|
||||
"es-UY",
|
||||
"es-VE",
|
||||
"eu-ES",
|
||||
"fil-PH",
|
||||
"fr-CA",
|
||||
"fr-FR",
|
||||
"gl-ES",
|
||||
"ka-GE",
|
||||
"gu-IN",
|
||||
"hr-HR",
|
||||
"zu-ZA",
|
||||
"is-IS",
|
||||
"it-IT",
|
||||
"jv-ID",
|
||||
"kn-IN",
|
||||
"km-KH",
|
||||
"lo-LA",
|
||||
"lv-LV",
|
||||
"lt-LT",
|
||||
"hu-HU",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"nl-NL",
|
||||
"ne-NP",
|
||||
"nb-NO",
|
||||
"pl-PL",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ro-RO",
|
||||
"si-LK",
|
||||
"sk-SK",
|
||||
"sl-SI",
|
||||
"su-ID",
|
||||
"sw-TZ",
|
||||
"sw-KE",
|
||||
"fi-FI",
|
||||
"sv-SE",
|
||||
"ta-IN",
|
||||
"ta-SG",
|
||||
"ta-LK",
|
||||
"ta-MY",
|
||||
"te-IN",
|
||||
"vi-VN",
|
||||
"tr-TR",
|
||||
"ur-PK",
|
||||
"ur-IN",
|
||||
"el-GR",
|
||||
"bg-BG",
|
||||
"ru-RU",
|
||||
"sr-RS",
|
||||
"uk-UA",
|
||||
"he-IL",
|
||||
"ar-IL",
|
||||
"ar-JO",
|
||||
"ar-AE",
|
||||
"ar-BH",
|
||||
"ar-DZ",
|
||||
"ar-SA",
|
||||
"ar-IQ",
|
||||
"ar-KW",
|
||||
"ar-MA",
|
||||
"ar-TN",
|
||||
"ar-OM",
|
||||
"ar-PS",
|
||||
"ar-QA",
|
||||
"ar-LB",
|
||||
"ar-EG",
|
||||
"fa-IR",
|
||||
"hi-IN",
|
||||
"th-TH",
|
||||
"ko-KR",
|
||||
"zh-TW",
|
||||
"yue-Hant-HK",
|
||||
"ja-JP",
|
||||
"zh-HK",
|
||||
"zh",
|
||||
]
|
||||
|
||||
if _check_for_config():
|
||||
speech_client = speech.SpeechClient()
|
||||
@@ -69,10 +173,11 @@ class Transcriber(TranscriberBaseClass):
|
||||
@classmethod
|
||||
def _setup(cls):
|
||||
super()._setup()
|
||||
if not shutil.which('gsutil'):
|
||||
if not shutil.which("gsutil"):
|
||||
raise exceptions.DependencyRequired(
|
||||
'Please install gcloud using the steps here:'
|
||||
'https://cloud.google.com/storage/docs/gsutil_install')
|
||||
"Please install gcloud using the steps here:"
|
||||
"https://cloud.google.com/storage/docs/gsutil_install"
|
||||
)
|
||||
|
||||
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
|
||||
|
||||
@@ -84,13 +189,13 @@ class Transcriber(TranscriberBaseClass):
|
||||
# this might fail if a bucket by the name exists *anywhere* on GCS?
|
||||
return
|
||||
else:
|
||||
print('made Google Cloud Storage Bucket for transcripts')
|
||||
print("made Google Cloud Storage Bucket for transcripts")
|
||||
|
||||
def convert_file_format_if_needed(self):
|
||||
if self.file_format not in self.SUPPORTED_FORMATS:
|
||||
if not shutil.which('ffmpeg'):
|
||||
raise exceptions.DependencyRequired('please install ffmpeg')
|
||||
self.filepath = helpers.convert_file(self.filepath, 'flac')
|
||||
if not shutil.which("ffmpeg"):
|
||||
raise exceptions.DependencyRequired("please install ffmpeg")
|
||||
self.filepath = helpers.convert_file(self.filepath, "flac")
|
||||
|
||||
@property
|
||||
def file_format(self):
|
||||
@@ -111,31 +216,31 @@ class Transcriber(TranscriberBaseClass):
|
||||
|
||||
def _check_if_transcript_exists(self, transcript_name=None):
|
||||
return storage.Blob(
|
||||
bucket=self.transcript_bucket,
|
||||
name=transcript_name or self.basename
|
||||
).exists(self.storage_client)
|
||||
bucket=self.transcript_bucket, name=transcript_name or self.basename
|
||||
).exists(self.storage_client)
|
||||
|
||||
def _request_transcription(
|
||||
self,
|
||||
language_code='en-US',
|
||||
enable_automatic_punctuation=True,
|
||||
enable_speaker_diarization=True,
|
||||
num_speakers=2,
|
||||
model='phone_call',
|
||||
use_enhanced=True,
|
||||
) -> str:
|
||||
self,
|
||||
language_code="en-US",
|
||||
enable_automatic_punctuation=True,
|
||||
enable_speaker_diarization=True,
|
||||
num_speakers=2,
|
||||
model="phone_call",
|
||||
use_enhanced=True,
|
||||
) -> str:
|
||||
"""Returns the job_name"""
|
||||
if self._check_if_transcript_exists():
|
||||
raise exceptions.AlreadyExistsError(
|
||||
f'{self.basename} already exists on {NAME}')
|
||||
f"{self.basename} already exists on {NAME}"
|
||||
)
|
||||
num_audio_channels = helpers.get_num_audio_channels(self.filepath)
|
||||
sample_rate = helpers.get_sample_rate(self.filepath)
|
||||
|
||||
with io.open(self.filepath, 'rb') as audio_file:
|
||||
with io.open(self.filepath, "rb") as audio_file:
|
||||
content = audio_file.read()
|
||||
audio = speech.types.RecognitionAudio(content=content)
|
||||
|
||||
if language_code != 'en-US':
|
||||
if language_code != "en-US":
|
||||
model = None
|
||||
|
||||
config = speech.types.RecognitionConfig(
|
||||
@@ -151,39 +256,37 @@ class Transcriber(TranscriberBaseClass):
|
||||
diarization_speaker_count=num_speakers,
|
||||
model=model,
|
||||
use_enhanced=use_enhanced,
|
||||
)
|
||||
)
|
||||
|
||||
self.operation = self.speech_client.long_running_recognize(config,
|
||||
audio)
|
||||
self.operation = self.speech_client.long_running_recognize(config, audio)
|
||||
|
||||
print('transcribing...')
|
||||
print("transcribing...")
|
||||
while not self.operation.done():
|
||||
sleep(1)
|
||||
print('.')
|
||||
print(".")
|
||||
|
||||
result_list = []
|
||||
|
||||
for result in self.operation.result().results:
|
||||
result_list.append(str(result))
|
||||
|
||||
print('saving transcript')
|
||||
transcript_path = '/tmp/transcript.txt'
|
||||
with open(transcript_path, 'w') as fout:
|
||||
fout.write('\n'.join(result_list))
|
||||
print('uploading transcript')
|
||||
print("saving transcript")
|
||||
transcript_path = "/tmp/transcript.txt"
|
||||
with open(transcript_path, "w") as fout:
|
||||
fout.write("\n".join(result_list))
|
||||
print("uploading transcript")
|
||||
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
|
||||
os.remove(transcript_path)
|
||||
|
||||
return self.basename
|
||||
|
||||
@classmethod
|
||||
def retrieve_transcript(cls, transcription_job_name: str
|
||||
) -> TRANSCRIPT_TYPE:
|
||||
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
|
||||
"""Get transcript from BUCKET_NAME_TRANSCRIPT"""
|
||||
if not cls._check_if_transcript_exists(
|
||||
cls,
|
||||
transcript_name=transcription_job_name):
|
||||
raise exceptions.DoesntExistError('no such transcript!')
|
||||
cls, transcript_name=transcription_job_name
|
||||
):
|
||||
raise exceptions.DoesntExistError("no such transcript!")
|
||||
blob = cls.transcript_bucket.blob(transcription_job_name)
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.close()
|
||||
@@ -202,7 +305,7 @@ class Transcriber(TranscriberBaseClass):
|
||||
@classmethod
|
||||
def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]:
|
||||
|
||||
if status and status.lower() != 'completed':
|
||||
if status and status.lower() != "completed":
|
||||
return []
|
||||
|
||||
jobs = []
|
||||
@@ -210,6 +313,6 @@ class Transcriber(TranscriberBaseClass):
|
||||
for t in cls.transcript_bucket.list_blobs():
|
||||
if job_name_query is not None and t.name != job_name_query:
|
||||
continue
|
||||
jobs.append({'name': t.name, 'status': 'COMPLETED'})
|
||||
jobs.append({"name": t.name, "status": "COMPLETED"})
|
||||
|
||||
return jobs
|
||||
|
||||
14
tatt/vendors/vendor.py
vendored
14
tatt/vendors/vendor.py
vendored
@@ -12,8 +12,8 @@ class TranscriberBaseClass:
|
||||
|
||||
def __init__(self, filepath):
|
||||
self._setup()
|
||||
if ' ' in filepath:
|
||||
raise exceptions.FormatError('Please don\'t put any spaces in the filename.')
|
||||
if " " in filepath:
|
||||
raise exceptions.FormatError("Please don't put any spaces in the filename.")
|
||||
self.filepath = PurePath(filepath)
|
||||
self.basename = str(os.path.basename(self.filepath))
|
||||
|
||||
@@ -52,14 +52,14 @@ class TranscriberBaseClass:
|
||||
def check_for_config(cls) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def transcribe(self) -> str:
|
||||
def transcribe(self, **kwargs) -> str:
|
||||
"""
|
||||
This should do any required logic,
|
||||
then call self._request_transcription.
|
||||
It should return the job_name.
|
||||
"""
|
||||
pass
|
||||
if kwargs["language_code"] not in self.language_list():
|
||||
raise KeyError(f"No such language code {kwargs['language_code']}")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _request_transcription(self) -> str:
|
||||
@@ -72,12 +72,10 @@ class TranscriberBaseClass:
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def retrieve_transcript(cls, transcription_job_name: str
|
||||
) -> Union[str, dict]:
|
||||
def retrieve_transcript(cls, transcription_job_name: str) -> Union[str, dict]:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_transcription_jobs() -> List[dict]:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user