basics are working for amazon, updated README.

This commit is contained in:
2019-02-12 11:55:20 -05:00
parent 27254c14c4
commit 237471f7f3
7 changed files with 174 additions and 52 deletions

1
181227_001.MP3.json Normal file

File diff suppressed because one or more lines are too long

1
Lelandmp3.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -2,11 +2,39 @@ Transcribe All The Things™
`tatt` creates a uniform API for multiple speech-to-text (STT) services.
# Services Supported (planned)
# Installation
`pip install git+https://github.com/zevaverbach/tatt`
# Usage
## List All Commands
`transcribe --help`
## List All STT Services
`transcribe services`
## Get Something Transcribed
`transcribe this <path_to_media_file> <service_name>`
## List Transcripts
`transcribe list # a full list of all transcripts, completed and in_progress`
`transcribe list <transcript_basename> # the status of a particular transcript
## Get A Completed Transcript
`transcript get <transcript_basename> # prints to stdout`
`transcript get -f <transcript_basename> # saves to a file in the format <basename>.json`
# Services Supported
- Watson
- Amazon Transcribe
## Planned
- Watson
- Google Speech
- Kaldi and/or things built on it
- Speechmatics
- Mozilla's new open-source STT thing

9
exceptions.py Normal file
View File

@@ -0,0 +1,9 @@
class ConfigError(Exception):
pass
class AlreadyExistsError(Exception):
pass

View File

@@ -1,5 +1,7 @@
import config
from tatt import vendors
def print_all_services(free_only=False, print_=True):
# TODO: make a jinja template for this
@@ -21,3 +23,55 @@ def print_all_services(free_only=False, print_=True):
if print_:
print(all_services_string)
return all_services_string
def get_service(service_name):
return getattr(getattr(vendors, service_name), config.SERVICE_CLASS_NAME)
def print_transcription_jobs(jobs):
max_job_name_length = max(len(job['name'])
for job_list in jobs.values()
for job in job_list)
max_service_name_length = max(len(provider_name) for provider_name in jobs)
print()
print('Service',
'Job Name',
(max_job_name_length - len('Job Name')) * ' ',
' Status')
print('-------',
'--------',
(max_job_name_length - len('Job Name')) * ' ',
' ------')
for provider_name, job_list in jobs.items():
for job in job_list:
num_spaces_between = max_job_name_length - len(job['name'])
print(provider_name, job['name'], ' ' * num_spaces_between,
job['status'], sep=' ')
print()
def get_transcription_jobs(service_name=None, name=None, status=None):
all_jobs = {}
for stt_name, data in config.STT_SERVICES.items():
if service_name is None or service_name == stt_name:
jobs = get_service(stt_name).get_transcription_jobs(
job_name_query=name,
status=status)
if jobs:
all_jobs[stt_name] = jobs
return all_jobs
def get_transcription_jobs_dict():
jobs = get_transcription_jobs()
return {
job['name']: {
'service_name': service_name,
'status': job['status']
}
for service_name, job_list in jobs.items()
for job in job_list
}

View File

@@ -7,6 +7,7 @@ import uuid
import boto3
import config
import exceptions
NAME = 'amazon'
BUCKET_NAME_MEDIA = config.AWS_BUCKET_NAME_FMTR_MEDIA.format(NAME)
@@ -15,14 +16,11 @@ tr = boto3.client('transcribe')
s3 = boto3.resource('s3')
class ConfigError(Exception):
pass
class transcribe:
bucket_names = {'media': BUCKET_NAME_MEDIA,
'transcript': BUCKET_NAME_TRANSCRIPT}
service_name = 'amazon'
def __init__(self, filepath):
self._setup()
@@ -49,7 +47,11 @@ class transcribe:
def transcribe(self):
self._upload_file()
try:
return self._request_transcription()
except tr.exceptions.ConflictException:
raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {self.service_name}')
def _upload_file(self):
s3.Bucket(self.bucket_names['media']).upload_file(
@@ -57,7 +59,7 @@ class transcribe:
self.basename)
def _request_transcription(self, language_code='en-US'):
job_name = str(uuid.uuid4())
job_name = self.basename
tr.start_transcription_job(
TranscriptionJobName=job_name,
LanguageCode=language_code,
@@ -70,22 +72,28 @@ class transcribe:
return job_name
@staticmethod
def get_completed_jobs():
return transcribe.get_transcription_jobs(status='completed')
def get_completed_jobs(job_name_query=None):
return transcribe.get_transcription_jobs(
status='completed',
job_name_query=job_name_query)
@staticmethod
def get_pending_jobs():
return transcribe.get_transcription_jobs(status='in_progress')
def get_pending_jobs(job_name_query=None):
return transcribe.get_transcription_jobs(
status='in_progress',
job_name_query=job_name_query)
@staticmethod
def get_all_jobs():
return transcribe.get_transcription_jobs()
def get_all_jobs(job_name_query=None):
return transcribe.get_transcription_jobs(job_name_query)
@staticmethod
def get_transcription_jobs(status=None):
def get_transcription_jobs(status=None, job_name_query=None):
kwargs = {'MaxResults': 100}
if status is not None:
kwargs['Status'] = status.upper()
if job_name_query is not None:
kwargs['JobNameContains'] = job_name_query
jobs_data = tr.list_transcription_jobs(**kwargs)
jobs = homogenize_transcription_job_data(jobs_data['TranscriptionJobSummaries'])
while jobs_data.get('NextToken'):
@@ -94,17 +102,8 @@ class transcribe:
jobs_data['TranscriptionJobSummaries'])
return jobs
def homogenize_transcription_job_data(transcription_job_data):
return [{
'created': jd['CreationTime'],
'name': jd['TranscriptionJobName'],
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]
def retrieve_transcript(transcription_job_name):
@staticmethod
def retrieve_transcript(transcription_job_name):
job = tr.get_transcription_job(
TranscriptionJobName=transcription_job_name
)['TranscriptionJob']
@@ -123,6 +122,17 @@ def retrieve_transcript(transcription_job_name):
return json.loads(transcript_json)
def homogenize_transcription_job_data(transcription_job_data):
return [{
'created': jd['CreationTime'],
'name': jd['TranscriptionJobName'],
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]
def check_for_credentials():
return config.AWS_CREDENTIALS_FILEPATH.exists()

View File

@@ -1,3 +1,4 @@
import json
from pprint import pprint
import sqlite3
import sys
@@ -5,6 +6,7 @@ import sys
import click
import config
import exceptions
import helpers
from tatt import vendors
@@ -15,16 +17,35 @@ def cli():
@cli.command()
@click.argument('uid', required=False)
def retrieve(name=None, service=None):
pending_jobs = [get_service(service_name).get_pending_jobs(name)
for service_name, data in config.STT_SERVICES
if service is None
or service == service_name]
if not pending_jobs:
click.ClickException('no pending jobs currently!')
for job in pending_jobs:
print(dict(job))
@click.option('-f', '--file', is_flag=True)
@click.argument('name')
def get(name, file):
job = helpers.get_transcription_jobs_dict().get(name)
if not job:
raise click.ClickException(f'no such transcript {name}')
if job['status'].lower() != 'completed':
raise click.ClickException(f'transcript status is {job["status"]}')
service = helpers.get_service(job['service_name'])
if not file:
pprint(service.retrieve_transcript(name))
with open(f'{name}.json', 'w') as fout:
fout.write(json.dumps(service.retrieve_transcript(name)))
@cli.command()
@click.option('-n', '--name', type=str)
@click.option('--service', type=str)
@click.option('--status', type=str)
def list(name, service, status):
if service is not None and service not in config.STT_SERVICES:
raise click.ClickException(f'no such service: {service}')
all_jobs = helpers.get_transcription_jobs(service, name, status)
if not all_jobs:
click.ClickException('no transcripts currently!')
helpers.print_transcription_jobs(all_jobs)
@cli.command()
@@ -46,7 +67,7 @@ def this(dry_run, media_filepath, service_name):
raise click.ClickException(
f'No such service! {print_all_services(print_=False)}')
service = get_service(service_name)
service = helpers.get_service(service_name)
s = service(media_filepath)
if dry_run:
@@ -57,13 +78,11 @@ def this(dry_run, media_filepath, service_name):
print(
f'Okay, transcribing {media_filepath} using {service_name}...')
try:
job_num = s.transcribe()
db.create_pending_job(job_num, s.basename, service_name)
print(f'Okay, job {job_num} is being transcribed. Use "retrieve" '
except exceptions.AlreadyExistsError as e:
raise click.ClickException(str(e))
print(f'Okay, job {job_num} is being transcribed. Use "get" '
'command to download it.')
def get_service(service_name):
return getattr(getattr(vendors, service_name), config.SERVICE_CLASS_NAME)