set up abstract base class

This commit is contained in:
2019-03-05 19:04:48 -05:00
parent b858e7def0
commit 7dff0121dd
5 changed files with 113 additions and 46 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as file:
setup(
name="tatt",
version="0.970",
version="0.972",
py_modules=['tatt'],
url='https://github.com/zevaverbach/tatt',
install_requires=[

View File

@@ -32,7 +32,7 @@ def get(name, save, pretty):
file = None
if save:
filepath = f'{name}.json'
file = open(filepath)
file = open(filepath, 'w')
click.echo(transcript, file=file)

View File

@@ -3,3 +3,4 @@ from tatt.vendors import amazon
SERVICES = {
'amazon': amazon,
}

View File

@@ -9,6 +9,7 @@ import boto3
from tatt import config
from tatt import exceptions
from .vendor import TranscriberBaseClass
NAME = 'amazon'
BUCKET_NAME_MEDIA = config.AWS_BUCKET_NAME_FMTR_MEDIA.format(NAME)
@@ -16,63 +17,72 @@ BUCKET_NAME_TRANSCRIPT = config.AWS_BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
cost_per_15_seconds = .024 / 4
def check_for_config():
def _check_for_config() -> bool:
return (
config.AWS_CONFIG_FILEPATH.exists()
and config.AWS_CREDENTIALS_FILEPATH.exists()
)
if check_for_config():
tr = boto3.client('transcribe')
s3 = boto3.resource('s3')
class Transcriber:
class Transcriber(TranscriberBaseClass):
bucket_names = {'media': BUCKET_NAME_MEDIA,
'transcript': BUCKET_NAME_TRANSCRIPT}
no_config_error_message = 'please run "aws configure" first'
if _check_for_config():
tr = boto3.client('transcribe')
s3 = boto3.resource('s3')
def __init__(self, filepath):
super().__init__(filepath)
self._setup()
self.filepath = PurePath(filepath)
self.basename = str(os.path.basename(self.filepath))
self.media_file_uri = (
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
f"{self.bucket_names['media']}/{self.basename}")
@classmethod
def check_for_config(cls):
return _check_for_config()
@property
def media_file_uri(self):
return (
f"https://s3-{config.AWS_REGION}.amazonaws.com/"
f"{self.bucket_names['media']}/{self.basename}"
)
@classmethod
def _setup(cls):
if not check_for_config():
raise exceptions.ConfigError('please run "aws configure" first')
super()._setup()
for bucket_name in cls.bucket_names.values():
if not cls.check_for_bucket(bucket_name):
cls.make_bucket(bucket_name)
@staticmethod
def check_for_bucket(bucket_name):
return bool(s3.Bucket(bucket_name).creation_date)
@classmethod
def check_for_bucket(cls, bucket_name: str) -> bool:
return bool(cls.s3.Bucket(bucket_name).creation_date)
@staticmethod
def make_bucket(bucket_name):
s3.create_bucket(Bucket=bucket_name)
@classmethod
def make_bucket(cls, bucket_name):
cls.s3.create_bucket(Bucket=bucket_name)
def transcribe(self):
def transcribe(self) -> str:
self._upload_file()
try:
return self._request_transcription()
except tr.exceptions.ConflictException:
except self.tr.exceptions.ConflictException:
raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}')
def _upload_file(self):
s3.Bucket(self.bucket_names['media']).upload_file(
self.s3.Bucket(self.bucket_names['media']).upload_file(
str(self.filepath),
self.basename)
def _request_transcription(self, language_code='en-US'):
def _request_transcription(self, language_code='en-US') -> str:
job_name = self.basename
tr.start_transcription_job(
self.tr.start_transcription_job(
TranscriptionJobName=job_name,
LanguageCode=language_code,
MediaFormat=self.basename.split('.')[-1].lower(),
@@ -83,10 +93,11 @@ class Transcriber:
)
return job_name
@staticmethod
@classmethod
def get_transcription_jobs(
status=None,
job_name_query=None
cls,
status:str = None,
job_name_query:str = None,
) -> List[dict]:
kwargs = {'MaxResults': 100}
@@ -96,21 +107,21 @@ class Transcriber:
if job_name_query is not None:
kwargs['JobNameContains'] = job_name_query
jobs_data = tr.list_transcription_jobs(**kwargs)
jobs_data = cls.tr.list_transcription_jobs(**kwargs)
key = 'TranscriptionJobSummaries'
jobs = homogenize_transcription_job_data(jobs_data[key])
jobs = cls.homogenize_transcription_job_data(jobs_data[key])
while jobs_data.get('NextToken'):
token = jobs_data['NextToken']
jobs_data = tr.list_transcription_jobs(NextToken=token)
jobs += homogenize_transcription_job_data(jobs_data[key])
jobs_data = cls.tr.list_transcription_jobs(NextToken=token)
jobs += cls.homogenize_transcription_job_data(jobs_data[key])
return jobs
@staticmethod
def retrieve_transcript(transcription_job_name):
job = tr.get_transcription_job(
@classmethod
def retrieve_transcript(cls, transcription_job_name: str) -> dict:
job = cls.tr.get_transcription_job(
TranscriptionJobName=transcription_job_name
)['TranscriptionJob']
@@ -123,17 +134,15 @@ class Transcriber:
transcript_bucket = transcript_path.split('/', 1)[0]
transcript_key = transcript_path.split('/', 1)[1]
s3_object = s3.Object(transcript_bucket, transcript_key).get()
s3_object = cls.s3.Object(transcript_bucket, transcript_key).get()
transcript_json = s3_object['Body'].read().decode('utf-8')
return json.loads(transcript_json)
def homogenize_transcription_job_data(transcription_job_data):
return [{
'created': jd['CreationTime'],
'name': jd['TranscriptionJobName'],
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]
@staticmethod
def homogenize_transcription_job_data(transcription_job_data):
return [{
'created': jd['CreationTime'],
'name': jd['TranscriptionJobName'],
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]

57
tatt/vendors/vendor.py vendored Normal file
View File

@@ -0,0 +1,57 @@
import abc
import os
from pathlib import PurePath
class TranscriberBaseClass:
__metaclass__ = abc.ABCMeta
def __init__(self, filepath):
self._setup()
self.filepath = PurePath(filepath)
self.basename = str(os.path.basename(self.filepath))
@property
@abc.abstractmethod
def no_config_error_message(self):
"""
This must be defined as a class attribute, to be printed when raising
such an error.
"""
pass
@classmethod
def _setup(cls):
if not cls.check_for_config():
raise exceptions.ConfigError(cls.no_config_error_message)
@staticmethod
@abc.abstractmethod
def check_for_config() -> bool:
pass
@abc.abstractmethod
def transcribe(self) -> str:
"""
This should do any required logic,
then call self._request_transcription.
It should return the job_name.
"""
pass
@abc.abstractmethod
def _request_transcription(self) -> str:
"""Returns the job_name"""
pass
@classmethod
@abc.abstractmethod
def retrieve_transcript(transcription_job_name: str) -> dict:
pass
@classmethod
@abc.abstractmethod
def get_transcription_jobs():
pass