added some tests, removed some unused methods on vendors/amazon

This commit is contained in:
2019-02-20 21:30:30 -05:00
parent 12297a321f
commit ac278f2d18
8 changed files with 112 additions and 51 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as file:
setup(
name="tatt",
version="0.957",
version="0.958",
py_modules=['tatt'],
url='https://github.com/zevaverbach/tatt',
install_requires=[

37
tatt/tests/test_amazon.py Normal file
View File

@@ -0,0 +1,37 @@
from unittest import mock
from tatt.vendors.amazon import Transcriber
def test_transcriber_instantiate():
filepath = '/Users/zev/tester.mp3'
t = Transcriber(filepath)
assert str(t.filepath) == filepath
assert t.basename == 'tester.mp3'
assert t.media_file_uri == (
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
)
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job')
def test_transcriber_retrieve(get_transcription_job):
filepath = '/Users/zev/tester.mp3'
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
t = Transcriber.retrieve_transcript(job_name)
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
def test_transcriber_get_transcription_jobs():
result = Transcriber.get_transcription_jobs()
assert result
def test_transcriber_retrieve_transcript():
jobs = Transcriber.get_transcription_jobs()
for j in jobs:
if j['status'].lower() == 'completed':
to_get = j['name']
break
transcript = Transcriber.retrieve_transcript(to_get)
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'}

View File

@@ -0,0 +1,43 @@
from unittest import mock
from click.testing import CliRunner
from tatt.transcribe import cli
from tatt.vendors.amazon import Transcriber
def test_services():
runner = CliRunner()
result = runner.invoke(cli, ['services'])
assert result.exit_code == 0
assert 'amazon\t\t$0.006 per 15 seconds' in result.output
assert ('Here are all the available speech-to-text-services:'
in result.output)
@mock.patch('tatt.transcribe.get_transcription_jobs')
def test_status(get_transcription_jobs):
runner = CliRunner()
result = runner.invoke(cli, ['status', 'hi'])
assert get_transcription_jobs.called
get_transcription_jobs.assert_called_with(name='hi')
# list, get
@mock.patch('tatt.transcribe.get_service')
def test_this(get_service):
runner = CliRunner()
result = runner.invoke(cli, ['this', 'hi.mp3', 'amazon'])
get_service.assert_called_once()
get_service.assert_called_with('amazon')
@mock.patch('tatt.transcribe.get_transcription_jobs')
def test_list(get_transcription_jobs):
runner = CliRunner()
result = runner.invoke(cli, ['list'])
get_transcription_jobs.assert_called_once()
get_transcription_jobs.assert_called_with(None, None, None)
result = runner.invoke(cli, ['list', '-n', 'hi.mp3'])
get_transcription_jobs.assert_called_with(None, 'hi.mp3', None)

View File

@@ -0,0 +1,8 @@
from tatt.vendors import SERVICES
def test_services():
for service in SERVICES.values():
assert hasattr(service, 'Transcriber')
assert hasattr(service, 'NAME')
assert hasattr(service, 'cost_per_15_seconds')

View File

@@ -6,6 +6,7 @@ import sys
import click
from tatt import config, exceptions, helpers, vendors
from tatt.helpers import get_transcription_jobs, get_service
@click.group()
@@ -42,16 +43,16 @@ def get(name, save, pretty):
@cli.command()
@click.option('-n', '--name', type=str, help="transcription job name")
@click.option('--service', type=str, help="STT service name")
@click.option('-n', '--name', type=str, help="transcription job name")
@click.option('--status', type=str, help="completed | failed | in_progress")
def list(name, service, status):
def list(service, name, status):
"""Lists available STT services."""
if service is not None and service not in vendors.SERVICES:
raise click.ClickException(f'no such service: {service}')
try:
all_jobs = helpers.get_transcription_jobs(service, name, status)
all_jobs = get_transcription_jobs(service, name, status)
except exceptions.ConfigError as e:
raise click.ClickException(str(e))
else:
@@ -71,8 +72,8 @@ def services(free_only):
@cli.command()
@click.argument('job_name', type=str)
def status(job_name):
jobs = helpers.get_transcription_jobs(name=job_name)
if not jobs:
jobs = get_transcription_jobs(name=job_name)
if not jobs or not list(jobs.values())[0]:
raise click.ClickException('no job by that name')
click.echo(list(jobs.values())[0][0]['status'])
@@ -83,7 +84,7 @@ def status(job_name):
def this(media_filepath, service_name):
"""Sends a media file to be transcribed."""
try:
service = helpers.get_service(service_name)
service = get_service(service_name)
except KeyError as e:
raise click.ClickException(
f'No such service! {print_all_services(print_=False)}')

View File

@@ -2,6 +2,7 @@ import json
import os
from pathlib import PurePath
from subprocess import check_output
from typing import List, Dict, Union
import uuid
import boto3
@@ -82,35 +83,29 @@ class Transcriber:
)
return job_name
@classmethod
def get_completed_jobs(cls, job_name_query=None):
return cls.get_transcription_jobs(
status='completed',
job_name_query=job_name_query)
@classmethod
def get_pending_jobs(cls, job_name_query=None):
return cls.get_transcription_jobs(
status='in_progress',
job_name_query=job_name_query)
@classmethod
def get_all_jobs(cls, job_name_query=None):
return cls.get_transcription_jobs(job_name_query)
@staticmethod
def get_transcription_jobs(status=None, job_name_query=None):
def get_transcription_jobs(
status=None,
job_name_query=None
) -> List[dict]:
kwargs = {'MaxResults': 100}
if status is not None:
kwargs['Status'] = status.upper()
if job_name_query is not None:
kwargs['JobNameContains'] = job_name_query
jobs_data = tr.list_transcription_jobs(**kwargs)
jobs = homogenize_transcription_job_data(jobs_data['TranscriptionJobSummaries'])
key = 'TranscriptionJobSummaries'
jobs = homogenize_transcription_job_data(jobs_data[key])
while jobs_data.get('NextToken'):
jobs_data = tr.list_transcription_jobs(NextToken=jobs_data['NextToken'])
jobs += homogenize_transcription_job_data(
jobs_data['TranscriptionJobSummaries'])
token = jobs_data['NextToken']
jobs_data = tr.list_transcription_jobs(NextToken=token)
jobs += homogenize_transcription_job_data(jobs_data[key])
return jobs
@staticmethod
@@ -142,7 +137,3 @@ def homogenize_transcription_job_data(transcription_job_data):
'status': jd['TranscriptionJobStatus']
}
for jd in transcription_job_data]
def shell_call(command):
return check_output(command, shell=True)

View File

@@ -1,19 +0,0 @@
from tatt.vendors.amazon import transcribe, retrieve_transcript
def test_transcribe_instantiate():
filepath = '/Users/zev/tester.mp3'
t = transcribe(filepath)
assert str(t.filepath) == filepath
assert t.basename == 'tester.mp3'
assert t.media_file_uri == (
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
)
def test_retrieve():
filepath = '/Users/zev/tester.mp3'
t = retrieve_transcript('4db6808e-a7e8-4d8d-a1b7-753ab97094dc')
print(t)
assert t is not None