added speaker_id conversions to Amazon and Google

This commit is contained in:
2019-03-08 10:23:01 -05:00
parent 3fc6dacfde
commit 0301b3be23
6 changed files with 130 additions and 55 deletions

38
tests/test_amazon.py Normal file
View File

@@ -0,0 +1,38 @@
import json
import os
import pytest
from transcript_processing.converters.amazon import AmazonConverter
@pytest.fixture
def transcript_data():
with open(os.getenv('AMAZON_TRANSCRIPT_TEST_FILE'), 'r') as fin:
return json.load(fin)
@pytest.fixture
def converter(transcript_data):
return AmazonConverter(transcript_data)
def test_get_word_objects(converter):
word_objects = converter.get_word_objects(converter.json_data)
assert word_objects
def test_get_speaker_segments(converter):
speaker_segments = converter.get_speaker_segments()
assert speaker_segments
def test_get_speaker_id(converter):
speaker_segments = converter.get_speaker_segments()
assert speaker_segments[54.58] == 0
assert speaker_segments[32.36] == 1
def test_convert(converter):
converter.convert()
print(converter.converted_words)

View File

@@ -1,33 +0,0 @@
import json
import os
import pytest
from transcript_processing.converters.google import (
make_json_friendly,
GoogleConverter,
)
from transcript_processing.config import GOOGLE_TRANSCRIPT_TEST_FILE
@pytest.fixture
def transcript():
with open(GOOGLE_TRANSCRIPT_TEST_FILE, 'r') as fin:
return fin.read()
def test_make_json_friendly(transcript):
friendly = make_json_friendly(transcript)
assert json.loads(friendly)
def test_pre_process(transcript):
with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin:
transcript_data = fin.read()
g = GoogleConverter(transcript_data)
assert g.json_data
print(g.json_data)

39
tests/test_google.py Normal file
View File

@@ -0,0 +1,39 @@
import json
import os
import pytest
from transcript_processing.converters.google import (
make_json_friendly,
GoogleConverter,
)
@pytest.fixture
def transcript_data():
with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin:
return fin.read()
@pytest.fixture
def converter(transcript_data):
return GoogleConverter(transcript_data)
def test_get_word_objects(converter):
word_objects = converter.get_word_objects(converter.json_data)
assert word_objects
def test_make_json_friendly(transcript_data):
friendly = make_json_friendly(transcript_data)
assert json.loads(friendly)
def test_pre_process(converter):
assert converter.json_data
def test_convert(converter):
converter.convert()
print(converter.converted_words)

View File

@@ -62,7 +62,7 @@ class TranscriptConverter:
@staticmethod
@abc.abstractmethod
def get_speaker_id(word_object):
def get_speaker_id(word_object, speaker_segments=None):
pass
@staticmethod
@@ -77,7 +77,14 @@ class TranscriptConverter:
word_category = tagged_words[index][1]
return word_category in helpers.PROPER_NOUN_TAGS
def get_word_object(self, word_object, index, tagged_words, word_objects):
def get_word_object(
self,
word_object,
index,
tagged_words,
word_objects,
speaker_segments=None,
):
word = self.get_word_word(word_object)
return Word(
self.get_word_start(word_object),
@@ -86,7 +93,7 @@ class TranscriptConverter:
word,
self.check_if_always_capitalized(word, index, tagged_words),
self.get_next_word(word_objects, index),
self.get_speaker_id(word_object),
self.get_speaker_id(word_object, speaker_segments),
)
def get_next_word(self, word_objects, index):

View File

@@ -1,4 +1,5 @@
import json
from typing import Dict, Optional
from ..converter import TranscriptConverter
from .. import helpers
@@ -12,9 +13,28 @@ class AmazonConverter(TranscriptConverter):
def __init__(self, json_data):
super().__init__(json_data)
def get_word_objects(self, json_data):
def get_word_objects(self, json_data) -> list:
return json_data['results']['items']
def get_speaker_segments(self) -> Optional[Dict[float, str]]:
try:
segments = self.json_data['results']['speaker_labels']['segments']
except KeyError:
return None
else:
segment_dict = {}
for segment in segments:
word_level_segment = segment['items']
for word in word_level_segment:
start_time = float(word['start_time'])
speaker_label = word['speaker_label']
speaker_id = ''
for char in speaker_label:
if char.isnumeric():
speaker_id += char
segment_dict[start_time] = int(speaker_id)
return segment_dict
@staticmethod
def get_word_start(word_object):
return float(word_object['start_time'])
@@ -35,12 +55,17 @@ class AmazonConverter(TranscriptConverter):
word_word = 'I'
return word_word
@staticmethod
def get_speaker_id(word_object):
return None
@classmethod
def get_speaker_id(cls, word_object, speaker_segments=None):
if speaker_segments is None:
return None
else:
word_start = cls.get_word_start(word_object)
return speaker_segments[word_start]
def convert_words(self, word_objects, words, tagged_words=None):
converted_words = []
speaker_segments = self.get_speaker_segments()
punc_before = False
punc_after = False
@@ -49,7 +74,13 @@ class AmazonConverter(TranscriptConverter):
if w['type'] == 'punctuation':
continue
next_word_punc_after = None
word_obj = self.get_word_object(w, i, tagged_words, word_objects)
word_obj = self.get_word_object(
w,
i,
tagged_words,
word_objects,
speaker_segments,
)
if word_obj.next_word:
next_word = self.get_word_word(word_obj.next_word)

View File

@@ -57,6 +57,10 @@ class GoogleConverter(TranscriptConverter):
return converted_words
@staticmethod
def get_speaker_id(word_object, _):
return word_object.get('speaker_tag')
@classmethod
def get_word_start(cls, word_object):
return cls.get_seconds(word_object['start_time'])
@@ -88,16 +92,6 @@ def make_json_friendly(json_string):
lines = [line.strip() for line in json_string.split('\n')]
new_string = '['
fields = [
'words {',
'start_time {',
'}',
'end_time {',
'}',
'word: ',
'confidence: '
]
start_field = 'words {'
open_braces = 0
@@ -114,13 +108,13 @@ def make_json_friendly(json_string):
if '}' in line:
open_braces -= 1
if open_braces > 0 and '{' not in line and '}' not in lines[index + 1]:
line = line + ', '
if open_braces == 0:
new_string += '}, '
continue
elif '{' not in line and '}' not in lines[index + 1]:
line = line + ', '
line = re.sub('^(?!")([0-9a-zA-Z_]+)',
'"\\1"',
line)
@@ -134,5 +128,4 @@ def make_json_friendly(json_string):
new_string = new_string.replace('\\', '')
return new_string[:-2] + ']'