diff --git a/babi.py b/babi.py index 52c2234..e492068 100644 --- a/babi.py +++ b/babi.py @@ -3,6 +3,7 @@ import collections import contextlib import curses import enum +import hashlib import io import os import signal @@ -144,10 +145,12 @@ def _restore_lines_eof_invariant(lines: List[str]) -> None: lines.append('') -def _get_lines(sio: IO[str]) -> Tuple[List[str], str, bool]: +def _get_lines(sio: IO[str]) -> Tuple[List[str], str, bool, str]: + sha256 = hashlib.sha256() lines = [] newlines = collections.Counter({'\n': 0}) # default to `\n` for line in sio: + sha256.update(line.encode()) for ending in ('\r\n', '\n'): if line.endswith(ending): lines.append(line[:-1 * len(ending)]) @@ -158,7 +161,7 @@ def _get_lines(sio: IO[str]) -> Tuple[List[str], str, bool]: _restore_lines_eof_invariant(lines) (nl, _), = newlines.most_common(1) mixed = len({k for k, v in newlines.items() if v}) > 1 - return lines, nl, mixed + return lines, nl, mixed, sha256.hexdigest() class File: @@ -168,6 +171,7 @@ class File: self.lines: List[str] = [] self.nl = '\n' self.file_line = self.cursor_line = self.x = self.x_hint = 0 + self.sha256: Optional[str] = None def ensure_loaded(self, status: Status, margin: Margin) -> None: if self.lines: @@ -175,7 +179,7 @@ class File: if self.filename is not None and os.path.isfile(self.filename): with open(self.filename, newline='') as f: - self.lines, self.nl, mixed = _get_lines(f) + self.lines, self.nl, mixed, self.sha256 = _get_lines(f) else: if self.filename is not None: if os.path.lexists(self.filename): @@ -183,7 +187,8 @@ class File: self.filename = None else: status.update('(new file)', margin) - self.lines, self.nl, mixed = _get_lines(io.StringIO('')) + sio = io.StringIO('') + self.lines, self.nl, mixed, self.sha256 = _get_lines(sio) if mixed: status.update( @@ -337,6 +342,7 @@ class File: ord('\r'): enter, } DISPATCH_KEY = { + # movement b'^A': home, b'^E': end, b'^Y': page_up, @@ -352,6 +358,38 @@ class File: self.modified = True _restore_lines_eof_invariant(self.lines) + def save(self, status: Status, margin: Margin) -> None: + # TODO: make directories if they don't exist + # TODO: maybe use mtime / stat as a shortcut for hashing below + # TODO: strip trailing whitespace? + # TODO: save atomically? + if self.filename is None: + status.update('(no filename, not implemented)', margin) + return + + if os.path.isfile(self.filename): + with open(self.filename) as f: + *_, sha256 = _get_lines(f) + else: + sha256 = hashlib.sha256(b'').hexdigest() + + contents = self.nl.join(self.lines) + sha256_to_save = hashlib.sha256(contents.encode()).hexdigest() + + # the file on disk is the same as when we opened it + if sha256 not in (self.sha256, sha256_to_save): + status.update('(file changed on disk, not implemented)', margin) + return + + with open(self.filename, 'w') as f: + f.write(contents) + + self.modified = False + self.sha256 = sha256_to_save + num_lines = len(self.lines) - 1 + lines = 'lines' if num_lines != 1 else 'line' + status.update(f'saved! ({num_lines} {lines} written)', margin) + # positioning def cursor_y(self, margin: Margin) -> int: @@ -511,6 +549,8 @@ def _edit( file.DISPATCH[key.key](file, margin) elif key.keyname in File.DISPATCH_KEY: file.DISPATCH_KEY[key.keyname](file, margin) + elif key.keyname == b'^S': + file.save(status, margin) elif key.keyname == b'^X': return EditResult.EXIT elif key.keyname == b'kLFT3': diff --git a/tests/babi_test.py b/tests/babi_test.py index f4a6226..39aa9ec 100644 --- a/tests/babi_test.py +++ b/tests/babi_test.py @@ -22,6 +22,7 @@ def test_position_repr(): ' cursor_line=0,\n' ' x=0,\n' ' x_hint=0,\n' + ' sha256=None,\n' ')' ) @@ -37,8 +38,15 @@ def test_position_repr(): ), ) def test_get_lines(s, lines, nl, mixed): - ret = babi._get_lines(io.StringIO(s)) - assert ret == (lines, nl, mixed) + # sha256 tested below + ret_lines, ret_nl, ret_mixed, _ = babi._get_lines(io.StringIO(s)) + assert (ret_lines, ret_nl, ret_mixed) == (lines, nl, mixed) + + +def test_get_lines_sha256_checksum(): + ret = babi._get_lines(io.StringIO('')) + sha256 = 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + assert ret == ([''], '\n', False, sha256) class PrintsErrorRunner(Runner): @@ -796,3 +804,64 @@ def test_multiple_files(tmpdir): h.await_text('file_b') h.press('C-x') h.await_exit() + + +def test_saving_with_no_filename_doesnt_exist(): + # TODO: this should prompt but currently refuses + with run() as h, and_exit(h): + h.press('C-s') + h.await_text('no filename, not implemented') + + +def test_saving_file_on_disk_changes(tmpdir): + # TODO: this should show some sort of diffing thing or just allow overwrite + f = tmpdir.join('f') + + with run(str(f)) as h, and_exit(h): + f.write('hello world') + + h.press('C-s') + h.await_text('file changed on disk, not implemented') + + +def test_allows_saving_same_contents_as_modified_contents(tmpdir): + f = tmpdir.join('f') + + with run(str(f)) as h, and_exit(h): + f.write('hello world\n') + h.press('hello world') + h.await_text('hello world') + + h.press('C-s') + h.await_text('saved! (1 line written)') + h.await_text_missing('*') + + assert f.read() == 'hello world\n' + + +def test_allows_saving_if_file_on_disk_does_not_change(tmpdir): + f = tmpdir.join('f') + f.write('hello world\n') + + with run(str(f)) as h, and_exit(h): + h.await_text('hello world') + h.press('ohai') + h.press('Enter') + + h.press('C-s') + h.await_text('saved! (2 lines written)') + h.await_text_missing('*') + + assert f.read() == 'ohai\nhello world\n' + + +def test_save_file_when_it_did_not_exist(tmpdir): + f = tmpdir.join('f') + + with run(str(f)) as h, and_exit(h): + h.press('hello world') + h.press('C-s') + h.await_text('saved! (1 line written)') + h.await_text_missing('*') + + assert f.read() == 'hello world\n'