added some tests, removed some unused methods on vendors/amazon
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as file:
|
||||
|
||||
setup(
|
||||
name="tatt",
|
||||
version="0.957",
|
||||
version="0.958",
|
||||
py_modules=['tatt'],
|
||||
url='https://github.com/zevaverbach/tatt',
|
||||
install_requires=[
|
||||
|
||||
37
tatt/tests/test_amazon.py
Normal file
37
tatt/tests/test_amazon.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from unittest import mock
|
||||
|
||||
from tatt.vendors.amazon import Transcriber
|
||||
|
||||
|
||||
|
||||
def test_transcriber_instantiate():
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
t = Transcriber(filepath)
|
||||
assert str(t.filepath) == filepath
|
||||
assert t.basename == 'tester.mp3'
|
||||
assert t.media_file_uri == (
|
||||
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('tatt.vendors.amazon.tr.get_transcription_job')
|
||||
def test_transcriber_retrieve(get_transcription_job):
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
job_name = '4db6808e-a7e8-4d8d-a1b7-753ab97094dc'
|
||||
t = Transcriber.retrieve_transcript(job_name)
|
||||
get_transcription_job.assert_called_with(TranscriptionJobName=job_name)
|
||||
|
||||
|
||||
def test_transcriber_get_transcription_jobs():
|
||||
result = Transcriber.get_transcription_jobs()
|
||||
assert result
|
||||
|
||||
|
||||
def test_transcriber_retrieve_transcript():
|
||||
jobs = Transcriber.get_transcription_jobs()
|
||||
for j in jobs:
|
||||
if j['status'].lower() == 'completed':
|
||||
to_get = j['name']
|
||||
break
|
||||
transcript = Transcriber.retrieve_transcript(to_get)
|
||||
assert transcript == {'jobName': 'abcd.mp3', 'accountId': '416321668733', 'results': {'transcripts': [{'transcript': 'Hello there.'}], 'items': [{'start_time': '0.0', 'end_time': '0.35', 'alternatives': [{'confidence': '0.8303', 'content': 'Hello'}], 'type': 'pronunciation'}, {'start_time': '0.35', 'end_time': '0.76', 'alternatives': [{'confidence': '1.0000', 'content': 'there'}], 'type': 'pronunciation'}, {'alternatives': [{'confidence': None, 'content': '.'}], 'type': 'punctuation'}]}, 'status': 'COMPLETED'}
|
||||
43
tatt/tests/test_transcribe.py
Normal file
43
tatt/tests/test_transcribe.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest import mock
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from tatt.transcribe import cli
|
||||
from tatt.vendors.amazon import Transcriber
|
||||
|
||||
|
||||
def test_services():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['services'])
|
||||
assert result.exit_code == 0
|
||||
assert 'amazon\t\t$0.006 per 15 seconds' in result.output
|
||||
assert ('Here are all the available speech-to-text-services:'
|
||||
in result.output)
|
||||
|
||||
|
||||
@mock.patch('tatt.transcribe.get_transcription_jobs')
|
||||
def test_status(get_transcription_jobs):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['status', 'hi'])
|
||||
assert get_transcription_jobs.called
|
||||
get_transcription_jobs.assert_called_with(name='hi')
|
||||
|
||||
# list, get
|
||||
|
||||
@mock.patch('tatt.transcribe.get_service')
|
||||
def test_this(get_service):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['this', 'hi.mp3', 'amazon'])
|
||||
get_service.assert_called_once()
|
||||
get_service.assert_called_with('amazon')
|
||||
|
||||
|
||||
@mock.patch('tatt.transcribe.get_transcription_jobs')
|
||||
def test_list(get_transcription_jobs):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['list'])
|
||||
get_transcription_jobs.assert_called_once()
|
||||
get_transcription_jobs.assert_called_with(None, None, None)
|
||||
|
||||
result = runner.invoke(cli, ['list', '-n', 'hi.mp3'])
|
||||
get_transcription_jobs.assert_called_with(None, 'hi.mp3', None)
|
||||
8
tatt/tests/test_vendors.py
Normal file
8
tatt/tests/test_vendors.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from tatt.vendors import SERVICES
|
||||
|
||||
|
||||
def test_services():
|
||||
for service in SERVICES.values():
|
||||
assert hasattr(service, 'Transcriber')
|
||||
assert hasattr(service, 'NAME')
|
||||
assert hasattr(service, 'cost_per_15_seconds')
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import click
|
||||
|
||||
from tatt import config, exceptions, helpers, vendors
|
||||
from tatt.helpers import get_transcription_jobs, get_service
|
||||
|
||||
|
||||
@click.group()
|
||||
@@ -42,16 +43,16 @@ def get(name, save, pretty):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('-n', '--name', type=str, help="transcription job name")
|
||||
@click.option('--service', type=str, help="STT service name")
|
||||
@click.option('-n', '--name', type=str, help="transcription job name")
|
||||
@click.option('--status', type=str, help="completed | failed | in_progress")
|
||||
def list(name, service, status):
|
||||
def list(service, name, status):
|
||||
"""Lists available STT services."""
|
||||
if service is not None and service not in vendors.SERVICES:
|
||||
raise click.ClickException(f'no such service: {service}')
|
||||
|
||||
try:
|
||||
all_jobs = helpers.get_transcription_jobs(service, name, status)
|
||||
all_jobs = get_transcription_jobs(service, name, status)
|
||||
except exceptions.ConfigError as e:
|
||||
raise click.ClickException(str(e))
|
||||
else:
|
||||
@@ -71,8 +72,8 @@ def services(free_only):
|
||||
@cli.command()
|
||||
@click.argument('job_name', type=str)
|
||||
def status(job_name):
|
||||
jobs = helpers.get_transcription_jobs(name=job_name)
|
||||
if not jobs:
|
||||
jobs = get_transcription_jobs(name=job_name)
|
||||
if not jobs or not list(jobs.values())[0]:
|
||||
raise click.ClickException('no job by that name')
|
||||
click.echo(list(jobs.values())[0][0]['status'])
|
||||
|
||||
@@ -83,7 +84,7 @@ def status(job_name):
|
||||
def this(media_filepath, service_name):
|
||||
"""Sends a media file to be transcribed."""
|
||||
try:
|
||||
service = helpers.get_service(service_name)
|
||||
service = get_service(service_name)
|
||||
except KeyError as e:
|
||||
raise click.ClickException(
|
||||
f'No such service! {print_all_services(print_=False)}')
|
||||
|
||||
41
tatt/vendors/amazon.py
vendored
41
tatt/vendors/amazon.py
vendored
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from pathlib import PurePath
|
||||
from subprocess import check_output
|
||||
from typing import List, Dict, Union
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
@@ -82,35 +83,29 @@ class Transcriber:
|
||||
)
|
||||
return job_name
|
||||
|
||||
@classmethod
|
||||
def get_completed_jobs(cls, job_name_query=None):
|
||||
return cls.get_transcription_jobs(
|
||||
status='completed',
|
||||
job_name_query=job_name_query)
|
||||
|
||||
@classmethod
|
||||
def get_pending_jobs(cls, job_name_query=None):
|
||||
return cls.get_transcription_jobs(
|
||||
status='in_progress',
|
||||
job_name_query=job_name_query)
|
||||
|
||||
@classmethod
|
||||
def get_all_jobs(cls, job_name_query=None):
|
||||
return cls.get_transcription_jobs(job_name_query)
|
||||
|
||||
@staticmethod
|
||||
def get_transcription_jobs(status=None, job_name_query=None):
|
||||
def get_transcription_jobs(
|
||||
status=None,
|
||||
job_name_query=None
|
||||
) -> List[dict]:
|
||||
|
||||
kwargs = {'MaxResults': 100}
|
||||
|
||||
if status is not None:
|
||||
kwargs['Status'] = status.upper()
|
||||
if job_name_query is not None:
|
||||
kwargs['JobNameContains'] = job_name_query
|
||||
|
||||
jobs_data = tr.list_transcription_jobs(**kwargs)
|
||||
jobs = homogenize_transcription_job_data(jobs_data['TranscriptionJobSummaries'])
|
||||
key = 'TranscriptionJobSummaries'
|
||||
|
||||
jobs = homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
while jobs_data.get('NextToken'):
|
||||
jobs_data = tr.list_transcription_jobs(NextToken=jobs_data['NextToken'])
|
||||
jobs += homogenize_transcription_job_data(
|
||||
jobs_data['TranscriptionJobSummaries'])
|
||||
token = jobs_data['NextToken']
|
||||
jobs_data = tr.list_transcription_jobs(NextToken=token)
|
||||
jobs += homogenize_transcription_job_data(jobs_data[key])
|
||||
|
||||
return jobs
|
||||
|
||||
@staticmethod
|
||||
@@ -142,7 +137,3 @@ def homogenize_transcription_job_data(transcription_job_data):
|
||||
'status': jd['TranscriptionJobStatus']
|
||||
}
|
||||
for jd in transcription_job_data]
|
||||
|
||||
|
||||
def shell_call(command):
|
||||
return check_output(command, shell=True)
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from tatt.vendors.amazon import transcribe, retrieve_transcript
|
||||
|
||||
|
||||
|
||||
def test_transcribe_instantiate():
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
t = transcribe(filepath)
|
||||
assert str(t.filepath) == filepath
|
||||
assert t.basename == 'tester.mp3'
|
||||
assert t.media_file_uri == (
|
||||
f'https://s3-us-east-1.amazonaws.com/tatt-media-amazon/tester.mp3'
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve():
|
||||
filepath = '/Users/zev/tester.mp3'
|
||||
t = retrieve_transcript('4db6808e-a7e8-4d8d-a1b7-753ab97094dc')
|
||||
print(t)
|
||||
assert t is not None
|
||||
Reference in New Issue
Block a user