linted a few modules, added a check that the language_code exists when calling Transcriber.transcribe, made that method non-abstract to do so. Added a test for this. Extracted a few things to fixtures.
This commit is contained in:
@@ -1,23 +1,30 @@
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from pytest import raises, fixture
|
||||||
|
|
||||||
from tatt.vendors.amazon import Transcriber
|
from tatt.vendors.amazon import Transcriber
|
||||||
|
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def media_filepath():
|
||||||
|
return "/Users/zev/tester.mp3"
|
||||||
|
|
||||||
def test_transcriber_instantiate():
|
|
||||||
filepath = '/Users/zev/tester.mp3'
|
@fixture
|
||||||
t = Transcriber(filepath)
|
def transcriber_instance(media_filepath):
|
||||||
assert str(t.filepath) == filepath
|
return Transcriber(media_filepath)
|
||||||
assert t.basename == 'tester.mp3'
|
|
||||||
assert t.media_file_uri == (
|
|
||||||
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
|
def test_transcriber_instance(media_filepath, transcriber_instance):
|
||||||
|
assert str(transcriber_instance.filepath) == media_filepath
|
||||||
|
assert transcriber_instance.basename == "tester.mp3"
|
||||||
|
assert transcriber_instance.media_file_uri == (
|
||||||
|
f"https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job')
|
@mock.patch("tatt.vendors.amazon.tr.get_transcription_job")
|
||||||
def test_transcriber_retrieve(get_transcription_job):
|
def test_transcriber_retrieve(get_transcription_job):
|
||||||
filepath = '/Users/zev/tester.mp3'
|
job_name = "4db6808e-a7e8-4d8d-a1b7-753ab97094dc"
|
||||||
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
|
|
||||||
t = Transcriber.retrieve_transcript(job_name)
|
t = Transcriber.retrieve_transcript(job_name)
|
||||||
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
|
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
|
||||||
|
|
||||||
@@ -29,9 +36,40 @@ def test_transcriber_get_transcription_jobs():
|
|||||||
|
|
||||||
def test_transcriber_retrieve_transcript():
|
def test_transcriber_retrieve_transcript():
|
||||||
jobs = Transcriber.get_transcription_jobs()
|
jobs = Transcriber.get_transcription_jobs()
|
||||||
|
assert jobs
|
||||||
for j in jobs:
|
for j in jobs:
|
||||||
if j['status'].lower() == 'completed':
|
if j["status"].lower() == "completed":
|
||||||
to_get = j['name']
|
to_get = j["name"]
|
||||||
break
|
break
|
||||||
transcript = Transcriber.retrieve_transcript(to_get)
|
transcript = Transcriber.retrieve_transcript(to_get)
|
||||||
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'}
|
assert transcript == {
|
||||||
|
"jobName": "abcd.mp3",
|
||||||
|
"accountId": "416321668733",
|
||||||
|
"results": {
|
||||||
|
"transcripts": [{"transcript": "Hello there."}],
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"start_time": "0.0",
|
||||||
|
"end_time": "0.35",
|
||||||
|
"alternatives": [{"confidence": "0.8303", "content": "Hello"}],
|
||||||
|
"type": "pronunciation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_time": "0.35",
|
||||||
|
"end_time": "0.76",
|
||||||
|
"alternatives": [{"confidence": "1.0000", "content": "there"}],
|
||||||
|
"type": "pronunciation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"alternatives": [{"confidence": None, "content": "."}],
|
||||||
|
"type": "punctuation",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"status": "COMPLETED",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_with_nonexistent_language_code(transcriber_instance):
|
||||||
|
with raises(KeyError):
|
||||||
|
transcriber_instance.transcribe(language_code="pretend-lang")
|
||||||
|
|||||||
@@ -3,6 +3,6 @@ from tatt.vendors import SERVICES
|
|||||||
|
|
||||||
def test_services():
|
def test_services():
|
||||||
for service in SERVICES.values():
|
for service in SERVICES.values():
|
||||||
assert hasattr(service, 'Transcriber')
|
assert hasattr(service, "Transcriber")
|
||||||
assert hasattr(service, 'NAME')
|
assert hasattr(service, "NAME")
|
||||||
assert hasattr(service.Transcriber, 'cost_per_15_seconds')
|
assert hasattr(service.Transcriber, "cost_per_15_seconds")
|
||||||
|
|||||||
6
tatt/vendors/__init__.py
vendored
6
tatt/vendors/__init__.py
vendored
@@ -1,7 +1,3 @@
|
|||||||
from tatt.vendors import amazon, google
|
from tatt.vendors import amazon, google
|
||||||
|
|
||||||
SERVICES = {
|
SERVICES = {"amazon": amazon, "google": google}
|
||||||
'amazon': amazon,
|
|
||||||
'google': google,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|||||||
109
tatt/vendors/amazon.py
vendored
109
tatt/vendors/amazon.py
vendored
@@ -11,7 +11,7 @@ from tatt import config
|
|||||||
from tatt import exceptions
|
from tatt import exceptions
|
||||||
from .vendor import TranscriberBaseClass
|
from .vendor import TranscriberBaseClass
|
||||||
|
|
||||||
NAME = 'amazon'
|
NAME = "amazon"
|
||||||
BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME)
|
BUCKET_NAME_MEDIA = config.BUCKET_NAME_FMTR_MEDIA.format(NAME)
|
||||||
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
|
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
|
||||||
TRANSCRIPT_TYPE = dict
|
TRANSCRIPT_TYPE = dict
|
||||||
@@ -19,28 +19,35 @@ TRANSCRIPT_TYPE = dict
|
|||||||
|
|
||||||
def _check_for_config() -> bool:
|
def _check_for_config() -> bool:
|
||||||
return (
|
return (
|
||||||
config.AWS_CONFIG_FILEPATH.exists()
|
config.AWS_CONFIG_FILEPATH.exists() and config.AWS_CREDENTIALS_FILEPATH.exists()
|
||||||
and config.AWS_CREDENTIALS_FILEPATH.exists()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transcriber(TranscriberBaseClass):
|
class Transcriber(TranscriberBaseClass):
|
||||||
|
|
||||||
name = NAME
|
name = NAME
|
||||||
cost_per_15_seconds = .024 / 4
|
cost_per_15_seconds = 0.024 / 4
|
||||||
bucket_names = {'media': BUCKET_NAME_MEDIA,
|
bucket_names = {"media": BUCKET_NAME_MEDIA, "transcript": BUCKET_NAME_TRANSCRIPT}
|
||||||
'transcript': BUCKET_NAME_TRANSCRIPT}
|
|
||||||
|
|
||||||
no_config_error_message = 'please run "aws configure" first'
|
no_config_error_message = 'please run "aws configure" first'
|
||||||
transcript_type = TRANSCRIPT_TYPE
|
transcript_type = TRANSCRIPT_TYPE
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/transcribe.html
|
||||||
_language_list = ['en-US', 'es-US', 'en-AU', 'fr-CA', 'en-GB', 'de-DE',
|
_language_list = [
|
||||||
'pt-BR', 'fr-FR', 'it-IT', 'ko-KR']
|
"en-US",
|
||||||
|
"es-US",
|
||||||
|
"en-AU",
|
||||||
|
"fr-CA",
|
||||||
|
"en-GB",
|
||||||
|
"de-DE",
|
||||||
|
"pt-BR",
|
||||||
|
"fr-FR",
|
||||||
|
"it-IT",
|
||||||
|
"ko-KR",
|
||||||
|
]
|
||||||
|
|
||||||
if _check_for_config():
|
if _check_for_config():
|
||||||
tr = boto3.client('transcribe')
|
tr = boto3.client("transcribe")
|
||||||
s3 = boto3.resource('s3')
|
s3 = boto3.resource("s3")
|
||||||
|
|
||||||
def __init__(self, filepath):
|
def __init__(self, filepath):
|
||||||
super().__init__(filepath)
|
super().__init__(filepath)
|
||||||
@@ -75,98 +82,96 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
cls.s3.create_bucket(Bucket=bucket_name)
|
cls.s3.create_bucket(Bucket=bucket_name)
|
||||||
|
|
||||||
def transcribe(self, **kwargs) -> str:
|
def transcribe(self, **kwargs) -> str:
|
||||||
|
super().transcribe(**kwargs)
|
||||||
self._upload_file()
|
self._upload_file()
|
||||||
try:
|
try:
|
||||||
return self._request_transcription(**kwargs)
|
return self._request_transcription(**kwargs)
|
||||||
except self.tr.exceptions.ConflictException:
|
except self.tr.exceptions.ConflictException:
|
||||||
raise exceptions.AlreadyExistsError(
|
raise exceptions.AlreadyExistsError(
|
||||||
f'{self.basename} already exists on {NAME}')
|
f"{self.basename} already exists on {NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
def _upload_file(self):
|
def _upload_file(self):
|
||||||
self.s3.Bucket(self.bucket_names['media']).upload_file(
|
self.s3.Bucket(self.bucket_names["media"]).upload_file(
|
||||||
str(self.filepath),
|
str(self.filepath), self.basename
|
||||||
self.basename)
|
)
|
||||||
|
|
||||||
def _request_transcription(
|
def _request_transcription(
|
||||||
self,
|
self, language_code="en-US", num_speakers=2, enable_speaker_diarization=True
|
||||||
language_code='en-US',
|
|
||||||
num_speakers=2,
|
|
||||||
enable_speaker_diarization=True,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
job_name = self.basename
|
job_name = self.basename
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
TranscriptionJobName=job_name,
|
TranscriptionJobName=job_name,
|
||||||
LanguageCode=language_code,
|
LanguageCode=language_code,
|
||||||
MediaFormat=self.basename.split('.')[-1].lower(),
|
MediaFormat=self.basename.split(".")[-1].lower(),
|
||||||
Media={
|
Media={"MediaFileUri": self.media_file_uri},
|
||||||
'MediaFileUri': self.media_file_uri
|
OutputBucketName=self.bucket_names["transcript"],
|
||||||
},
|
|
||||||
OutputBucketName=self.bucket_names['transcript']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if enable_speaker_diarization:
|
if enable_speaker_diarization:
|
||||||
kwargs.update(dict(
|
kwargs.update(
|
||||||
|
dict(
|
||||||
Settings={
|
Settings={
|
||||||
'ShowSpeakerLabels': True,
|
"ShowSpeakerLabels": True,
|
||||||
'MaxSpeakerLabels': num_speakers,
|
"MaxSpeakerLabels": num_speakers,
|
||||||
}
|
}
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.tr.start_transcription_job(**kwargs)
|
self.tr.start_transcription_job(**kwargs)
|
||||||
return job_name
|
return job_name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_transcription_jobs(
|
def get_transcription_jobs(
|
||||||
cls,
|
cls, status: str = None, job_name_query: str = None
|
||||||
status:str = None,
|
|
||||||
job_name_query:str = None,
|
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
|
|
||||||
kwargs = {'MaxResults': 100}
|
kwargs = {"MaxResults": 100}
|
||||||
|
|
||||||
if status is not None:
|
if status is not None:
|
||||||
kwargs['Status'] = status.upper()
|
kwargs["Status"] = status.upper()
|
||||||
if job_name_query is not None:
|
if job_name_query is not None:
|
||||||
kwargs['JobNameContains'] = job_name_query
|
kwargs["JobNameContains"] = job_name_query
|
||||||
|
|
||||||
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
|
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
|
||||||
key = 'TranscriptionJobSummaries'
|
key = "TranscriptionJobSummaries"
|
||||||
|
|
||||||
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
|
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
|
||||||
|
|
||||||
while jobs_data.get('NextToken'):
|
while jobs_data.get("NextToken"):
|
||||||
token = jobs_data['NextToken']
|
token = jobs_data["NextToken"]
|
||||||
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
|
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
|
||||||
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
|
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve_transcript(cls, transcription_job_name: str
|
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
|
||||||
) -> TRANSCRIPT_TYPE:
|
job = cls.tr.get_transcription_job(TranscriptionJobName=transcription_job_name)[
|
||||||
job = cls.tr.get_transcription_job(
|
"TranscriptionJob"
|
||||||
TranscriptionJobName=transcription_job_name
|
]
|
||||||
)['TranscriptionJob']
|
|
||||||
|
|
||||||
if not job['TranscriptionJobStatus'] == 'COMPLETED':
|
if not job["TranscriptionJobStatus"] == "COMPLETED":
|
||||||
return
|
return
|
||||||
|
|
||||||
transcript_file_uri = job['Transcript']['TranscriptFileUri']
|
transcript_file_uri = job["Transcript"]["TranscriptFileUri"]
|
||||||
transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1]
|
transcript_path = transcript_file_uri.split("amazonaws.com/", 1)[1]
|
||||||
|
|
||||||
transcript_bucket = transcript_path.split('/', 1)[0]
|
transcript_bucket = transcript_path.split("/", 1)[0]
|
||||||
transcript_key = transcript_path.split('/', 1)[1]
|
transcript_key = transcript_path.split("/", 1)[1]
|
||||||
|
|
||||||
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
|
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
|
||||||
transcript_json = s3_object['Body'].read().decode('utf-8')
|
transcript_json = s3_object["Body"].read().decode("utf-8")
|
||||||
return json.loads(transcript_json)
|
return json.loads(transcript_json)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def homogenize_transcription_job_data(transcription_job_data):
|
def homogenize_transcription_job_data(transcription_job_data):
|
||||||
return [{
|
return [
|
||||||
'created': jd['CreationTime'],
|
{
|
||||||
'name': jd['TranscriptionJobName'],
|
"created": jd["CreationTime"],
|
||||||
'status': jd['TranscriptionJobStatus']
|
"name": jd["TranscriptionJobName"],
|
||||||
|
"status": jd["TranscriptionJobStatus"],
|
||||||
}
|
}
|
||||||
for jd in transcription_job_data]
|
for jd in transcription_job_data
|
||||||
|
]
|
||||||
|
|||||||
211
tatt/vendors/google.py
vendored
211
tatt/vendors/google.py
vendored
@@ -17,46 +17,150 @@ from google.cloud import (
|
|||||||
from tatt import exceptions, helpers, config as config_mod
|
from tatt import exceptions, helpers, config as config_mod
|
||||||
from .vendor import TranscriberBaseClass
|
from .vendor import TranscriberBaseClass
|
||||||
|
|
||||||
NAME = 'google'
|
NAME = "google"
|
||||||
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format(
|
BUCKET_NAME_TRANSCRIPT = config_mod.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format("goog")
|
||||||
'goog')
|
|
||||||
TRANSCRIPT_TYPE = str
|
TRANSCRIPT_TYPE = str
|
||||||
|
|
||||||
|
|
||||||
def _check_for_config():
|
def _check_for_config():
|
||||||
return os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None
|
return os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None
|
||||||
|
|
||||||
|
|
||||||
class Transcriber(TranscriberBaseClass):
|
class Transcriber(TranscriberBaseClass):
|
||||||
|
|
||||||
name = NAME
|
name = NAME
|
||||||
SUPPORTED_FORMATS = ['flac']
|
SUPPORTED_FORMATS = ["flac"]
|
||||||
cost_per_15_seconds = [.004, .006, .009]
|
cost_per_15_seconds = [0.004, 0.006, 0.009]
|
||||||
no_config_error_message = (
|
no_config_error_message = (
|
||||||
'Please sign up for the Google Speech-to-Text API '
|
"Please sign up for the Google Speech-to-Text API "
|
||||||
'and put the path to your credentials in an '
|
"and put the path to your credentials in an "
|
||||||
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
|
'environment variable "GOOGLE_APPLICATION_CREDENTIALS"'
|
||||||
)
|
)
|
||||||
transcript_type = TRANSCRIPT_TYPE
|
transcript_type = TRANSCRIPT_TYPE
|
||||||
# https://cloud.google.com/speech-to-text/docs/languages
|
# https://cloud.google.com/speech-to-text/docs/languages
|
||||||
# Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText)
|
# Array.from(document.querySelector('.devsite-table-wrapper').querySelectorAll('table tr')).slice(1).map(row => row.children[1].innerText)
|
||||||
_language_list = [
|
_language_list = [
|
||||||
'af-ZA', 'am-ET', 'hy-AM', 'az-AZ', 'id-ID', 'ms-MY',
|
"af-ZA",
|
||||||
'bn-BD', 'bn-IN', 'ca-ES', 'cs-CZ', 'da-DK', 'de-DE', 'en-AU', 'en-CA',
|
"am-ET",
|
||||||
'en-GH', 'en-GB', 'en-IN', 'en-IE', 'en-KE', 'en-NZ', 'en-NG', 'en-PH',
|
"hy-AM",
|
||||||
'en-SG', 'en-ZA', 'en-TZ', 'en-US', 'es-AR', 'es-BO', 'es-CL', 'es-CO',
|
"az-AZ",
|
||||||
'es-CR', 'es-EC', 'es-SV', 'es-ES', 'es-US', 'es-GT', 'es-HN', 'es-MX',
|
"id-ID",
|
||||||
'es-NI', 'es-PA', 'es-PY', 'es-PE', 'es-PR', 'es-DO', 'es-UY', 'es-VE',
|
"ms-MY",
|
||||||
'eu-ES', 'fil-PH', 'fr-CA', 'fr-FR', 'gl-ES', 'ka-GE', 'gu-IN', 'hr-HR',
|
"bn-BD",
|
||||||
'zu-ZA', 'is-IS', 'it-IT', 'jv-ID', 'kn-IN', 'km-KH', 'lo-LA', 'lv-LV',
|
"bn-IN",
|
||||||
'lt-LT', 'hu-HU', 'ml-IN', 'mr-IN', 'nl-NL', 'ne-NP', 'nb-NO', 'pl-PL',
|
"ca-ES",
|
||||||
'pt-BR', 'pt-PT', 'ro-RO', 'si-LK', 'sk-SK', 'sl-SI', 'su-ID', 'sw-TZ',
|
"cs-CZ",
|
||||||
'sw-KE', 'fi-FI', 'sv-SE', 'ta-IN', 'ta-SG', 'ta-LK', 'ta-MY', 'te-IN',
|
"da-DK",
|
||||||
'vi-VN', 'tr-TR', 'ur-PK', 'ur-IN', 'el-GR', 'bg-BG', 'ru-RU', 'sr-RS',
|
"de-DE",
|
||||||
'uk-UA', 'he-IL', 'ar-IL', 'ar-JO', 'ar-AE', 'ar-BH', 'ar-DZ', 'ar-SA',
|
"en-AU",
|
||||||
'ar-IQ', 'ar-KW', 'ar-MA', 'ar-TN', 'ar-OM', 'ar-PS', 'ar-QA', 'ar-LB',
|
"en-CA",
|
||||||
'ar-EG', 'fa-IR', 'hi-IN', 'th-TH', 'ko-KR', 'zh-TW', 'yue-Hant-HK',
|
"en-GH",
|
||||||
'ja-JP', 'zh-HK', 'zh']
|
"en-GB",
|
||||||
|
"en-IN",
|
||||||
|
"en-IE",
|
||||||
|
"en-KE",
|
||||||
|
"en-NZ",
|
||||||
|
"en-NG",
|
||||||
|
"en-PH",
|
||||||
|
"en-SG",
|
||||||
|
"en-ZA",
|
||||||
|
"en-TZ",
|
||||||
|
"en-US",
|
||||||
|
"es-AR",
|
||||||
|
"es-BO",
|
||||||
|
"es-CL",
|
||||||
|
"es-CO",
|
||||||
|
"es-CR",
|
||||||
|
"es-EC",
|
||||||
|
"es-SV",
|
||||||
|
"es-ES",
|
||||||
|
"es-US",
|
||||||
|
"es-GT",
|
||||||
|
"es-HN",
|
||||||
|
"es-MX",
|
||||||
|
"es-NI",
|
||||||
|
"es-PA",
|
||||||
|
"es-PY",
|
||||||
|
"es-PE",
|
||||||
|
"es-PR",
|
||||||
|
"es-DO",
|
||||||
|
"es-UY",
|
||||||
|
"es-VE",
|
||||||
|
"eu-ES",
|
||||||
|
"fil-PH",
|
||||||
|
"fr-CA",
|
||||||
|
"fr-FR",
|
||||||
|
"gl-ES",
|
||||||
|
"ka-GE",
|
||||||
|
"gu-IN",
|
||||||
|
"hr-HR",
|
||||||
|
"zu-ZA",
|
||||||
|
"is-IS",
|
||||||
|
"it-IT",
|
||||||
|
"jv-ID",
|
||||||
|
"kn-IN",
|
||||||
|
"km-KH",
|
||||||
|
"lo-LA",
|
||||||
|
"lv-LV",
|
||||||
|
"lt-LT",
|
||||||
|
"hu-HU",
|
||||||
|
"ml-IN",
|
||||||
|
"mr-IN",
|
||||||
|
"nl-NL",
|
||||||
|
"ne-NP",
|
||||||
|
"nb-NO",
|
||||||
|
"pl-PL",
|
||||||
|
"pt-BR",
|
||||||
|
"pt-PT",
|
||||||
|
"ro-RO",
|
||||||
|
"si-LK",
|
||||||
|
"sk-SK",
|
||||||
|
"sl-SI",
|
||||||
|
"su-ID",
|
||||||
|
"sw-TZ",
|
||||||
|
"sw-KE",
|
||||||
|
"fi-FI",
|
||||||
|
"sv-SE",
|
||||||
|
"ta-IN",
|
||||||
|
"ta-SG",
|
||||||
|
"ta-LK",
|
||||||
|
"ta-MY",
|
||||||
|
"te-IN",
|
||||||
|
"vi-VN",
|
||||||
|
"tr-TR",
|
||||||
|
"ur-PK",
|
||||||
|
"ur-IN",
|
||||||
|
"el-GR",
|
||||||
|
"bg-BG",
|
||||||
|
"ru-RU",
|
||||||
|
"sr-RS",
|
||||||
|
"uk-UA",
|
||||||
|
"he-IL",
|
||||||
|
"ar-IL",
|
||||||
|
"ar-JO",
|
||||||
|
"ar-AE",
|
||||||
|
"ar-BH",
|
||||||
|
"ar-DZ",
|
||||||
|
"ar-SA",
|
||||||
|
"ar-IQ",
|
||||||
|
"ar-KW",
|
||||||
|
"ar-MA",
|
||||||
|
"ar-TN",
|
||||||
|
"ar-OM",
|
||||||
|
"ar-PS",
|
||||||
|
"ar-QA",
|
||||||
|
"ar-LB",
|
||||||
|
"ar-EG",
|
||||||
|
"fa-IR",
|
||||||
|
"hi-IN",
|
||||||
|
"th-TH",
|
||||||
|
"ko-KR",
|
||||||
|
"zh-TW",
|
||||||
|
"yue-Hant-HK",
|
||||||
|
"ja-JP",
|
||||||
|
"zh-HK",
|
||||||
|
"zh",
|
||||||
|
]
|
||||||
|
|
||||||
if _check_for_config():
|
if _check_for_config():
|
||||||
speech_client = speech.SpeechClient()
|
speech_client = speech.SpeechClient()
|
||||||
@@ -69,10 +173,11 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _setup(cls):
|
def _setup(cls):
|
||||||
super()._setup()
|
super()._setup()
|
||||||
if not shutil.which('gsutil'):
|
if not shutil.which("gsutil"):
|
||||||
raise exceptions.DependencyRequired(
|
raise exceptions.DependencyRequired(
|
||||||
'Please install gcloud using the steps here:'
|
"Please install gcloud using the steps here:"
|
||||||
'https://cloud.google.com/storage/docs/gsutil_install')
|
"https://cloud.google.com/storage/docs/gsutil_install"
|
||||||
|
)
|
||||||
|
|
||||||
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
|
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
|
||||||
|
|
||||||
@@ -84,13 +189,13 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
# this might fail if a bucket by the name exists *anywhere* on GCS?
|
# this might fail if a bucket by the name exists *anywhere* on GCS?
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print('made Google Cloud Storage Bucket for transcripts')
|
print("made Google Cloud Storage Bucket for transcripts")
|
||||||
|
|
||||||
def convert_file_format_if_needed(self):
|
def convert_file_format_if_needed(self):
|
||||||
if self.file_format not in self.SUPPORTED_FORMATS:
|
if self.file_format not in self.SUPPORTED_FORMATS:
|
||||||
if not shutil.which('ffmpeg'):
|
if not shutil.which("ffmpeg"):
|
||||||
raise exceptions.DependencyRequired('please install ffmpeg')
|
raise exceptions.DependencyRequired("please install ffmpeg")
|
||||||
self.filepath = helpers.convert_file(self.filepath, 'flac')
|
self.filepath = helpers.convert_file(self.filepath, "flac")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def file_format(self):
|
def file_format(self):
|
||||||
@@ -111,31 +216,31 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
|
|
||||||
def _check_if_transcript_exists(self, transcript_name=None):
|
def _check_if_transcript_exists(self, transcript_name=None):
|
||||||
return storage.Blob(
|
return storage.Blob(
|
||||||
bucket=self.transcript_bucket,
|
bucket=self.transcript_bucket, name=transcript_name or self.basename
|
||||||
name=transcript_name or self.basename
|
|
||||||
).exists(self.storage_client)
|
).exists(self.storage_client)
|
||||||
|
|
||||||
def _request_transcription(
|
def _request_transcription(
|
||||||
self,
|
self,
|
||||||
language_code='en-US',
|
language_code="en-US",
|
||||||
enable_automatic_punctuation=True,
|
enable_automatic_punctuation=True,
|
||||||
enable_speaker_diarization=True,
|
enable_speaker_diarization=True,
|
||||||
num_speakers=2,
|
num_speakers=2,
|
||||||
model='phone_call',
|
model="phone_call",
|
||||||
use_enhanced=True,
|
use_enhanced=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Returns the job_name"""
|
"""Returns the job_name"""
|
||||||
if self._check_if_transcript_exists():
|
if self._check_if_transcript_exists():
|
||||||
raise exceptions.AlreadyExistsError(
|
raise exceptions.AlreadyExistsError(
|
||||||
f'{self.basename} already exists on {NAME}')
|
f"{self.basename} already exists on {NAME}"
|
||||||
|
)
|
||||||
num_audio_channels = helpers.get_num_audio_channels(self.filepath)
|
num_audio_channels = helpers.get_num_audio_channels(self.filepath)
|
||||||
sample_rate = helpers.get_sample_rate(self.filepath)
|
sample_rate = helpers.get_sample_rate(self.filepath)
|
||||||
|
|
||||||
with io.open(self.filepath, 'rb') as audio_file:
|
with io.open(self.filepath, "rb") as audio_file:
|
||||||
content = audio_file.read()
|
content = audio_file.read()
|
||||||
audio = speech.types.RecognitionAudio(content=content)
|
audio = speech.types.RecognitionAudio(content=content)
|
||||||
|
|
||||||
if language_code != 'en-US':
|
if language_code != "en-US":
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
config = speech.types.RecognitionConfig(
|
config = speech.types.RecognitionConfig(
|
||||||
@@ -153,37 +258,35 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
use_enhanced=use_enhanced,
|
use_enhanced=use_enhanced,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.operation = self.speech_client.long_running_recognize(config,
|
self.operation = self.speech_client.long_running_recognize(config, audio)
|
||||||
audio)
|
|
||||||
|
|
||||||
print('transcribing...')
|
print("transcribing...")
|
||||||
while not self.operation.done():
|
while not self.operation.done():
|
||||||
sleep(1)
|
sleep(1)
|
||||||
print('.')
|
print(".")
|
||||||
|
|
||||||
result_list = []
|
result_list = []
|
||||||
|
|
||||||
for result in self.operation.result().results:
|
for result in self.operation.result().results:
|
||||||
result_list.append(str(result))
|
result_list.append(str(result))
|
||||||
|
|
||||||
print('saving transcript')
|
print("saving transcript")
|
||||||
transcript_path = '/tmp/transcript.txt'
|
transcript_path = "/tmp/transcript.txt"
|
||||||
with open(transcript_path, 'w') as fout:
|
with open(transcript_path, "w") as fout:
|
||||||
fout.write('\n'.join(result_list))
|
fout.write("\n".join(result_list))
|
||||||
print('uploading transcript')
|
print("uploading transcript")
|
||||||
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
|
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
|
||||||
os.remove(transcript_path)
|
os.remove(transcript_path)
|
||||||
|
|
||||||
return self.basename
|
return self.basename
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve_transcript(cls, transcription_job_name: str
|
def retrieve_transcript(cls, transcription_job_name: str) -> TRANSCRIPT_TYPE:
|
||||||
) -> TRANSCRIPT_TYPE:
|
|
||||||
"""Get transcript from BUCKET_NAME_TRANSCRIPT"""
|
"""Get transcript from BUCKET_NAME_TRANSCRIPT"""
|
||||||
if not cls._check_if_transcript_exists(
|
if not cls._check_if_transcript_exists(
|
||||||
cls,
|
cls, transcript_name=transcription_job_name
|
||||||
transcript_name=transcription_job_name):
|
):
|
||||||
raise exceptions.DoesntExistError('no such transcript!')
|
raise exceptions.DoesntExistError("no such transcript!")
|
||||||
blob = cls.transcript_bucket.blob(transcription_job_name)
|
blob = cls.transcript_bucket.blob(transcription_job_name)
|
||||||
f = tempfile.NamedTemporaryFile(delete=False)
|
f = tempfile.NamedTemporaryFile(delete=False)
|
||||||
f.close()
|
f.close()
|
||||||
@@ -202,7 +305,7 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]:
|
def get_transcription_jobs(cls, job_name_query=None, status=None) -> List[dict]:
|
||||||
|
|
||||||
if status and status.lower() != 'completed':
|
if status and status.lower() != "completed":
|
||||||
return []
|
return []
|
||||||
|
|
||||||
jobs = []
|
jobs = []
|
||||||
@@ -210,6 +313,6 @@ class Transcriber(TranscriberBaseClass):
|
|||||||
for t in cls.transcript_bucket.list_blobs():
|
for t in cls.transcript_bucket.list_blobs():
|
||||||
if job_name_query is not None and t.name != job_name_query:
|
if job_name_query is not None and t.name != job_name_query:
|
||||||
continue
|
continue
|
||||||
jobs.append({'name': t.name, 'status': 'COMPLETED'})
|
jobs.append({"name": t.name, "status": "COMPLETED"})
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|||||||
14
tatt/vendors/vendor.py
vendored
14
tatt/vendors/vendor.py
vendored
@@ -12,8 +12,8 @@ class TranscriberBaseClass:
|
|||||||
|
|
||||||
def __init__(self, filepath):
|
def __init__(self, filepath):
|
||||||
self._setup()
|
self._setup()
|
||||||
if ' ' in filepath:
|
if " " in filepath:
|
||||||
raise exceptions.FormatError('Please don\'t put any spaces in the filename.')
|
raise exceptions.FormatError("Please don't put any spaces in the filename.")
|
||||||
self.filepath = PurePath(filepath)
|
self.filepath = PurePath(filepath)
|
||||||
self.basename = str(os.path.basename(self.filepath))
|
self.basename = str(os.path.basename(self.filepath))
|
||||||
|
|
||||||
@@ -52,14 +52,14 @@ class TranscriberBaseClass:
|
|||||||
def check_for_config(cls) -> bool:
|
def check_for_config(cls) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
def transcribe(self, **kwargs) -> str:
|
||||||
def transcribe(self) -> str:
|
|
||||||
"""
|
"""
|
||||||
This should do any required logic,
|
This should do any required logic,
|
||||||
then call self._request_transcription.
|
then call self._request_transcription.
|
||||||
It should return the job_name.
|
It should return the job_name.
|
||||||
"""
|
"""
|
||||||
pass
|
if kwargs["language_code"] not in self.language_list():
|
||||||
|
raise KeyError(f"No such language code {kwargs['language_code']}")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _request_transcription(self) -> str:
|
def _request_transcription(self) -> str:
|
||||||
@@ -72,12 +72,10 @@ class TranscriberBaseClass:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def retrieve_transcript(cls, transcription_job_name: str
|
def retrieve_transcript(cls, transcription_job_name: str) -> Union[str, dict]:
|
||||||
) -> Union[str, dict]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_transcription_jobs() -> List[dict]:
|
def get_transcription_jobs() -> List[dict]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user