From 9952f5fe4bf84737007354bacc0c6831b36af367 Mon Sep 17 00:00:00 2001 From: zevav Date: Thu, 7 Mar 2019 23:38:29 -0500 Subject: [PATCH] added speaker diarization for amazon, not in CLI yet --- setup.py | 2 +- tatt/vendors/amazon.py | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 055ff96..308d7e0 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as file: setup( name="tatt", - version="0.974", + version="0.975", py_modules=['tatt'], url='https://github.com/zevaverbach/tatt', install_requires=[ diff --git a/tatt/vendors/amazon.py b/tatt/vendors/amazon.py index 59286a5..6ed9509 100644 --- a/tatt/vendors/amazon.py +++ b/tatt/vendors/amazon.py @@ -70,10 +70,10 @@ class Transcriber(TranscriberBaseClass): def make_bucket(cls, bucket_name): cls.s3.create_bucket(Bucket=bucket_name) - def transcribe(self) -> str: + def transcribe(self, **kwargs) -> str: self._upload_file() try: - return self._request_transcription() + return self._request_transcription(**kwargs) except self.tr.exceptions.ConflictException: raise exceptions.AlreadyExistsError( f'{self.basename} already exists on {NAME}') @@ -83,9 +83,15 @@ class Transcriber(TranscriberBaseClass): str(self.filepath), self.basename) - def _request_transcription(self, language_code='en-US') -> str: + def _request_transcription( + self, + language_code='en-US', + num_speakers=2, + enable_speaker_diarization=True, + ) -> str: job_name = self.basename - self.tr.start_transcription_job( + + kwargs = dict( TranscriptionJobName=job_name, LanguageCode=language_code, MediaFormat=self.basename.split('.')[-1].lower(), @@ -94,6 +100,16 @@ class Transcriber(TranscriberBaseClass): }, OutputBucketName=self.bucket_names['transcript'] ) + + if enable_speaker_diarization: + kwargs.update(dict( + Settings={ + 'ShowSpeakerLabels': True, + 'MaxSpeakerLabels': num_speakers, + } + )) + + self.tr.start_transcription_job(**kwargs) return job_name @classmethod