set up abstract base class
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as file:
|
||||
|
||||
setup(
|
||||
name="tatt",
|
||||
version="0.970",
|
||||
version="0.972",
|
||||
py_modules=['tatt'],
|
||||
url='https://github.com/zevaverbach/tatt',
|
||||
install_requires=[
|
||||
|
||||
@@ -32,7 +32,7 @@ def get(name, save, pretty):
|
||||
file = None
|
||||
if save:
|
||||
filepath = f'{name}.json'
|
||||
file = open(filepath)
|
||||
file = open(filepath, 'w')
|
||||
|
||||
click.echo(transcript, file=file)
|
||||
|
||||
|
||||
1
tatt/vendors/__init__.py
vendored
1
tatt/vendors/__init__.py
vendored
@@ -3,3 +3,4 @@ from tatt.vendors import amazon
|
||||
SERVICES = {
|
||||
'amazon': amazon,
|
||||
}
|
||||
|
||||
|
||||
97
tatt/vendors/amazon.py
vendored
97
tatt/vendors/amazon.py
vendored
@@ -9,6 +9,7 @@ import boto3
|
||||
|
||||
from tatt import config
|
||||
from tatt import exceptions
|
||||
from .vendor import TranscriberBaseClass
|
||||
|
||||
NAME = 'amazon'
|
||||
BUCKET_NAME_MEDIA = config.AWS_BUCKET_NAME_FMTR_MEDIA.format(NAME)
|
||||
@@ -16,63 +17,72 @@ BUCKET_NAME_TRANSCRIPT = config.AWS_BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
|
||||
cost_per_15_seconds = .024 / 4
|
||||
|
||||
|
||||
def check_for_config():
|
||||
def _check_for_config() -> bool:
|
||||
return (
|
||||
config.AWS_CONFIG_FILEPATH.exists()
|
||||
and config.AWS_CREDENTIALS_FILEPATH.exists()
|
||||
)
|
||||
|
||||
|
||||
if check_for_config():
|
||||
tr = boto3.client('transcribe')
|
||||
s3 = boto3.resource('s3')
|
||||
|
||||
|
||||
class Transcriber:
|
||||
class Transcriber(TranscriberBaseClass):
|
||||
|
||||
bucket_names = {'media': BUCKET_NAME_MEDIA,
|
||||
'transcript': BUCKET_NAME_TRANSCRIPT}
|
||||
|
||||
no_config_error_message = 'please run "aws configure" first'
|
||||
|
||||
if _check_for_config():
|
||||
tr = boto3.client('transcribe')
|
||||
s3 = boto3.resource('s3')
|
||||
|
||||
def __init__(self, filepath):
|
||||
super().__init__(filepath)
|
||||
self._setup()
|
||||
self.filepath = PurePath(filepath)
|
||||
self.basename = str(os.path.basename(self.filepath))
|
||||
self.media_file_uri = (
|
||||
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
|
||||
f"{self.bucket_names['media']}/{self.basename}")
|
||||
|
||||
@classmethod
|
||||
def check_for_config(cls):
|
||||
return _check_for_config()
|
||||
|
||||
@property
|
||||
def media_file_uri(self):
|
||||
return (
|
||||
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
|
||||
f"{self.bucket_names['media']}/{self.basename}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _setup(cls):
|
||||
if not check_for_config():
|
||||
raise exceptions.ConfigError('please run "aws configure" first')
|
||||
super()._setup()
|
||||
for bucket_name in cls.bucket_names.values():
|
||||
if not cls.check_for_bucket(bucket_name):
|
||||
cls.make_bucket(bucket_name)
|
||||
|
||||
@staticmethod
|
||||
def check_for_bucket(bucket_name):
|
||||
return bool(s3.Bucket(bucket_name).creation_date)
|
||||
@classmethod
|
||||
def check_for_bucket(cls, bucket_name: str) -> bool:
|
||||
return bool(cls.s3.Bucket(bucket_name).creation_date)
|
||||
|
||||
@staticmethod
|
||||
def make_bucket(bucket_name):
|
||||
s3.create_bucket(Bucket=bucket_name)
|
||||
@classmethod
|
||||
def make_bucket(cls, bucket_name):
|
||||
cls.s3.create_bucket(Bucket=bucket_name)
|
||||
|
||||
def transcribe(self):
|
||||
def transcribe(self) -> str:
|
||||
self._upload_file()
|
||||
try:
|
||||
return self._request_transcription()
|
||||
except tr.exceptions.ConflictException:
|
||||
except self.tr.exceptions.ConflictException:
|
||||
raise exceptions.AlreadyExistsError(
|
||||
f'{self.basename} already exists on {NAME}')
|
||||
|
||||
def _upload_file(self):
|
||||
s3.Bucket(self.bucket_names['media']).upload_file(
|
||||
self.s3.Bucket(self.bucket_names['media']).upload_file(
|
||||
str(self.filepath),
|
||||
self.basename)
|
||||
|
||||
def _request_transcription(self, language_code='en-US'):
|
||||
def _request_transcription(self, language_code='en-US') -> str:
|
||||
job_name = self.basename
|
||||
tr.start_transcription_job(
|
||||
self.tr.start_transcription_job(
|
||||
TranscriptionJobName=job_name,
|
||||
LanguageCode=language_code,
|
||||
MediaFormat=self.basename.split('.')[-1].lower(),
|
||||
@@ -83,10 +93,11 @@ class Transcriber:
|
||||
)
|
||||
return job_name
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def get_transcription_jobs(
|
||||
status=None,
|
||||
job_name_query=None
|
||||
cls,
|
||||
status:str = None,
|
||||
job_name_query:str = None,
|
||||
) -> List[dict]:
|
||||
|
||||
kwargs = {'MaxResults': 100}
|
||||
@@ -96,21 +107,21 @@ class Transcriber:
|
||||
if job_name_query is not None:
|
||||
kwargs['JobNameContains'] = job_name_query
|
||||
|
||||
jobs_data = tr.list_transcription_jobs(**kwargs)
|
||||
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
|
||||
key = 'TranscriptionJobSummaries'
|
||||
|
||||
jobs = homogenize_transcription_job_data(jobs_data[key])
|
||||
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
while jobs_data.get('NextToken'):
|
||||
token = jobs_data['NextToken']
|
||||
jobs_data = tr.list_transcription_jobs(NextToken=token)
|
||||
jobs += homogenize_transcription_job_data(jobs_data[key])
|
||||
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
|
||||
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
return jobs
|
||||
|
||||
@staticmethod
|
||||
def retrieve_transcript(transcription_job_name):
|
||||
job = tr.get_transcription_job(
|
||||
@classmethod
|
||||
def retrieve_transcript(cls, transcription_job_name: str) -> dict:
|
||||
job = cls.tr.get_transcription_job(
|
||||
TranscriptionJobName=transcription_job_name
|
||||
)['TranscriptionJob']
|
||||
|
||||
@@ -123,17 +134,15 @@ class Transcriber:
|
||||
transcript_bucket = transcript_path.split('/', 1)[0]
|
||||
transcript_key = transcript_path.split('/', 1)[1]
|
||||
|
||||
s3_object = s3.Object(transcript_bucket, transcript_key).get()
|
||||
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
|
||||
transcript_json = s3_object['Body'].read().decode('utf-8')
|
||||
return json.loads(transcript_json)
|
||||
|
||||
|
||||
|
||||
|
||||
def homogenize_transcription_job_data(transcription_job_data):
|
||||
return [{
|
||||
'created': jd['CreationTime'],
|
||||
'name': jd['TranscriptionJobName'],
|
||||
'status': jd['TranscriptionJobStatus']
|
||||
}
|
||||
for jd in transcription_job_data]
|
||||
@staticmethod
|
||||
def homogenize_transcription_job_data(transcription_job_data):
|
||||
return [{
|
||||
'created': jd['CreationTime'],
|
||||
'name': jd['TranscriptionJobName'],
|
||||
'status': jd['TranscriptionJobStatus']
|
||||
}
|
||||
for jd in transcription_job_data]
|
||||
|
||||
57
tatt/vendors/vendor.py
vendored
Normal file
57
tatt/vendors/vendor.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
import abc
|
||||
import os
|
||||
from pathlib import PurePath
|
||||
|
||||
|
||||
class TranscriberBaseClass:
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, filepath):
|
||||
self._setup()
|
||||
self.filepath = PurePath(filepath)
|
||||
self.basename = str(os.path.basename(self.filepath))
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def no_config_error_message(self):
|
||||
"""
|
||||
This must be defined as a class attribute, to be printed when raising
|
||||
such an error.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _setup(cls):
|
||||
if not cls.check_for_config():
|
||||
raise exceptions.ConfigError(cls.no_config_error_message)
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def check_for_config() -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def transcribe(self) -> str:
|
||||
"""
|
||||
This should do any required logic,
|
||||
then call self._request_transcription.
|
||||
It should return the job_name.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _request_transcription(self) -> str:
|
||||
"""Returns the job_name"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def retrieve_transcript(transcription_job_name: str) -> dict:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_transcription_jobs():
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user