added speaker_id conversions to Amazon and Google
This commit is contained in:
38
tests/test_amazon.py
Normal file
38
tests/test_amazon.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from transcript_processing.converters.amazon import AmazonConverter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcript_data():
|
||||
with open(os.getenv('AMAZON_TRANSCRIPT_TEST_FILE'), 'r') as fin:
|
||||
return json.load(fin)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter(transcript_data):
|
||||
return AmazonConverter(transcript_data)
|
||||
|
||||
|
||||
def test_get_word_objects(converter):
|
||||
word_objects = converter.get_word_objects(converter.json_data)
|
||||
assert word_objects
|
||||
|
||||
|
||||
def test_get_speaker_segments(converter):
|
||||
speaker_segments = converter.get_speaker_segments()
|
||||
assert speaker_segments
|
||||
|
||||
|
||||
def test_get_speaker_id(converter):
|
||||
speaker_segments = converter.get_speaker_segments()
|
||||
assert speaker_segments[54.58] == 0
|
||||
assert speaker_segments[32.36] == 1
|
||||
|
||||
|
||||
def test_convert(converter):
|
||||
converter.convert()
|
||||
print(converter.converted_words)
|
||||
@@ -1,33 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from transcript_processing.converters.google import (
|
||||
make_json_friendly,
|
||||
GoogleConverter,
|
||||
)
|
||||
from transcript_processing.config import GOOGLE_TRANSCRIPT_TEST_FILE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcript():
|
||||
with open(GOOGLE_TRANSCRIPT_TEST_FILE, 'r') as fin:
|
||||
return fin.read()
|
||||
|
||||
|
||||
def test_make_json_friendly(transcript):
|
||||
friendly = make_json_friendly(transcript)
|
||||
assert json.loads(friendly)
|
||||
|
||||
|
||||
def test_pre_process(transcript):
|
||||
with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin:
|
||||
transcript_data = fin.read()
|
||||
|
||||
g = GoogleConverter(transcript_data)
|
||||
assert g.json_data
|
||||
print(g.json_data)
|
||||
|
||||
|
||||
|
||||
39
tests/test_google.py
Normal file
39
tests/test_google.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from transcript_processing.converters.google import (
|
||||
make_json_friendly,
|
||||
GoogleConverter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcript_data():
|
||||
with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin:
|
||||
return fin.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter(transcript_data):
|
||||
return GoogleConverter(transcript_data)
|
||||
|
||||
|
||||
def test_get_word_objects(converter):
|
||||
word_objects = converter.get_word_objects(converter.json_data)
|
||||
assert word_objects
|
||||
|
||||
|
||||
def test_make_json_friendly(transcript_data):
|
||||
friendly = make_json_friendly(transcript_data)
|
||||
assert json.loads(friendly)
|
||||
|
||||
|
||||
def test_pre_process(converter):
|
||||
assert converter.json_data
|
||||
|
||||
|
||||
def test_convert(converter):
|
||||
converter.convert()
|
||||
print(converter.converted_words)
|
||||
@@ -62,7 +62,7 @@ class TranscriptConverter:
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_speaker_id(word_object):
|
||||
def get_speaker_id(word_object, speaker_segments=None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -77,7 +77,14 @@ class TranscriptConverter:
|
||||
word_category = tagged_words[index][1]
|
||||
return word_category in helpers.PROPER_NOUN_TAGS
|
||||
|
||||
def get_word_object(self, word_object, index, tagged_words, word_objects):
|
||||
def get_word_object(
|
||||
self,
|
||||
word_object,
|
||||
index,
|
||||
tagged_words,
|
||||
word_objects,
|
||||
speaker_segments=None,
|
||||
):
|
||||
word = self.get_word_word(word_object)
|
||||
return Word(
|
||||
self.get_word_start(word_object),
|
||||
@@ -86,7 +93,7 @@ class TranscriptConverter:
|
||||
word,
|
||||
self.check_if_always_capitalized(word, index, tagged_words),
|
||||
self.get_next_word(word_objects, index),
|
||||
self.get_speaker_id(word_object),
|
||||
self.get_speaker_id(word_object, speaker_segments),
|
||||
)
|
||||
|
||||
def get_next_word(self, word_objects, index):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..converter import TranscriptConverter
|
||||
from .. import helpers
|
||||
@@ -12,9 +13,28 @@ class AmazonConverter(TranscriptConverter):
|
||||
def __init__(self, json_data):
|
||||
super().__init__(json_data)
|
||||
|
||||
def get_word_objects(self, json_data):
|
||||
def get_word_objects(self, json_data) -> list:
|
||||
return json_data['results']['items']
|
||||
|
||||
def get_speaker_segments(self) -> Optional[Dict[float, str]]:
|
||||
try:
|
||||
segments = self.json_data['results']['speaker_labels']['segments']
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
segment_dict = {}
|
||||
for segment in segments:
|
||||
word_level_segment = segment['items']
|
||||
for word in word_level_segment:
|
||||
start_time = float(word['start_time'])
|
||||
speaker_label = word['speaker_label']
|
||||
speaker_id = ''
|
||||
for char in speaker_label:
|
||||
if char.isnumeric():
|
||||
speaker_id += char
|
||||
segment_dict[start_time] = int(speaker_id)
|
||||
return segment_dict
|
||||
|
||||
@staticmethod
|
||||
def get_word_start(word_object):
|
||||
return float(word_object['start_time'])
|
||||
@@ -35,12 +55,17 @@ class AmazonConverter(TranscriptConverter):
|
||||
word_word = 'I'
|
||||
return word_word
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_id(word_object):
|
||||
return None
|
||||
@classmethod
|
||||
def get_speaker_id(cls, word_object, speaker_segments=None):
|
||||
if speaker_segments is None:
|
||||
return None
|
||||
else:
|
||||
word_start = cls.get_word_start(word_object)
|
||||
return speaker_segments[word_start]
|
||||
|
||||
def convert_words(self, word_objects, words, tagged_words=None):
|
||||
converted_words = []
|
||||
speaker_segments = self.get_speaker_segments()
|
||||
|
||||
punc_before = False
|
||||
punc_after = False
|
||||
@@ -49,7 +74,13 @@ class AmazonConverter(TranscriptConverter):
|
||||
if w['type'] == 'punctuation':
|
||||
continue
|
||||
next_word_punc_after = None
|
||||
word_obj = self.get_word_object(w, i, tagged_words, word_objects)
|
||||
word_obj = self.get_word_object(
|
||||
w,
|
||||
i,
|
||||
tagged_words,
|
||||
word_objects,
|
||||
speaker_segments,
|
||||
)
|
||||
|
||||
if word_obj.next_word:
|
||||
next_word = self.get_word_word(word_obj.next_word)
|
||||
|
||||
@@ -57,6 +57,10 @@ class GoogleConverter(TranscriptConverter):
|
||||
|
||||
return converted_words
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_id(word_object, _):
|
||||
return word_object.get('speaker_tag')
|
||||
|
||||
@classmethod
|
||||
def get_word_start(cls, word_object):
|
||||
return cls.get_seconds(word_object['start_time'])
|
||||
@@ -88,16 +92,6 @@ def make_json_friendly(json_string):
|
||||
lines = [line.strip() for line in json_string.split('\n')]
|
||||
new_string = '['
|
||||
|
||||
fields = [
|
||||
'words {',
|
||||
'start_time {',
|
||||
'}',
|
||||
'end_time {',
|
||||
'}',
|
||||
'word: ',
|
||||
'confidence: '
|
||||
]
|
||||
|
||||
start_field = 'words {'
|
||||
|
||||
open_braces = 0
|
||||
@@ -114,13 +108,13 @@ def make_json_friendly(json_string):
|
||||
if '}' in line:
|
||||
open_braces -= 1
|
||||
|
||||
if open_braces > 0 and '{' not in line and '}' not in lines[index + 1]:
|
||||
line = line + ', '
|
||||
|
||||
if open_braces == 0:
|
||||
new_string += '}, '
|
||||
continue
|
||||
|
||||
elif '{' not in line and '}' not in lines[index + 1]:
|
||||
line = line + ', '
|
||||
|
||||
line = re.sub('^(?!")([0-9a-zA-Z_]+)',
|
||||
'"\\1"',
|
||||
line)
|
||||
@@ -134,5 +128,4 @@ def make_json_friendly(json_string):
|
||||
|
||||
new_string = new_string.replace('\\', '')
|
||||
|
||||
|
||||
return new_string[:-2] + ']'
|
||||
|
||||
Reference in New Issue
Block a user