diff --git a/setup.py b/setup.py index 74326f4..2dfbc5d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as file: setup( name="tatt", - version="0.957", + version="0.958", py_modules=['tatt'], url='https://github.com/zevaverbach/tatt', install_requires=[ diff --git a/tests/__init__.py b/tatt/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to tatt/tests/__init__.py diff --git a/tatt/tests/test_amazon.py b/tatt/tests/test_amazon.py new file mode 100644 index 0000000..1d6ce29 --- /dev/null +++ b/tatt/tests/test_amazon.py @@ -0,0 +1,37 @@ +from unittest import mock + +from tatt.vendors.amazon import Transcriber + + + +def test_transcriber_instantiate(): + filepath = '/Users/zev/tester.mp3' + t = Transcriber(filepath) + assert str(t.filepath) == filepath + assert t.basename == 'tester.mp3' + assert t.media_file_uri == ( + f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3' + ) + + +@mock.patch('tatt.vendors.amazon.tr.get_transcription_job') +def test_transcriber_retrieve(get_transcription_job): + filepath = '/Users/zev/tester.mp3' + job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc' + t = Transcriber.retrieve_transcript(job_name) + get_transcription_job.assert_called_with(TranscriptionJobName=job_name) + + +def test_transcriber_get_transcription_jobs(): + result = Transcriber.get_transcription_jobs() + assert result + + +def test_transcriber_retrieve_transcript(): + jobs = Transcriber.get_transcription_jobs() + for j in jobs: + if j['status'].lower() == 'completed': + to_get = j['name'] + break + transcript = Transcriber.retrieve_transcript(to_get) + assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'} diff --git a/tatt/tests/test_transcribe.py b/tatt/tests/test_transcribe.py new file mode 100644 index 0000000..e7d4b4b --- /dev/null +++ b/tatt/tests/test_transcribe.py @@ -0,0 +1,43 @@ +from unittest import mock + +from click.testing import CliRunner + +from tatt.transcribe import cli +from tatt.vendors.amazon import Transcriber + + +def test_services(): + runner = CliRunner() + result = runner.invoke(cli, ['services']) + assert result.exit_code == 0 + assert 'amazon\t\t$0.006 per 15 seconds' in result.output + assert ('Here are all the available speech-to-text-services:' + in result.output) + + +@mock.patch('tatt.transcribe.get_transcription_jobs') +def test_status(get_transcription_jobs): + runner = CliRunner() + result = runner.invoke(cli, ['status', 'hi']) + assert get_transcription_jobs.called + get_transcription_jobs.assert_called_with(name='hi') + +# list, get + +@mock.patch('tatt.transcribe.get_service') +def test_this(get_service): + runner = CliRunner() + result = runner.invoke(cli, ['this', 'hi.mp3', 'amazon']) + get_service.assert_called_once() + get_service.assert_called_with('amazon') + + +@mock.patch('tatt.transcribe.get_transcription_jobs') +def test_list(get_transcription_jobs): + runner = CliRunner() + result = runner.invoke(cli, ['list']) + get_transcription_jobs.assert_called_once() + get_transcription_jobs.assert_called_with(None, None, None) + + result = runner.invoke(cli, ['list', '-n', 'hi.mp3']) + get_transcription_jobs.assert_called_with(None, 'hi.mp3', None) diff --git a/tatt/tests/test_vendors.py b/tatt/tests/test_vendors.py new file mode 100644 index 0000000..10b7fdb --- /dev/null +++ b/tatt/tests/test_vendors.py @@ -0,0 +1,8 @@ +from tatt.vendors import SERVICES + + +def test_services(): + for service in SERVICES.values(): + assert hasattr(service, 'Transcriber') + assert hasattr(service, 'NAME') + assert hasattr(service, 'cost_per_15_seconds') diff --git a/tatt/transcribe.py b/tatt/transcribe.py index c492039..c54d847 100644 --- a/tatt/transcribe.py +++ b/tatt/transcribe.py @@ -6,6 +6,7 @@ import sys import click from tatt import config, exceptions, helpers, vendors +from tatt.helpers import get_transcription_jobs, get_service @click.group() @@ -42,16 +43,16 @@ def get(name, save, pretty): @cli.command() -@click.option('-n', '--name', type=str, help="transcription job name") @click.option('--service', type=str, help="STT service name") +@click.option('-n', '--name', type=str, help="transcription job name") @click.option('--status', type=str, help="completed | failed | in_progress") -def list(name, service, status): +def list(service, name, status): """Lists available STT services.""" if service is not None and service not in vendors.SERVICES: raise click.ClickException(f'no such service: {service}') try: - all_jobs = helpers.get_transcription_jobs(service, name, status) + all_jobs = get_transcription_jobs(service, name, status) except exceptions.ConfigError as e: raise click.ClickException(str(e)) else: @@ -71,8 +72,8 @@ def services(free_only): @cli.command() @click.argument('job_name', type=str) def status(job_name): - jobs = helpers.get_transcription_jobs(name=job_name) - if not jobs: + jobs = get_transcription_jobs(name=job_name) + if not jobs or not list(jobs.values())[0]: raise click.ClickException('no job by that name') click.echo(list(jobs.values())[0][0]['status']) @@ -83,7 +84,7 @@ def status(job_name): def this(media_filepath, service_name): """Sends a media file to be transcribed.""" try: - service = helpers.get_service(service_name) + service = get_service(service_name) except KeyError as e: raise click.ClickException( f'No such service! {print_all_services(print_=False)}') diff --git a/tatt/vendors/amazon.py b/tatt/vendors/amazon.py index 95571c9..805900f 100644 --- a/tatt/vendors/amazon.py +++ b/tatt/vendors/amazon.py @@ -2,6 +2,7 @@ import json import os from pathlib import PurePath from subprocess import check_output +from typing import List, Dict, Union import uuid import boto3 @@ -82,35 +83,29 @@ class Transcriber: ) return job_name - @classmethod - def get_completed_jobs(cls, job_name_query=None): - return cls.get_transcription_jobs( - status='completed', - job_name_query=job_name_query) - - @classmethod - def get_pending_jobs(cls, job_name_query=None): - return cls.get_transcription_jobs( - status='in_progress', - job_name_query=job_name_query) - - @classmethod - def get_all_jobs(cls, job_name_query=None): - return cls.get_transcription_jobs(job_name_query) - @staticmethod - def get_transcription_jobs(status=None, job_name_query=None): + def get_transcription_jobs( + status=None, + job_name_query=None + ) -> List[dict]: + kwargs = {'MaxResults': 100} + if status is not None: kwargs['Status'] = status.upper() if job_name_query is not None: kwargs['JobNameContains'] = job_name_query + jobs_data = tr.list_transcription_jobs(**kwargs) - jobs = homogenize_transcription_job_data(jobs_data['TranscriptionJobSummaries']) + key = 'TranscriptionJobSummaries' + + jobs = homogenize_transcription_job_data(jobs_data[key]) + while jobs_data.get('NextToken'): - jobs_data = tr.list_transcription_jobs(NextToken=jobs_data['NextToken']) - jobs += homogenize_transcription_job_data( - jobs_data['TranscriptionJobSummaries']) + token = jobs_data['NextToken'] + jobs_data = tr.list_transcription_jobs(NextToken=token) + jobs += homogenize_transcription_job_data(jobs_data[key]) + return jobs @staticmethod @@ -142,7 +137,3 @@ def homogenize_transcription_job_data(transcription_job_data): 'status': jd['TranscriptionJobStatus'] } for jd in transcription_job_data] - - -def shell_call(command): - return check_output(command, shell=True) diff --git a/tests/test_amazon.py b/tests/test_amazon.py deleted file mode 100644 index 86d3b74..0000000 --- a/tests/test_amazon.py +++ /dev/null @@ -1,19 +0,0 @@ -from tatt.vendors.amazon import transcribe, retrieve_transcript - - - -def test_transcribe_instantiate(): - filepath = '/Users/zev/tester.mp3' - t = transcribe(filepath) - assert str(t.filepath) == filepath - assert t.basename == 'tester.mp3' - assert t.media_file_uri == ( - f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3' - ) - - -def test_retrieve(): - filepath = '/Users/zev/tester.mp3' - t = retrieve_transcript('4db6808e-a7e8-4d8d-a1b7-753ab97094dc') - print(t) - assert t is not None