google is working

This commit is contained in:
2019-03-06 22:58:06 -05:00
parent 8a29f36cf0
commit 7a61e5d729
8 changed files with 137 additions and 81 deletions

BIN
completed_google.p Normal file

Binary file not shown.

View File

@@ -5,6 +5,7 @@ import sqlite3
BUCKET_NAME_FMTR_MEDIA = 'tatt-media-{}'
BUCKET_NAME_FMTR_TRANSCRIPT = 'tatt-transcript-{}'
BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE = 'tatt_transcript_{}'
if os.getenv('AWS_CONFIG_FILEPATH'):
AWS_CONFIG_FILEPATH = Path(os.getenv('AWS_CONFIG_FILEPATH'))

View File

@@ -13,3 +13,7 @@ class DoesntExistError(Exception):
class NotAvailable(Exception):
pass
class DependencyRequired(Exception):
pass

View File

@@ -1,4 +1,5 @@
import pathlib
import re
import subprocess
from typing import Dict, List
@@ -111,7 +112,6 @@ def get_num_audio_channels(filepath):
filepath = str(filepath)
with audioread.audio_open(filepath) as f:
return f.channels
pass
def shell_call(command):
@@ -132,18 +132,5 @@ def convert_file(filepath, format_name):
convert_flags = '-c:a flac'
output_filepath = change_file_extension(filepath, format_name)
shell_call(f'ffmpeg -i {filepath} {convert_flags} {output_filepath}')
shell_call(f'ffmpeg -y -i {filepath} {convert_flags} {output_filepath}')
return output_filepath
def make_json_friendly(json_string):
lines = [line.strip() for line in json_string.split('\n')]
new_lines = []
for index, line in enumerate(lines):
if '{' in line and ':' not in line:
line = line.replace('{', ':{')
if '{' not in line and index != 0:
line += ','
# TODO: regex to get words not surrounded by quotes
new_lines.append(line)
return ''.join(new_lines)

View File

@@ -1,17 +0,0 @@
import json
import pytest
from tatt.helpers import make_json_friendly
@pytest.fixture
def json_string():
return '''
results {\n alternatives {\n transcript: "Testing, this is Zev, Ivory box saying things."\n confidence: 0.8002681732177734\n words {\n start_time {\n seconds: 4\n }\n end_time {\n seconds: 5\n nanos: 500000000\n }\n word: "Testing,"\n confidence: 0.8863372206687927\n }\n words {\n start_time {\n seconds: 5\n nanos: 500000000\n }\n end_time {\n seconds: 6\n nanos: 600000000\n }\n word: "this"\n confidence: 0.8322266936302185\n }\n words {\n start_time {\n seconds: 6\n nanos: 600000000\n }\n end_time {\n seconds: 6\n nanos: 900000000\n }\n word: "is"\n confidence: 0.7659578323364258\n }\n words {\n start_time {\n seconds: 6\n nanos: 900000000\n }\n end_time {\n seconds: 7\n nanos: 300000000\n }\n word: "Zev,"\n confidence: 0.9128385782241821\n }\n words {\n start_time {\n seconds: 7\n nanos: 300000000\n }\n end_time {\n seconds: 7\n nanos: 700000000\n }\n word: "Ivory"\n confidence: 0.7265068292617798\n }\n words {\n start_time {\n seconds: 7\n nanos: 700000000\n }\n end_time {\n seconds: 7\n nanos: 900000000\n }\n word: "box"\n confidence: 0.7768470644950867\n }\n words {\n start_time {\n seconds: 7\n nanos: 900000000\n }\n end_time {\n seconds: 8\n nanos: 700000000\n }\n word: "saying"\n confidence: 0.8872994780540466\n }\n words {\n start_time {\n seconds: 8\n nanos: 700000000\n }\n end_time {\n seconds: 9\n nanos: 400000000\n }\n word: "things."\n confidence: 0.9128385782241821\n }\n }\n channel_tag: 1\n language_code: "en-us"\n}\nresults {\n alternatives {\n transcript: " 2019"\n confidence: 0.7211145758628845\n words {\n start_time {\n seconds: 10\n nanos: 300000000\n }\n end_time {\n seconds: 11\n nanos: 500000000\n }\n word: "2019"\n confidence: 0.7581846714019775\n }\n }\n channel_tag: 2\n language_code: "en-us"\n}\n
'''
def test_make_json_friendly(json_string):
friendly = make_json_friendly(json_string)
print(friendly)
assert json.loads(friendly)

130
tatt/vendors/google.py vendored
View File

@@ -2,14 +2,24 @@ import io
import json
import os
import pathlib
import shutil
import tempfile
from time import sleep
from typing import List
from google.cloud import speech_v1p1beta1 as speech
from google.api_core import operations_v1
from google.cloud import (
speech_v1p1beta1 as speech,
storage,
exceptions as gc_exceptions,
)
from tatt import exceptions, helpers, config
from .vendor import TranscriberBaseClass
NAME = 'google'
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT.format(NAME)
BUCKET_NAME_TRANSCRIPT = config.BUCKET_NAME_FMTR_TRANSCRIPT_GOOGLE.format(
'goog')
def _check_for_config():
@@ -27,53 +37,70 @@ class Transcriber(TranscriberBaseClass):
)
if _check_for_config():
client = speech.SpeechClient()
speech_client = speech.SpeechClient()
storage_client = storage.Client()
transcript_bucket = storage_client.get_bucket(BUCKET_NAME_TRANSCRIPT)
def __init__(self, filepath):
super().__init__(filepath)
self.convert_file_format_if_needed()
@classmethod
def _setup(cls):
super()._setup()
if not cls.check_for_bucket(BUCKET_NAME_TRANSCRIPT):
print('creating a transcript bucket on Google Cloud Storage')
cls.make_bucket(BUCKET_NAME_TRANSCRIPT)
if not shutil.which('gsutil'):
raise exceptions.DependencyRequired(
'Please install gcloud using the steps here:'
'https://cloud.google.com/storage/docs/gsutil_install')
cls._make_bucket_if_doesnt_exist(BUCKET_NAME_TRANSCRIPT)
@classmethod
def make_bucket(cls, bucket_name):
pass
@classmethod
def check_for_bucket(cls, bucket_name):
pass
def _make_bucket_if_doesnt_exist(cls, bucket_name):
try:
cls.storage_client.create_bucket(bucket_name)
except gc_exceptions.Conflict:
# this might fail if a bucket by the name exists *anywhere* on GCS?
return
else:
print('made Google Cloud Storage Bucket for transcripts')
def convert_file_format_if_needed(self):
if self.file_format not in self.SUPPORTED_FORMATS:
if not shutil.which('ffmpeg'):
raise exceptions.DependencyRequired('please install ffmpeg')
self.filepath = helpers.convert_file(self.filepath, 'flac')
@property
def file_format(self):
return pathlib.Path(self.filepath).suffix[1:].lower()
@property
def transcript_name(self):
return self.basename + '.txt'
@staticmethod
def check_for_config() -> bool:
return _check_for_config()
def transcribe(self) -> str:
"""
This should do any required logic,
then call self._request_transcription.
It should return the job_name.
"""
self.convert_file_format_if_needed()
self._request_transcription()
def _check_if_transcript_exists(self, transcript_name=None):
return storage.Blob(
bucket=self.transcript_bucket,
name=transcript_name or self.transcript_name
).exists(self.storage_client)
def _request_transcription(
self,
language_code='en-US',
model='video',
) -> str:
"""Returns the job_name"""
if self._check_if_transcript_exists():
raise exceptions.AlreadyExistsError(
f'{self.basename} already exists on {NAME}')
num_audio_channels = helpers.get_num_audio_channels(self.filepath)
with io.open(self.filepath, 'rb') as audio_file:
@@ -92,43 +119,54 @@ class Transcriber(TranscriberBaseClass):
model=model,
)
self.operation = self.client.long_running_recognize(config, audio)
self.operation = self.speech_client.long_running_recognize(config,
audio)
def my_callback(future):
result = future.result()
# save json.dumps(result) to file
# TODO: see what others have done to make this easy (BBC guy)
self.upload_file(BUCKET_NAME_TRANSCRIPT, filepath)
# delete file
print('transcribing...')
while not self.operation.done():
sleep(1)
print('.')
self.operation.add_done_callback(my_callback)
result_list = []
return self.filepath.name
for result in self.operation.result().results:
result_list.append(str(result))
print('saving transcript')
transcript_path = '/tmp/transcript.txt'
with open(transcript_path, 'w') as fout:
fout.write('\n'.join(result_list))
print('uploading transcript')
self.upload_file(BUCKET_NAME_TRANSCRIPT, transcript_path)
os.remove(transcript_path)
return self.basename
@classmethod
def retrieve_transcript(cls, transcription_job_name: str) -> dict:
"""Get transcript from BUCKET_NAME_TRANSCRIPT"""
# for result in results:
if not cls._check_if_transcript_exists(
cls,
transcript_name=transcription_job_name):
raise exceptions.DoesntExistError('no such transcript!')
blob = cls.transcript_bucket.blob(transcription_job_name)
f = tempfile.NamedTemporaryFile(delete=False)
f.close()
# leave enable_automatic_punctuation in? it is applied to the words
# themselves, so it'll have to be processed...
blob.download_to_filename(f.name)
with open(f.name) as fin:
transcript_text = fin.read()
# for word in result.alternatives[0].words:
# print(word)
# print(type(word))
# print(dir(word))
os.remove(f.name)
return transcript_text
pass
def upload_file(self, bucket_name, path):
blob = self.transcript_bucket.blob(self.transcript_name)
blob.upload_from_filename(path)
@classmethod
def upload_file(cls, bucket_name, path):
pass
@classmethod
def get_transcription_jobs(job_name_query, status):
"""
Store pending jobs in some simple db or document,
then remove them when the transcript appears in the bucket.
"""
pass
def get_transcription_jobs(cls, job_name_query, status) -> List[dict]:
return [
{'name': t.name, 'status': 'COMPLETED'}
for t in cls.transcript_bucket.list_blobs()
]

View File

@@ -1,6 +1,48 @@
from pprint import pprint
import pytest
from tatt.vendors.google import Transcriber
from tatt import exceptions
def test_request_transcription():
@pytest.fixture
def audio_filepath():
return '/Users/zev/d/saying_things_stuff.flac'
@pytest.fixture
def transcript_name():
return 'saying_things_stuff.flac.txt'
def test_request_transcription_already_exists(audio_filepath):
with pytest.raises(exceptions.AlreadyExistsError):
t = Transcriber(audio_filepath)
filename = t._request_transcription()
def test_make_bucket():
t = Transcriber._make_bucket_if_doesnt_exist('something-uh-ok')
def test_setup():
t = Transcriber._setup()
def test_check_if_transcript_exists(audio_filepath):
t = Transcriber('/Users/zev/d/saying_things_stuff.flac')
t._request_transcription()
assert t._check_if_transcript_exists() is True
def test_retrieve_transcript(transcript_name):
transcript = Transcriber.retrieve_transcript(transcript_name)
assert transcript is not None
def test_retrieve_transcript_doesnt_exist():
with pytest.raises(exceptions.DoesntExistError):
Transcriber.retrieve_transcript('no_such_thing.json')
def test_get_transcription_jobs():

View File

@@ -1,6 +1,7 @@
import abc
import os
from pathlib import PurePath
from typing import List
from tatt import exceptions
@@ -60,6 +61,6 @@ class TranscriberBaseClass:
@classmethod
@abc.abstractmethod
def get_transcription_jobs():
def get_transcription_jobs() -> List[dict]:
pass