diff --git a/tests/test_amazon.py b/tests/test_amazon.py new file mode 100644 index 0000000..dbceca3 --- /dev/null +++ b/tests/test_amazon.py @@ -0,0 +1,38 @@ +import json +import os + +import pytest + +from transcript_processing.converters.amazon import AmazonConverter + + +@pytest.fixture +def transcript_data(): + with open(os.getenv('AMAZON_TRANSCRIPT_TEST_FILE'), 'r') as fin: + return json.load(fin) + + +@pytest.fixture +def converter(transcript_data): + return AmazonConverter(transcript_data) + + +def test_get_word_objects(converter): + word_objects = converter.get_word_objects(converter.json_data) + assert word_objects + + +def test_get_speaker_segments(converter): + speaker_segments = converter.get_speaker_segments() + assert speaker_segments + + +def test_get_speaker_id(converter): + speaker_segments = converter.get_speaker_segments() + assert speaker_segments[54.58] == 0 + assert speaker_segments[32.36] == 1 + + +def test_convert(converter): + converter.convert() + print(converter.converted_words) diff --git a/tests/test_convert_google.py b/tests/test_convert_google.py deleted file mode 100644 index 545d78f..0000000 --- a/tests/test_convert_google.py +++ /dev/null @@ -1,33 +0,0 @@ -import json -import os - -import pytest - -from transcript_processing.converters.google import ( - make_json_friendly, - GoogleConverter, - ) -from transcript_processing.config import GOOGLE_TRANSCRIPT_TEST_FILE - - -@pytest.fixture -def transcript(): - with open(GOOGLE_TRANSCRIPT_TEST_FILE, 'r') as fin: - return fin.read() - - -def test_make_json_friendly(transcript): - friendly = make_json_friendly(transcript) - assert json.loads(friendly) - - -def test_pre_process(transcript): - with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin: - transcript_data = fin.read() - - g = GoogleConverter(transcript_data) - assert g.json_data - print(g.json_data) - - - diff --git a/tests/test_google.py b/tests/test_google.py new file mode 100644 index 0000000..1dc8e41 --- /dev/null +++ b/tests/test_google.py @@ -0,0 +1,39 @@ +import json +import os + +import pytest + +from transcript_processing.converters.google import ( + make_json_friendly, + GoogleConverter, + ) + + +@pytest.fixture +def transcript_data(): + with open(os.getenv('GOOGLE_TRANSCRIPT_TEST_FILE'), 'r') as fin: + return fin.read() + + +@pytest.fixture +def converter(transcript_data): + return GoogleConverter(transcript_data) + + +def test_get_word_objects(converter): + word_objects = converter.get_word_objects(converter.json_data) + assert word_objects + + +def test_make_json_friendly(transcript_data): + friendly = make_json_friendly(transcript_data) + assert json.loads(friendly) + + +def test_pre_process(converter): + assert converter.json_data + + +def test_convert(converter): + converter.convert() + print(converter.converted_words) diff --git a/transcript_processing/converter.py b/transcript_processing/converter.py index 7e1697c..66254cf 100644 --- a/transcript_processing/converter.py +++ b/transcript_processing/converter.py @@ -62,7 +62,7 @@ class TranscriptConverter: @staticmethod @abc.abstractmethod - def get_speaker_id(word_object): + def get_speaker_id(word_object, speaker_segments=None): pass @staticmethod @@ -77,7 +77,14 @@ class TranscriptConverter: word_category = tagged_words[index][1] return word_category in helpers.PROPER_NOUN_TAGS - def get_word_object(self, word_object, index, tagged_words, word_objects): + def get_word_object( + self, + word_object, + index, + tagged_words, + word_objects, + speaker_segments=None, + ): word = self.get_word_word(word_object) return Word( self.get_word_start(word_object), @@ -86,7 +93,7 @@ class TranscriptConverter: word, self.check_if_always_capitalized(word, index, tagged_words), self.get_next_word(word_objects, index), - self.get_speaker_id(word_object), + self.get_speaker_id(word_object, speaker_segments), ) def get_next_word(self, word_objects, index): diff --git a/transcript_processing/converters/amazon.py b/transcript_processing/converters/amazon.py index 5ff147f..55348ed 100644 --- a/transcript_processing/converters/amazon.py +++ b/transcript_processing/converters/amazon.py @@ -1,4 +1,5 @@ import json +from typing import Dict, Optional from ..converter import TranscriptConverter from .. import helpers @@ -12,9 +13,28 @@ class AmazonConverter(TranscriptConverter): def __init__(self, json_data): super().__init__(json_data) - def get_word_objects(self, json_data): + def get_word_objects(self, json_data) -> list: return json_data['results']['items'] + def get_speaker_segments(self) -> Optional[Dict[float, str]]: + try: + segments = self.json_data['results']['speaker_labels']['segments'] + except KeyError: + return None + else: + segment_dict = {} + for segment in segments: + word_level_segment = segment['items'] + for word in word_level_segment: + start_time = float(word['start_time']) + speaker_label = word['speaker_label'] + speaker_id = '' + for char in speaker_label: + if char.isnumeric(): + speaker_id += char + segment_dict[start_time] = int(speaker_id) + return segment_dict + @staticmethod def get_word_start(word_object): return float(word_object['start_time']) @@ -35,12 +55,17 @@ class AmazonConverter(TranscriptConverter): word_word = 'I' return word_word - @staticmethod - def get_speaker_id(word_object): - return None + @classmethod + def get_speaker_id(cls, word_object, speaker_segments=None): + if speaker_segments is None: + return None + else: + word_start = cls.get_word_start(word_object) + return speaker_segments[word_start] def convert_words(self, word_objects, words, tagged_words=None): converted_words = [] + speaker_segments = self.get_speaker_segments() punc_before = False punc_after = False @@ -49,7 +74,13 @@ class AmazonConverter(TranscriptConverter): if w['type'] == 'punctuation': continue next_word_punc_after = None - word_obj = self.get_word_object(w, i, tagged_words, word_objects) + word_obj = self.get_word_object( + w, + i, + tagged_words, + word_objects, + speaker_segments, + ) if word_obj.next_word: next_word = self.get_word_word(word_obj.next_word) diff --git a/transcript_processing/converters/google.py b/transcript_processing/converters/google.py index 81d3411..8a8d262 100644 --- a/transcript_processing/converters/google.py +++ b/transcript_processing/converters/google.py @@ -57,6 +57,10 @@ class GoogleConverter(TranscriptConverter): return converted_words + @staticmethod + def get_speaker_id(word_object, _): + return word_object.get('speaker_tag') + @classmethod def get_word_start(cls, word_object): return cls.get_seconds(word_object['start_time']) @@ -88,16 +92,6 @@ def make_json_friendly(json_string): lines = [line.strip() for line in json_string.split('\n')] new_string = '[' - fields = [ - 'words {', - 'start_time {', - '}', - 'end_time {', - '}', - 'word: ', - 'confidence: ' - ] - start_field = 'words {' open_braces = 0 @@ -114,13 +108,13 @@ def make_json_friendly(json_string): if '}' in line: open_braces -= 1 - if open_braces > 0 and '{' not in line and '}' not in lines[index + 1]: - line = line + ', ' - if open_braces == 0: new_string += '}, ' continue + elif '{' not in line and '}' not in lines[index + 1]: + line = line + ', ' + line = re.sub('^(?!")([0-9a-zA-Z_]+)', '"\\1"', line) @@ -134,5 +128,4 @@ def make_json_friendly(json_string): new_string = new_string.replace('\\', '') - return new_string[:-2] + ']'