From 7dff0121dd21d3553532a0640ae6a92dfdc4d093 Mon Sep 17 00:00:00 2001 From: zevav Date: Tue, 5 Mar 2019 19:04:48 -0500 Subject: [PATCH] set up abstract base class --- setup.py | 2 +- tatt/transcribe.py | 2 +- tatt/vendors/__init__.py | 1 + tatt/vendors/amazon.py | 97 ++++++++++++++++++++++------------------ tatt/vendors/vendor.py | 57 +++++++++++++++++++++++ 5 files changed, 113 insertions(+), 46 deletions(-) create mode 100644 tatt/vendors/vendor.py diff --git a/setup.py b/setup.py index d2b2865..8c794f7 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as file: setup( name="tatt", - version="0.970", + version="0.972", py_modules=['tatt'], url='https://github.com/zevaverbach/tatt', install_requires=[ diff --git a/tatt/transcribe.py b/tatt/transcribe.py index 2b07fb8..cb0705a 100644 --- a/tatt/transcribe.py +++ b/tatt/transcribe.py @@ -32,7 +32,7 @@ def get(name, save, pretty): file = None if save: filepath = f'{name}.json' - file = open(filepath) + file = open(filepath, 'w') click.echo(transcript, file=file) diff --git a/tatt/vendors/__init__.py b/tatt/vendors/__init__.py index 7e5b2d5..225f4b7 100644 --- a/tatt/vendors/__init__.py +++ b/tatt/vendors/__init__.py @@ -3,3 +3,4 @@ from tatt.vendors import amazon SERVICES = { 'amazon': amazon, } + diff --git a/tatt/vendors/amazon.py b/tatt/vendors/amazon.py index 805900f..43fa23a 100644 --- a/tatt/vendors/amazon.py +++ b/tatt/vendors/amazon.py @@ -9,6 +9,7 @@ import boto3 from tatt import config from tatt import exceptions +from .vendor import TranscriberBaseClass NAME = 'amazon' BUCKET_NAME_MEDIA = config.AWS_BUCKET_NAME_FMTR_MEDIA.format(NAME) @@ -16,63 +17,72 @@ BUCKET_NAME_TRANSCRIPT = config.AWS_BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME) cost_per_15_seconds = .024 / 4 -def check_for_config(): +def _check_for_config() -> bool: return ( config.AWS_CONFIG_FILEPATH.exists() and config.AWS_CREDENTIALS_FILEPATH.exists() ) -if check_for_config(): - tr = boto3.client('transcribe') - s3 = boto3.resource('s3') - - -class Transcriber: +class Transcriber(TranscriberBaseClass): bucket_names = {'media': BUCKET_NAME_MEDIA, 'transcript': BUCKET_NAME_TRANSCRIPT} + no_config_error_message = 'please run "aws configure" first' + + if _check_for_config(): + tr = boto3.client('transcribe') + s3 = boto3.resource('s3') + def __init__(self, filepath): + super().__init__(filepath) self._setup() self.filepath = PurePath(filepath) self.basename = str(os.path.basename(self.filepath)) - self.media_file_uri = ( - f"https://s3-{config.AWS_REGION}.amazonaws.com/" - f"{self.bucket_names['media']}/{self.basename}") + + @classmethod + def check_for_config(cls): + return _check_for_config() + + @property + def media_file_uri(self): + return ( + f"https://s3-{config.AWS_REGION}.amazonaws.com/" + f"{self.bucket_names['media']}/{self.basename}" + ) @classmethod def _setup(cls): - if not check_for_config(): - raise exceptions.ConfigError('please run "aws configure" first') + super()._setup() for bucket_name in cls.bucket_names.values(): if not cls.check_for_bucket(bucket_name): cls.make_bucket(bucket_name) - @staticmethod - def check_for_bucket(bucket_name): - return bool(s3.Bucket(bucket_name).creation_date) + @classmethod + def check_for_bucket(cls, bucket_name: str) -> bool: + return bool(cls.s3.Bucket(bucket_name).creation_date) - @staticmethod - def make_bucket(bucket_name): - s3.create_bucket(Bucket=bucket_name) + @classmethod + def make_bucket(cls, bucket_name): + cls.s3.create_bucket(Bucket=bucket_name) - def transcribe(self): + def transcribe(self) -> str: self._upload_file() try: return self._request_transcription() - except tr.exceptions.ConflictException: + except self.tr.exceptions.ConflictException: raise exceptions.AlreadyExistsError( f'{self.basename} already exists on {NAME}') def _upload_file(self): - s3.Bucket(self.bucket_names['media']).upload_file( + self.s3.Bucket(self.bucket_names['media']).upload_file( str(self.filepath), self.basename) - def _request_transcription(self, language_code='en-US'): + def _request_transcription(self, language_code='en-US') -> str: job_name = self.basename - tr.start_transcription_job( + self.tr.start_transcription_job( TranscriptionJobName=job_name, LanguageCode=language_code, MediaFormat=self.basename.split('.')[-1].lower(), @@ -83,10 +93,11 @@ class Transcriber: ) return job_name - @staticmethod + @classmethod def get_transcription_jobs( - status=None, - job_name_query=None + cls, + status:str = None, + job_name_query:str = None, ) -> List[dict]: kwargs = {'MaxResults': 100} @@ -96,21 +107,21 @@ class Transcriber: if job_name_query is not None: kwargs['JobNameContains'] = job_name_query - jobs_data = tr.list_transcription_jobs(**kwargs) + jobs_data = cls.tr.list_transcription_jobs(**kwargs) key = 'TranscriptionJobSummaries' - jobs = homogenize_transcription_job_data(jobs_data[key]) + jobs = cls.homogenize_transcription_job_data(jobs_data[key]) while jobs_data.get('NextToken'): token = jobs_data['NextToken'] - jobs_data = tr.list_transcription_jobs(NextToken=token) - jobs += homogenize_transcription_job_data(jobs_data[key]) + jobs_data = cls.tr.list_transcription_jobs(NextToken=token) + jobs += cls.homogenize_transcription_job_data(jobs_data[key]) return jobs - @staticmethod - def retrieve_transcript(transcription_job_name): - job = tr.get_transcription_job( + @classmethod + def retrieve_transcript(cls, transcription_job_name: str) -> dict: + job = cls.tr.get_transcription_job( TranscriptionJobName=transcription_job_name )['TranscriptionJob'] @@ -123,17 +134,15 @@ class Transcriber: transcript_bucket = transcript_path.split('/', 1)[0] transcript_key = transcript_path.split('/', 1)[1] - s3_object = s3.Object(transcript_bucket, transcript_key).get() + s3_object = cls.s3.Object(transcript_bucket, transcript_key).get() transcript_json = s3_object['Body'].read().decode('utf-8') return json.loads(transcript_json) - - - -def homogenize_transcription_job_data(transcription_job_data): - return [{ - 'created': jd['CreationTime'], - 'name': jd['TranscriptionJobName'], - 'status': jd['TranscriptionJobStatus'] - } - for jd in transcription_job_data] + @staticmethod + def homogenize_transcription_job_data(transcription_job_data): + return [{ + 'created': jd['CreationTime'], + 'name': jd['TranscriptionJobName'], + 'status': jd['TranscriptionJobStatus'] + } + for jd in transcription_job_data] diff --git a/tatt/vendors/vendor.py b/tatt/vendors/vendor.py new file mode 100644 index 0000000..e8f0247 --- /dev/null +++ b/tatt/vendors/vendor.py @@ -0,0 +1,57 @@ +import abc +import os +from pathlib import PurePath + + +class TranscriberBaseClass: + + __metaclass__ = abc.ABCMeta + + def __init__(self, filepath): + self._setup() + self.filepath = PurePath(filepath) + self.basename = str(os.path.basename(self.filepath)) + + @property + @abc.abstractmethod + def no_config_error_message(self): + """ + This must be defined as a class attribute, to be printed when raising + such an error. + """ + pass + + @classmethod + def _setup(cls): + if not cls.check_for_config(): + raise exceptions.ConfigError(cls.no_config_error_message) + + @staticmethod + @abc.abstractmethod + def check_for_config() -> bool: + pass + + @abc.abstractmethod + def transcribe(self) -> str: + """ + This should do any required logic, + then call self._request_transcription. + It should return the job_name. + """ + pass + + @abc.abstractmethod + def _request_transcription(self) -> str: + """Returns the job_name""" + pass + + @classmethod + @abc.abstractmethod + def retrieve_transcript(transcription_job_name: str) -> dict: + pass + + @classmethod + @abc.abstractmethod + def get_transcription_jobs(): + pass +