Syntax highlighting

This commit is contained in:
Anthony Sottile
2020-02-22 16:34:47 -08:00
parent 1d06a77d44
commit 697b012027
29 changed files with 2515 additions and 18 deletions

View File

@@ -35,6 +35,6 @@ repos:
- id: pyupgrade
args: [--py36-plus]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761
rev: v0.770
hooks:
- id: mypy

26
babi/cached_property.py Normal file
View File

@@ -0,0 +1,26 @@
import sys
if sys.version_info >= (3, 8): # pragma: no cover (>=py38)
from functools import cached_property
else: # pragma: no cover (<py38)
from typing import Callable
from typing import Generic
from typing import Optional
from typing import Type
from typing import TypeVar
TSelf = TypeVar('TSelf')
TRet = TypeVar('TRet')
class cached_property(Generic[TSelf, TRet]):
def __init__(self, func: Callable[[TSelf], TRet]) -> None:
self._func = func
def __get__(
self,
instance: Optional[TSelf],
owner: Optional[Type[TSelf]] = None,
) -> TRet:
assert instance is not None
ret = instance.__dict__[self._func.__name__] = self._func(instance)
return ret

11
babi/color.py Normal file
View File

@@ -0,0 +1,11 @@
from typing import NamedTuple
class Color(NamedTuple):
r: int
g: int
b: int
@classmethod
def parse(cls, s: str) -> 'Color':
return cls(r=int(s[1:3], 16), g=int(s[3:5], 16), b=int(s[5:7], 16))

90
babi/color_kd.py Normal file
View File

@@ -0,0 +1,90 @@
import functools
import itertools
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from babi._types import Protocol
from babi.color import Color
def _square_distance(c1: Color, c2: Color) -> int:
return (c1.r - c2.r) ** 2 + (c1.g - c2.g) ** 2 + (c1.b - c2.b) ** 2
class KD(Protocol):
@property
def color(self) -> Color: ...
@property
def n(self) -> int: ...
@property
def left(self) -> Optional['KD']: ...
@property
def right(self) -> Optional['KD']: ...
class _KD(NamedTuple):
color: Color
n: int
left: Optional[KD]
right: Optional[KD]
def _build(colors: List[Tuple[Color, int]], depth: int = 0) -> Optional[KD]:
if not colors:
return None
axis = depth % 3
colors.sort(key=lambda kv: kv[0][axis])
pivot = len(colors) // 2
return _KD(
*colors[pivot],
_build(colors[:pivot], depth=depth + 1),
_build(colors[pivot + 1:], depth=depth + 1),
)
def nearest(color: Color, colors: Optional[KD]) -> int:
best = 0
dist = 2 ** 32
def _search(kd: Optional[KD], *, depth: int) -> None:
nonlocal best
nonlocal dist
if kd is None:
return
cand_dist = _square_distance(color, kd.color)
if cand_dist < dist:
best, dist = kd.n, cand_dist
axis = depth % 3
diff = color[axis] - kd.color[axis]
if diff > 0:
_search(kd.right, depth=depth + 1)
if diff ** 2 < dist:
_search(kd.left, depth=depth + 1)
else:
_search(kd.left, depth=depth + 1)
if diff ** 2 < dist:
_search(kd.right, depth=depth + 1)
_search(colors, depth=0)
return best
@functools.lru_cache(maxsize=1)
def make_256() -> Optional[KD]:
vals = (0, 95, 135, 175, 215, 255)
colors = [
(Color(r, g, b), i)
for i, (r, g, b) in enumerate(itertools.product(vals, vals, vals), 16)
]
for i in range(24):
v = 10 * i + 8
colors.append((Color(v, v, v), 232 + i))
return _build(colors)

View File

@@ -2,12 +2,40 @@ import contextlib
import curses
from typing import Dict
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from babi import color_kd
from babi.color import Color
def _color_to_curses(color: Color) -> Tuple[int, int, int]:
factor = 1000 / 255
return int(color.r * factor), int(color.g * factor), int(color.b * factor)
class ColorManager(NamedTuple):
colors: Dict[Color, int]
raw_pairs: Dict[Tuple[int, int], int]
def init_color(self, color: Color) -> None:
if curses.COLORS < 256:
return
elif curses.can_change_color():
n = min(self.colors.values(), default=256) - 1
self.colors[color] = n
curses.init_color(n, *_color_to_curses(color))
else:
self.colors[color] = color_kd.nearest(color, color_kd.make_256())
def color_pair(self, fg: Optional[Color], bg: Optional[Color]) -> int:
if curses.COLORS < 256:
return 0
fg_i = self.colors[fg] if fg is not None else -1
bg_i = self.colors[bg] if bg is not None else -1
return self.raw_color_pair(fg_i, bg_i)
def raw_color_pair(self, fg: int, bg: int) -> int:
with contextlib.suppress(KeyError):
return self.raw_pairs[(fg, bg)]
@@ -18,4 +46,4 @@ class ColorManager(NamedTuple):
@classmethod
def make(cls) -> 'ColorManager':
return cls({})
return cls({}, {})

24
babi/fdict.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import TypeVar
TKey = TypeVar('TKey')
TValue = TypeVar('TValue')
class FDict(Generic[TKey, TValue]):
def __init__(self, dct: Mapping[TKey, TValue]) -> None:
self._dct = dct
def __getitem__(self, k: TKey) -> TValue:
return self._dct[k]
def __contains__(self, k: TKey) -> bool:
return k in self._dct
def __repr__(self) -> str:
return f'{type(self).__name__}({self._dct})'
def values(self) -> Iterable[TValue]:
return self._dct.values()

View File

@@ -35,6 +35,7 @@ if TYPE_CHECKING:
from babi.main import Screen # XXX: circular
TCallable = TypeVar('TCallable', bound=Callable[..., Any])
HIGHLIGHT = curses.A_REVERSE | curses.A_DIM
@@ -275,7 +276,7 @@ class File:
if self.y >= self.file_y + margin.body_lines:
self.file_y += self._scroll_amount()
def _decrement_y(self, margin: Margin) -> None:
def _decrement_y(self) -> None:
self.y -= 1
if self.y < self.file_y:
self.file_y -= self._scroll_amount()
@@ -284,7 +285,7 @@ class File:
@action
def up(self, margin: Margin) -> None:
if self.y > 0:
self._decrement_y(margin)
self._decrement_y()
self._set_x_after_vertical_movement()
@action
@@ -307,7 +308,7 @@ class File:
def left(self, margin: Margin) -> None:
if self.x == 0:
if self.y > 0:
self._decrement_y(margin)
self._decrement_y()
self.x = len(self.lines[self.y])
else:
self.x -= 1
@@ -370,7 +371,7 @@ class File:
elif self.x == 0 or line[:self.x].isspace():
self.x = self.x_hint = 0
while self.y > 0 and (self.x == 0 or not self.lines[self.y]):
self._decrement_y(margin)
self._decrement_y()
self.x = self.x_hint = len(self.lines[self.y])
else:
self.x = self.x_hint = self.x - 1
@@ -502,7 +503,7 @@ class File:
pass
# backspace at the end of the file does not change the contents
elif self.y == len(self.lines) - 1:
self._decrement_y(margin)
self._decrement_y()
self.x = self.x_hint = len(self.lines[self.y])
# at the beginning of the line, we join the current line and
# the previous line
@@ -510,7 +511,7 @@ class File:
victim = self.lines.pop(self.y)
new_x = len(self.lines[self.y - 1])
self.lines[self.y - 1] += victim
self._decrement_y(margin)
self._decrement_y()
self.x = self.x_hint = new_x
else:
s = self.lines[self.y]
@@ -658,7 +659,7 @@ class File:
cut_buffer: Tuple[str, ...], margin: Margin,
) -> None:
self._uncut(cut_buffer, margin)
self._decrement_y(margin)
self._decrement_y()
self.x = self.x_hint = len(self.lines[self.y])
self.lines[self.y] += self.lines.pop(self.y + 1)
@@ -886,6 +887,7 @@ class File:
h_y = y - self.file_y + margin.header
if y == self.y:
l_x = line_x(self.x, curses.COLS)
# TODO: include edge left detection
if x < l_x:
h_x = 0
n -= l_x - x

729
babi/highlight.py Normal file
View File

@@ -0,0 +1,729 @@
import contextlib
import functools
import json
import os.path
from typing import Any
from typing import Dict
from typing import FrozenSet
from typing import List
from typing import Match
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from babi._types import Protocol
from babi.fdict import FDict
from babi.reg import _Reg
from babi.reg import _RegSet
from babi.reg import ERR_REG
from babi.reg import make_reg
from babi.reg import make_regset
Scope = Tuple[str, ...]
Regions = Tuple['Region', ...]
Captures = Tuple[Tuple[int, '_Rule'], ...]
def _split_name(s: Optional[str]) -> Tuple[str, ...]:
if s is None:
return ()
else:
return tuple(s.split())
class _Rule(Protocol):
"""hax for recursive types python/mypy#731"""
@property
def name(self) -> Tuple[str, ...]: ...
@property
def match(self) -> Optional[str]: ...
@property
def begin(self) -> Optional[str]: ...
@property
def end(self) -> Optional[str]: ...
@property
def while_(self) -> Optional[str]: ...
@property
def content_name(self) -> Tuple[str, ...]: ...
@property
def captures(self) -> Captures: ...
@property
def begin_captures(self) -> Captures: ...
@property
def end_captures(self) -> Captures: ...
@property
def while_captures(self) -> Captures: ...
@property
def include(self) -> Optional[str]: ...
@property
def patterns(self) -> 'Tuple[_Rule, ...]': ...
class Rule(NamedTuple):
name: Tuple[str, ...]
match: Optional[str]
begin: Optional[str]
end: Optional[str]
while_: Optional[str]
content_name: Tuple[str, ...]
captures: Captures
begin_captures: Captures
end_captures: Captures
while_captures: Captures
include: Optional[str]
patterns: Tuple[_Rule, ...]
@classmethod
def from_dct(cls, dct: Dict[str, Any]) -> _Rule:
name = _split_name(dct.get('name'))
match = dct.get('match')
begin = dct.get('begin')
end = dct.get('end')
while_ = dct.get('while')
content_name = _split_name(dct.get('contentName'))
if 'captures' in dct:
captures = tuple(
(int(k), Rule.from_dct(v))
for k, v in dct['captures'].items()
)
else:
captures = ()
if 'beginCaptures' in dct:
begin_captures = tuple(
(int(k), Rule.from_dct(v))
for k, v in dct['beginCaptures'].items()
)
else:
begin_captures = ()
if 'endCaptures' in dct:
end_captures = tuple(
(int(k), Rule.from_dct(v))
for k, v in dct['endCaptures'].items()
)
else:
end_captures = ()
if 'whileCaptures' in dct:
while_captures = tuple(
(int(k), Rule.from_dct(v))
for k, v in dct['whileCaptures'].items()
)
else:
while_captures = ()
# Using the captures key for a begin/end/while rule is short-hand for
# giving both beginCaptures and endCaptures with same values
if begin and end and captures:
begin_captures = end_captures = captures
captures = ()
elif begin and while_ and captures:
begin_captures = while_captures = captures
captures = ()
include = dct.get('include')
if 'patterns' in dct:
patterns = tuple(Rule.from_dct(d) for d in dct['patterns'])
else:
patterns = ()
return cls(
name=name,
match=match,
begin=begin,
end=end,
while_=while_,
content_name=content_name,
captures=captures,
begin_captures=begin_captures,
end_captures=end_captures,
while_captures=while_captures,
include=include,
patterns=patterns,
)
class Grammar(NamedTuple):
scope_name: str
first_line_match: Optional[_Reg]
file_types: FrozenSet[str]
patterns: Tuple[_Rule, ...]
repository: FDict[str, _Rule]
@classmethod
def from_data(cls, data: Dict[str, Any]) -> 'Grammar':
scope_name = data['scopeName']
if 'firstLineMatch' in data:
first_line_match: Optional[_Reg] = make_reg(data['firstLineMatch'])
else:
first_line_match = None
if 'fileTypes' in data:
file_types = frozenset(data['fileTypes'])
else:
file_types = frozenset()
patterns = tuple(Rule.from_dct(dct) for dct in data['patterns'])
if 'repository' in data:
repository = FDict({
k: Rule.from_dct(dct) for k, dct in data['repository'].items()
})
else:
repository = FDict({})
return cls(
scope_name=scope_name,
first_line_match=first_line_match,
file_types=file_types,
patterns=patterns,
repository=repository,
)
@classmethod
def parse(cls, filename: str) -> 'Grammar':
with open(filename) as f:
return cls.from_data(json.load(f))
@classmethod
def blank(cls) -> 'Grammar':
return cls(
scope_name='source.unknown',
first_line_match=None,
file_types=frozenset(),
patterns=(),
repository=FDict({}),
)
def matches_file(self, filename: str, first_line: str) -> bool:
_, ext = os.path.splitext(filename)
if ext.lstrip('.') in self.file_types:
return True
elif self.first_line_match is not None:
return bool(
self.first_line_match.match(
first_line, 0, first_line=True, boundary=True,
),
)
else:
return False
class Region(NamedTuple):
start: int
end: int
scope: Scope
class State(NamedTuple):
entries: Tuple['Entry', ...]
while_stack: Tuple[Tuple['WhileRule', int], ...]
@classmethod
def root(cls, entry: 'Entry') -> 'State':
return cls((entry,), ())
@property
def cur(self) -> 'Entry':
return self.entries[-1]
def push(self, entry: 'Entry') -> 'State':
return self._replace(entries=(*self.entries, entry))
def pop(self) -> 'State':
return self._replace(entries=self.entries[:-1])
def push_while(self, rule: 'WhileRule', entry: 'Entry') -> 'State':
entries = (*self.entries, entry)
while_stack = (*self.while_stack, (rule, len(entries)))
return self._replace(entries=entries, while_stack=while_stack)
def pop_while(self) -> 'State':
entries, while_stack = self.entries[:-1], self.while_stack[:-1]
return self._replace(entries=entries, while_stack=while_stack)
class CompiledRule(Protocol):
@property
def name(self) -> Tuple[str, ...]: ...
def start(
self,
compiler: 'Compiler',
match: Match[str],
state: State,
) -> Tuple[State, bool, Regions]:
...
def search(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[State, int, bool, Regions]]:
...
class CompiledRegsetRule(CompiledRule, Protocol):
@property
def regset(self) -> _RegSet: ...
@property
def u_rules(self) -> Tuple[_Rule, ...]: ...
class Entry(NamedTuple):
scope: Tuple[str, ...]
rule: CompiledRule
reg: _Reg = ERR_REG
boundary: bool = False
def _inner_capture_parse(
compiler: 'Compiler',
start: int,
s: str,
scope: Scope,
rule: CompiledRule,
) -> Regions:
state = State.root(Entry(scope + rule.name, rule))
_, regions = highlight_line(compiler, state, s, first_line=False)
return tuple(
r._replace(start=r.start + start, end=r.end + start) for r in regions
)
def _captures(
compiler: 'Compiler',
scope: Scope,
match: Match[str],
captures: Captures,
) -> Regions:
ret: List[Region] = []
pos, pos_end = match.span()
for i, u_rule in captures:
try:
group_s = match[i]
except IndexError: # some grammars are malformed here?
continue
if not group_s:
continue
rule = compiler.compile_rule(u_rule)
start, end = match.span(i)
if start < pos:
# TODO: could maybe bisect but this is probably fast enough
j = len(ret) - 1
while j > 0 and start < ret[j - 1].end:
j -= 1
oldtok = ret[j]
newtok = []
if start > oldtok.start:
newtok.append(oldtok._replace(end=start))
newtok.extend(
_inner_capture_parse(
compiler, start, match[i], oldtok.scope, rule,
),
)
if end < oldtok.end:
newtok.append(oldtok._replace(start=end))
ret[j:j + 1] = newtok
else:
if start > pos:
ret.append(Region(pos, start, scope))
ret.extend(
_inner_capture_parse(compiler, start, match[i], scope, rule),
)
pos = end
if pos < pos_end:
ret.append(Region(pos, pos_end, scope))
return tuple(ret)
def _do_regset(
idx: int,
match: Optional[Match[str]],
rule: CompiledRegsetRule,
compiler: 'Compiler',
state: State,
pos: int,
) -> Optional[Tuple[State, int, bool, Regions]]:
if match is None:
return None
ret = []
if match.start() > pos:
ret.append(Region(pos, match.start(), state.cur.scope))
target_rule = compiler.compile_rule(rule.u_rules[idx])
state, boundary, regions = target_rule.start(compiler, match, state)
ret.extend(regions)
return state, match.end(), boundary, tuple(ret)
class PatternRule(NamedTuple):
name: Tuple[str, ...]
regset: _RegSet
u_rules: Tuple[_Rule, ...]
def start(
self,
compiler: 'Compiler',
match: Match[str],
state: State,
) -> Tuple[State, bool, Regions]:
raise AssertionError(f'unreachable {self}')
def search(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[State, int, bool, Regions]]:
idx, match = self.regset.search(line, pos, first_line, boundary)
return _do_regset(idx, match, self, compiler, state, pos)
class MatchRule(NamedTuple):
name: Tuple[str, ...]
captures: Captures
def start(
self,
compiler: 'Compiler',
match: Match[str],
state: State,
) -> Tuple[State, bool, Regions]:
scope = state.cur.scope + self.name
return state, False, _captures(compiler, scope, match, self.captures)
def search(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[State, int, bool, Regions]]:
raise AssertionError(f'unreachable {self}')
class EndRule(NamedTuple):
name: Tuple[str, ...]
content_name: Tuple[str, ...]
begin_captures: Captures
end_captures: Captures
end: str
regset: _RegSet
u_rules: Tuple[_Rule, ...]
def start(
self,
compiler: 'Compiler',
match: Match[str],
state: State,
) -> Tuple[State, bool, Regions]:
scope = state.cur.scope + self.name
next_scope = scope + self.content_name
boundary = match.end() == len(match.string)
reg = make_reg(match.expand(self.end))
state = state.push(Entry(next_scope, self, reg, boundary))
regions = _captures(compiler, scope, match, self.begin_captures)
return state, True, regions
def _end_ret(
self,
compiler: 'Compiler',
state: State,
pos: int,
m: Match[str],
) -> Tuple[State, int, bool, Regions]:
ret = []
if m.start() > pos:
ret.append(Region(pos, m.start(), state.cur.scope))
ret.extend(_captures(compiler, state.cur.scope, m, self.end_captures))
return state.pop(), m.end(), False, tuple(ret)
def search(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[State, int, bool, Regions]]:
end_match = state.cur.reg.search(line, pos, first_line, boundary)
if end_match is not None and end_match.start() == pos:
return self._end_ret(compiler, state, pos, end_match)
elif end_match is None:
idx, match = self.regset.search(line, pos, first_line, boundary)
return _do_regset(idx, match, self, compiler, state, pos)
else:
idx, match = self.regset.search(line, pos, first_line, boundary)
if match is None or end_match.start() <= match.start():
return self._end_ret(compiler, state, pos, end_match)
else:
return _do_regset(idx, match, self, compiler, state, pos)
class WhileRule(NamedTuple):
name: Tuple[str, ...]
content_name: Tuple[str, ...]
begin_captures: Captures
while_captures: Captures
while_: str
regset: _RegSet
u_rules: Tuple[_Rule, ...]
def start(
self,
compiler: 'Compiler',
match: Match[str],
state: State,
) -> Tuple[State, bool, Regions]:
scope = state.cur.scope + self.name
next_scope = scope + self.content_name
boundary = match.end() == len(match.string)
reg = make_reg(match.expand(self.while_))
state = state.push_while(self, Entry(next_scope, self, reg, boundary))
regions = _captures(compiler, scope, match, self.begin_captures)
return state, True, regions
def continues(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[int, bool, Regions]]:
match = state.cur.reg.match(line, pos, first_line, boundary)
if match is None:
return None
ret = _captures(compiler, state.cur.scope, match, self.while_captures)
return match.end(), True, ret
def search(
self,
compiler: 'Compiler',
state: State,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Tuple[State, int, bool, Regions]]:
idx, match = self.regset.search(line, pos, first_line, boundary)
return _do_regset(idx, match, self, compiler, state, pos)
class Compiler:
def __init__(self, grammar: Grammar, grammars: Dict[str, Grammar]) -> None:
self._root_scope = grammar.scope_name
self._grammars = grammars
self._rule_to_grammar: Dict[_Rule, Grammar] = {}
self._c_rules: Dict[_Rule, CompiledRule] = {}
root = self._compile_root(grammar)
self.root_state = State.root(Entry(root.name, root))
def _visit_rule(self, grammar: Grammar, rule: _Rule) -> _Rule:
self._rule_to_grammar[rule] = grammar
return rule
@functools.lru_cache(maxsize=None)
def _include(
self,
grammar: Grammar,
s: str,
) -> Tuple[List[str], Tuple[_Rule, ...]]:
if s == '$self':
return self._patterns(grammar, grammar.patterns)
elif s == '$base':
return self._include(self._grammars[self._root_scope], '$self')
elif s.startswith('#'):
return self._patterns(grammar, (grammar.repository[s[1:]],))
elif '#' not in s:
return self._include(self._grammars[s], '$self')
else:
scope, _, s = s.partition('#')
return self._include(self._grammars[scope], f'#{s}')
@functools.lru_cache(maxsize=None)
def _patterns(
self,
grammar: Grammar,
rules: Tuple[_Rule, ...],
) -> Tuple[List[str], Tuple[_Rule, ...]]:
ret_regs = []
ret_rules: List[_Rule] = []
for rule in rules:
if rule.include is not None:
tmp_regs, tmp_rules = self._include(grammar, rule.include)
ret_regs.extend(tmp_regs)
ret_rules.extend(tmp_rules)
elif rule.match is None and rule.begin is None and rule.patterns:
tmp_regs, tmp_rules = self._patterns(grammar, rule.patterns)
ret_regs.extend(tmp_regs)
ret_rules.extend(tmp_rules)
elif rule.match is not None:
ret_regs.append(rule.match)
ret_rules.append(self._visit_rule(grammar, rule))
elif rule.begin is not None:
ret_regs.append(rule.begin)
ret_rules.append(self._visit_rule(grammar, rule))
else:
raise AssertionError(f'unreachable {rule}')
return ret_regs, tuple(ret_rules)
def _captures_ref(
self,
grammar: Grammar,
captures: Captures,
) -> Captures:
return tuple((n, self._visit_rule(grammar, r)) for n, r in captures)
def _compile_root(self, grammar: Grammar) -> PatternRule:
regs, rules = self._patterns(grammar, grammar.patterns)
return PatternRule((grammar.scope_name,), make_regset(*regs), rules)
def _compile_rule(self, grammar: Grammar, rule: _Rule) -> CompiledRule:
assert rule.include is None, rule
if rule.match is not None:
captures_ref = self._captures_ref(grammar, rule.captures)
return MatchRule(rule.name, captures_ref)
elif rule.begin is not None and rule.end is not None:
regs, rules = self._patterns(grammar, rule.patterns)
return EndRule(
rule.name,
rule.content_name,
self._captures_ref(grammar, rule.begin_captures),
self._captures_ref(grammar, rule.end_captures),
rule.end,
make_regset(*regs),
rules,
)
elif rule.begin is not None and rule.while_ is not None:
regs, rules = self._patterns(grammar, rule.patterns)
return WhileRule(
rule.name,
rule.content_name,
self._captures_ref(grammar, rule.begin_captures),
self._captures_ref(grammar, rule.while_captures),
rule.while_,
make_regset(*regs),
rules,
)
else:
regs, rules = self._patterns(grammar, rule.patterns)
return PatternRule(rule.name, make_regset(*regs), rules)
def compile_rule(self, rule: _Rule) -> CompiledRule:
with contextlib.suppress(KeyError):
return self._c_rules[rule]
grammar = self._rule_to_grammar[rule]
ret = self._c_rules[rule] = self._compile_rule(grammar, rule)
return ret
class Grammars:
def __init__(self, grammars: List[Grammar]) -> None:
self.grammars = {grammar.scope_name: grammar for grammar in grammars}
self._compilers: Dict[Grammar, Compiler] = {}
@classmethod
def from_syntax_dir(cls, syntax_dir: str) -> 'Grammars':
grammars = [Grammar.blank()]
if os.path.exists(syntax_dir):
grammars.extend(
Grammar.parse(os.path.join(syntax_dir, filename))
for filename in os.listdir(syntax_dir)
)
return cls(grammars)
def _compiler_for_grammar(self, grammar: Grammar) -> Compiler:
with contextlib.suppress(KeyError):
return self._compilers[grammar]
ret = self._compilers[grammar] = Compiler(grammar, self.grammars)
return ret
def compiler_for_scope(self, scope: str) -> Compiler:
return self._compiler_for_grammar(self.grammars[scope])
def blank_compiler(self) -> Compiler:
return self.compiler_for_scope('source.unknown')
def compiler_for_file(self, filename: str) -> Compiler:
if os.path.exists(filename):
with open(filename) as f:
first_line = next(f, '')
else:
first_line = ''
for grammar in self.grammars.values():
if grammar.matches_file(filename, first_line):
break
else:
grammar = self.grammars['source.unknown']
return self._compiler_for_grammar(grammar)
@functools.lru_cache(maxsize=None)
def highlight_line(
compiler: 'Compiler',
state: State,
line: str,
first_line: bool,
) -> Tuple[State, Regions]:
ret: List[Region] = []
pos = 0
boundary = state.cur.boundary
# TODO: this is still a little wasteful
while_stack = []
for while_rule, idx in state.while_stack:
while_stack.append((while_rule, idx))
while_state = State(state.entries[:idx], tuple(while_stack))
while_res = while_rule.continues(
compiler, while_state, line, pos, first_line, boundary,
)
if while_res is None:
state = while_state.pop_while()
break
else:
pos, boundary, regions = while_res
ret.extend(regions)
search_res = state.cur.rule.search(
compiler, state, line, pos, first_line, boundary,
)
while search_res is not None:
state, pos, boundary, regions = search_res
ret.extend(regions)
search_res = state.cur.rule.search(
compiler, state, line, pos, first_line, boundary,
)
if pos < len(line):
ret.append(Region(pos, len(line), state.cur.scope))
return state, tuple(ret)

146
babi/hl/syntax.py Normal file
View File

@@ -0,0 +1,146 @@
import curses
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from babi.color_manager import ColorManager
from babi.highlight import Compiler
from babi.highlight import Grammars
from babi.highlight import highlight_line
from babi.highlight import State
from babi.hl.interface import CursesRegion
from babi.hl.interface import CursesRegions
from babi.list_spy import SequenceNoSlice
from babi.theme import Style
from babi.theme import Theme
from babi.user_data import xdg_config
from babi.user_data import xdg_data
A_ITALIC = getattr(curses, 'A_ITALIC', 0x80000000) # new in py37
class FileSyntax:
def __init__(
self,
compiler: Compiler,
theme: Theme,
color_manager: ColorManager,
) -> None:
self._compiler = compiler
self._theme = theme
self._color_manager = color_manager
self.regions: List[CursesRegions] = []
self._states: List[State] = []
self._hl_cache: Dict[str, Dict[State, Tuple[State, CursesRegions]]]
self._hl_cache = {}
def attr(self, style: Style) -> int:
pair = self._color_manager.color_pair(style.fg, style.bg)
return (
curses.color_pair(pair) |
curses.A_BOLD * style.b |
A_ITALIC * style.i |
curses.A_UNDERLINE * style.u
)
def _hl(
self,
state: State,
line: str,
i: int,
) -> Tuple[State, CursesRegions]:
try:
return self._hl_cache[line][state]
except KeyError:
pass
new_state, regions = highlight_line(
self._compiler, state, f'{line}\n', first_line=i == 0,
)
# remove the trailing newline
new_end = regions[-1]._replace(end=regions[-1].end - 1)
regions = regions[:-1] + (new_end,)
regs: List[CursesRegion] = []
for r in regions:
style = self._theme.select(r.scope)
if style == self._theme.default:
continue
n = r.end - r.start
attr = self.attr(style)
if (
regs and
regs[-1]['color'] == attr and
regs[-1]['x'] + regs[-1]['n'] == r.start
):
regs[-1]['n'] += n
else:
regs.append(CursesRegion(x=r.start, n=n, color=attr))
dct = self._hl_cache.setdefault(line, {})
ret = dct[state] = (new_state, tuple(regs))
return ret
def highlight_until(self, lines: SequenceNoSlice, idx: int) -> None:
if not self._states:
state = self._compiler.root_state
else:
state = self._states[-1]
for i in range(len(self._states), idx):
state, regions = self._hl(state, lines[i], i)
self._states.append(state)
self.regions.append(regions)
def touch(self, lineno: int) -> None:
del self._states[lineno:]
del self.regions[lineno:]
class Syntax(NamedTuple):
grammars: Grammars
theme: Theme
color_manager: ColorManager
def get_file_highlighter(self, filename: str) -> FileSyntax:
compiler = self.grammars.compiler_for_file(filename)
return FileSyntax(compiler, self.theme, self.color_manager)
def get_blank_file_highlighter(self) -> FileSyntax:
compiler = self.grammars.blank_compiler()
return FileSyntax(compiler, self.theme, self.color_manager)
def _init_screen(self, stdscr: 'curses._CursesWindow') -> None:
default_fg, default_bg = self.theme.default.fg, self.theme.default.bg
all_colors = {c for c in (default_fg, default_bg) if c is not None}
todo = list(self.theme.rules.children.values())
while todo:
rule = todo.pop()
if rule.style.fg is not None:
all_colors.add(rule.style.fg)
if rule.style.bg is not None:
all_colors.add(rule.style.bg)
todo.extend(rule.children.values())
for color in sorted(all_colors):
self.color_manager.init_color(color)
pair = self.color_manager.color_pair(default_fg, default_bg)
stdscr.bkgd(' ', curses.color_pair(pair))
@classmethod
def from_screen(
cls,
stdscr: 'curses._CursesWindow',
color_manager: ColorManager,
) -> 'Syntax':
grammars = Grammars.from_syntax_dir(xdg_data('textmate_syntax'))
theme = Theme.from_filename(xdg_config('theme.json'))
ret = cls(grammars, theme, color_manager)
ret._init_screen(stdscr)
return ret

View File

@@ -32,7 +32,7 @@ def _edit(screen: Screen) -> EditResult:
screen.status.update(f'unknown key: {key}')
def c_main(stdscr: 'curses._CursesWindow', args: argparse.Namespace) -> None:
def c_main(stdscr: 'curses._CursesWindow', args: argparse.Namespace) -> int:
with perf_log(args.perf_log) as perf:
screen = Screen(stdscr, args.filenames or [None], perf)
with screen.history.save():
@@ -50,6 +50,7 @@ def c_main(stdscr: 'curses._CursesWindow', args: argparse.Namespace) -> None:
screen.status.clear()
else:
raise AssertionError(f'unreachable {res}')
return 0
def main(argv: Optional[Sequence[str]] = None) -> int:
@@ -57,9 +58,9 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
parser.add_argument('filenames', metavar='filename', nargs='*')
parser.add_argument('--perf-log')
args = parser.parse_args(argv)
with make_stdscr() as stdscr:
c_main(stdscr, args)
return 0
return c_main(stdscr, args)
if __name__ == '__main__':

147
babi/reg.py Normal file
View File

@@ -0,0 +1,147 @@
import functools
from typing import Match
from typing import Optional
from typing import Tuple
import onigurumacffi
from babi.cached_property import cached_property
def _replace_esc(s: str, chars: str) -> str:
"""replace the given escape sequences of `chars` with \\uffff"""
for c in chars:
if f'\\{c}' in s:
break
else:
return s
b = []
i = 0
length = len(s)
while i < length:
try:
sbi = s.index('\\', i)
except ValueError:
b.append(s[i:])
break
if sbi > i:
b.append(s[i:sbi])
b.append('\\')
i = sbi + 1
if i < length:
if s[i] in chars:
b.append('\uffff')
else:
b.append(s[i])
i += 1
return ''.join(b)
class _Reg:
def __init__(self, s: str) -> None:
self._pattern = s
def __repr__(self) -> str:
return f'{type(self).__name__}({self._pattern!r})'
@cached_property
def _reg(self) -> onigurumacffi._Pattern:
return onigurumacffi.compile(self._pattern)
@cached_property
def _reg_no_A(self) -> onigurumacffi._Pattern:
return onigurumacffi.compile(_replace_esc(self._pattern, 'A'))
@cached_property
def _reg_no_G(self) -> onigurumacffi._Pattern:
return onigurumacffi.compile(_replace_esc(self._pattern, 'G'))
@cached_property
def _reg_no_A_no_G(self) -> onigurumacffi._Pattern:
return onigurumacffi.compile(_replace_esc(self._pattern, 'AG'))
def _get_reg(
self,
first_line: bool,
boundary: bool,
) -> onigurumacffi._Pattern:
if boundary:
if first_line:
return self._reg
else:
return self._reg_no_A
else:
if first_line:
return self._reg_no_G
else:
return self._reg_no_A_no_G
def search(
self,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Match[str]]:
return self._get_reg(first_line, boundary).search(line, pos)
def match(
self,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Optional[Match[str]]:
return self._get_reg(first_line, boundary).match(line, pos)
class _RegSet:
def __init__(self, *s: str) -> None:
self._patterns = s
def __repr__(self) -> str:
args = ', '.join(repr(s) for s in self._patterns)
return f'{type(self).__name__}({args})'
@cached_property
def _set(self) -> onigurumacffi._RegSet:
return onigurumacffi.compile_regset(*self._patterns)
@cached_property
def _set_no_A(self) -> onigurumacffi._RegSet:
patterns = (_replace_esc(p, 'A') for p in self._patterns)
return onigurumacffi.compile_regset(*patterns)
@cached_property
def _set_no_G(self) -> onigurumacffi._RegSet:
patterns = (_replace_esc(p, 'G') for p in self._patterns)
return onigurumacffi.compile_regset(*patterns)
@cached_property
def _set_no_A_no_G(self) -> onigurumacffi._RegSet:
patterns = (_replace_esc(p, 'AG') for p in self._patterns)
return onigurumacffi.compile_regset(*patterns)
def search(
self,
line: str,
pos: int,
first_line: bool,
boundary: bool,
) -> Tuple[int, Optional[Match[str]]]:
if boundary:
if first_line:
return self._set.search(line, pos)
else:
return self._set_no_A.search(line, pos)
else:
if first_line:
return self._set_no_G.search(line, pos)
else:
return self._set_no_A_no_G.search(line, pos)
make_reg = functools.lru_cache(maxsize=None)(_Reg)
make_regset = functools.lru_cache(maxsize=None)(_RegSet)
ERR_REG = make_reg(')this pattern always triggers an error when used(')

View File

@@ -20,6 +20,7 @@ from babi.file import Action
from babi.file import File
from babi.file import get_lines
from babi.history import History
from babi.hl.syntax import Syntax
from babi.hl.trailing_whitespace import TrailingWhitespace
from babi.margin import Margin
from babi.perf import Perf
@@ -73,7 +74,10 @@ class Screen:
) -> None:
self.stdscr = stdscr
color_manager = ColorManager.make()
hl_factories = (TrailingWhitespace(color_manager),)
hl_factories = (
Syntax.from_screen(stdscr, color_manager),
TrailingWhitespace(color_manager),
)
self.files = [File(f, hl_factories) for f in filenames]
self.i = 0
self.history = History()
@@ -490,6 +494,7 @@ def _init_screen() -> 'curses._CursesWindow':
# ^S / ^Q / ^Z / ^\ are passed through
curses.raw()
stdscr.keypad(True)
with contextlib.suppress(curses.error):
curses.start_color()
curses.use_default_colors()

152
babi/theme.py Normal file
View File

@@ -0,0 +1,152 @@
import functools
import json
import os.path
import re
from typing import Any
from typing import Dict
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from babi._types import Protocol
from babi.color import Color
from babi.fdict import FDict
# yes I know this is wrong, but it's good enough for now
UN_COMMENT = re.compile(r'^\s*//.*$', re.MULTILINE)
class Style(NamedTuple):
fg: Optional[Color]
bg: Optional[Color]
b: bool
i: bool
u: bool
@classmethod
def blank(cls) -> 'Style':
return cls(fg=None, bg=None, b=False, i=False, u=False)
class PartialStyle(NamedTuple):
fg: Optional[Color] = None
bg: Optional[Color] = None
b: Optional[bool] = None
i: Optional[bool] = None
u: Optional[bool] = None
def overlay_on(self, dct: Dict[str, Any]) -> None:
for attr in self._fields:
value = getattr(self, attr)
if value is not None:
dct[attr] = value
@classmethod
def from_dct(cls, dct: Dict[str, Any]) -> 'PartialStyle':
kv = cls()._asdict()
if 'foreground' in dct:
kv['fg'] = Color.parse(dct['foreground'])
if 'background' in dct:
kv['bg'] = Color.parse(dct['background'])
if dct.get('fontStyle') == 'bold':
kv['b'] = True
elif dct.get('fontStyle') == 'italic':
kv['i'] = True
elif dct.get('fontStyle') == 'underline':
kv['u'] = True
return cls(**kv)
class _TrieNode(Protocol):
@property
def style(self) -> PartialStyle: ...
@property
def children(self) -> FDict[str, '_TrieNode']: ...
class TrieNode(NamedTuple):
style: PartialStyle
children: FDict[str, _TrieNode]
@classmethod
def from_dct(cls, dct: Dict[str, Any]) -> _TrieNode:
children = FDict({
k: TrieNode.from_dct(v) for k, v in dct['children'].items()
})
return cls(PartialStyle.from_dct(dct), children)
class Theme(NamedTuple):
default: Style
rules: _TrieNode
@functools.lru_cache(maxsize=None)
def select(self, scope: Tuple[str, ...]) -> Style:
if not scope:
return self.default
else:
style = self.select(scope[:-1])._asdict()
node = self.rules
for part in scope[-1].split('.'):
if part not in node.children:
break
else:
node = node.children[part]
node.style.overlay_on(style)
return Style(**style)
@classmethod
def from_dct(cls, data: Dict[str, Any]) -> 'Theme':
default = Style.blank()._asdict()
for k in ('foreground', 'editor.foreground'):
if k in data.get('colors', {}):
default['fg'] = Color.parse(data['colors'][k])
break
for k in ('background', 'editor.background'):
if k in data.get('colors', {}):
default['bg'] = Color.parse(data['colors'][k])
break
root: Dict[str, Any] = {'children': {}}
rules = data.get('tokenColors', []) + data.get('settings', [])
for rule in rules:
if 'scope' not in rule:
scopes = ['']
elif isinstance(rule['scope'], str):
scopes = [
# some themes have a buggy trailing comma
s.strip() for s in rule['scope'].strip(',').split(',')
]
else:
scopes = rule['scope']
for scope in scopes:
if ' ' in scope:
# TODO: implement parent scopes
continue
elif scope == '':
PartialStyle.from_dct(rule['settings']).overlay_on(default)
continue
cur = root
for part in scope.split('.'):
cur = cur['children'].setdefault(part, {'children': {}})
cur.update(rule['settings'])
return cls(Style(**default), TrieNode.from_dct(root))
@classmethod
def blank(cls) -> 'Theme':
return cls(Style.blank(), TrieNode.from_dct({'children': {}}))
@classmethod
def from_filename(cls, filename: str) -> 'Theme':
if not os.path.exists(filename):
return cls.blank()
else:
with open(filename) as f:
contents = UN_COMMENT.sub('', f.read())
return cls.from_dct(json.loads(contents))

View File

@@ -1,9 +1,16 @@
import os.path
def xdg_data(*path: str) -> str:
def _xdg(*path: str, env: str, default: str) -> str:
return os.path.join(
os.environ.get('XDG_DATA_HOME') or
os.path.expanduser('~/.local/share'),
os.environ.get(env) or os.path.expanduser(default),
'babi', *path,
)
def xdg_data(*path: str) -> str:
return _xdg(*path, env='XDG_DATA_HOME', default='~/.local/share')
def xdg_config(*path: str) -> str:
return _xdg(*path, env='XDG_CONFIG_HOME', default='~/.config')

81
bin/download-syntax Executable file
View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
import argparse
import enum
import json
import os.path
import plistlib
import urllib.request
from typing import NamedTuple
import cson # pip install cson
DEFAULT_DIR = os.path.join(
os.environ.get('XDG_DATA_HOME') or
os.path.expanduser('~/.local/share'),
'babi/textmate_syntax',
)
Ext = enum.Enum('Ext', 'CSON PLIST JSON')
def _convert_cson(src: bytes) -> str:
return json.dumps(cson.loads(src))
def _convert_json(src: bytes) -> str:
return json.dumps(json.loads(src))
def _convert_plist(src: bytes) -> str:
return json.dumps(plistlib.loads(src))
EXT_CONVERT = {
Ext.CSON: _convert_cson,
Ext.JSON: _convert_json,
Ext.PLIST: _convert_plist,
}
class Syntax(NamedTuple):
name: str
ext: Ext
url: str
SYNTAXES = (
Syntax('c', Ext.JSON, 'https://raw.githubusercontent.com/jeff-hykin/cpp-textmate-grammar/53e39b1c/syntaxes/c.tmLanguage.json'), # noqa: E501
Syntax('css', Ext.CSON, 'https://raw.githubusercontent.com/atom/language-css/9feb69c081308b63f78bb0d6a2af2ff5eb7d869b/grammars/css.cson'), # noqa: E501
Syntax('diff', Ext.PLIST, 'https://raw.githubusercontent.com/textmate/diff.tmbundle/0593bb77/Syntaxes/Diff.plist'), # noqa: E501
Syntax('html', Ext.PLIST, 'https://raw.githubusercontent.com/textmate/html.tmbundle/0c3d5ee5/Syntaxes/HTML.plist'), # noqa: E501
Syntax('html-derivative', Ext.PLIST, 'https://raw.githubusercontent.com/textmate/html.tmbundle/0c3d5ee54de3a993f747f54186b73a4d2d3c44a2/Syntaxes/HTML%20(Derivative).tmLanguage'), # noqa: E501
Syntax('ini', Ext.PLIST, 'https://raw.githubusercontent.com/textmate/ini.tmbundle/7d8c7b55/Syntaxes/Ini.plist'), # noqa: E501
Syntax('json', Ext.PLIST, 'https://raw.githubusercontent.com/microsoft/vscode-JSON.tmLanguage/d113e90937ed3ecc31ac54750aac2e8efa08d784/JSON.tmLanguage'), # noqa: E501
Syntax('markdown', Ext.PLIST, 'https://raw.githubusercontent.com/microsoft/vscode-markdown-tm-grammar/59a5962/syntaxes/markdown.tmLanguage'), # noqa: E501
Syntax('powershell', Ext.PLIST, 'https://raw.githubusercontent.com/PowerShell/EditorSyntax/4a0a0766/PowerShellSyntax.tmLanguage'), # noqa: E501
Syntax('python', Ext.PLIST, 'https://raw.githubusercontent.com/MagicStack/MagicPython/c9b3409d/grammars/MagicPython.tmLanguage'), # noqa: E501
# TODO: https://github.com/zargony/atom-language-rust/pull/149
Syntax('rust', Ext.CSON, 'https://raw.githubusercontent.com/asottile/atom-language-rust/e113ca67/grammars/rust.cson'), # noqa: E501
Syntax('shell', Ext.CSON, 'https://raw.githubusercontent.com/atom/language-shellscript/7008ea926867d8a231003e78094091471c4fccf8/grammars/shell-unix-bash.cson'), # noqa: E501
Syntax('yaml', Ext.PLIST, 'https://raw.githubusercontent.com/textmate/yaml.tmbundle/e54ceae3/Syntaxes/YAML.tmLanguage'), # noqa: E501
)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--dest', default=DEFAULT_DIR)
args = parser.parse_args()
os.makedirs(args.dest, exist_ok=True)
for syntax in SYNTAXES:
print(f'downloading {syntax.name}...')
resp = urllib.request.urlopen(syntax.url).read()
converted = EXT_CONVERT[syntax.ext](resp)
with open(os.path.join(args.dest, f'{syntax.name}.json'), 'w') as f:
f.write(converted)
return 0
if __name__ == '__main__':
exit(main())

View File

@@ -21,6 +21,8 @@ classifiers =
[options]
packages = find:
install_requires =
onigurumacffi>=0.0.10
python_requires = >=3.6.1
[options.entry_points]

View File

@@ -9,6 +9,8 @@ from hecate import Runner
class Token(enum.Enum):
FG_ESC = re.compile(r'\x1b\[38;5;(\d+)m')
BG_ESC = re.compile(r'\x1b\[48;5;(\d+)m')
RESET = re.compile(r'\x1b\[0?m')
ESC = re.compile(r'\x1b\[(\d+)m')
NL = re.compile(r'\n')
@@ -36,8 +38,13 @@ def to_attrs(screen, width):
ret = [[] for _ in range(len(screen.splitlines()))]
for tp, match in tokenize_colors(screen):
if tp is Token.RESET:
fg = bg = attr = 0
if tp is Token.FG_ESC:
fg = int(match[1])
elif tp is Token.BG_ESC:
bg = int(match[1])
elif tp is Token.RESET:
fg = bg = -1
attr = 0
elif tp is Token.ESC:
if match[1] == '7':
attr |= curses.A_REVERSE

63
tests/color_kd_test.py Normal file
View File

@@ -0,0 +1,63 @@
from babi import color_kd
from babi.color import Color
def test_build_trivial():
assert color_kd._build([]) is None
def test_build_single_node():
kd = color_kd._build([(Color(0, 0, 0), 255)])
assert kd == color_kd._KD(Color(0, 0, 0), 255, left=None, right=None)
def test_build_many_colors():
kd = color_kd._build([
(Color(0, 106, 200), 255),
(Color(1, 105, 201), 254),
(Color(2, 104, 202), 253),
(Color(3, 103, 203), 252),
(Color(4, 102, 204), 251),
(Color(5, 101, 205), 250),
(Color(6, 100, 206), 249),
])
# each level is sorted by the next dimension
assert kd == color_kd._KD(
Color(3, 103, 203),
252,
left=color_kd._KD(
Color(1, 105, 201), 254,
left=color_kd._KD(Color(2, 104, 202), 253, None, None),
right=color_kd._KD(Color(0, 106, 200), 255, None, None),
),
right=color_kd._KD(
Color(5, 101, 205), 250,
left=color_kd._KD(Color(6, 100, 206), 249, None, None),
right=color_kd._KD(Color(4, 102, 204), 251, None, None),
),
)
def test_nearest_trivial():
assert color_kd.nearest(Color(0, 0, 0), None) == 0
def test_nearest_one_node():
kd = color_kd._build([(Color(100, 100, 100), 99)])
assert color_kd.nearest(Color(0, 0, 0), kd) == 99
def test_nearest_on_square_distance():
kd = color_kd._build([
(Color(50, 50, 50), 255),
(Color(50, 51, 50), 254),
])
assert color_kd.nearest(Color(0, 0, 0), kd) == 255
assert color_kd.nearest(Color(52, 52, 52), kd) == 254
def test_smoke_kd_256():
kd_256 = color_kd.make_256()
assert color_kd.nearest(Color(0, 0, 0), kd_256) == 16
assert color_kd.nearest(Color(0x1e, 0x77, 0xd3), kd_256) == 32

View File

@@ -0,0 +1,16 @@
import pytest
from babi.color import Color
from babi.color_manager import _color_to_curses
@pytest.mark.parametrize(
('color', 'expected'),
(
(Color(0x00, 0x00, 0x00), (0, 0, 0)),
(Color(0xff, 0xff, 0xff), (1000, 1000, 1000)),
(Color(0x1e, 0x77, 0xd3), (117, 466, 827)),
),
)
def test_color_to_curses(color, expected):
assert _color_to_curses(color) == expected

7
tests/fdict_test.py Normal file
View File

@@ -0,0 +1,7 @@
from babi.fdict import FDict
def test_fdict_repr():
# mostly because this shouldn't get hit elsewhere but is uesful for
# debugging purposes
assert repr(FDict({1: 2, 3: 4})) == 'FDict({1: 2, 3: 4})'

View File

@@ -23,6 +23,13 @@ def xdg_data_home(tmpdir):
yield data_home
@pytest.fixture(autouse=True)
def xdg_config_home(tmpdir):
config_home = tmpdir.join('config_home')
with mock.patch.dict(os.environ, {'XDG_CONFIG_HOME': str(config_home)}):
yield config_home
@pytest.fixture
def ten_lines(tmpdir):
f = tmpdir.join('f')
@@ -175,6 +182,10 @@ class CursesScreen:
attr = attr & ~(0xff << 8)
return (fg, bg, attr)
def bkgd(self, c, attr):
assert c == ' '
self._bkgd_attr = self._to_attr(attr)
def keypad(self, val):
pass
@@ -368,6 +379,9 @@ class DeferredRunner:
def _curses_start_color(self):
curses.COLORS = self._n_colors
def _curses_can_change_color(self):
return self._can_change_color
def _curses_init_pair(self, pair, fg, bg):
self.color_pairs[pair] = (fg, bg)

View File

@@ -0,0 +1,80 @@
import curses
import json
import pytest
from testing.runner import and_exit
THEME = json.dumps({
'colors': {'background': '#00d700', 'foreground': '#303030'},
'tokenColors': [
{'scope': 'comment', 'settings': {'foreground': '#767676'}},
{
'scope': 'diffremove',
'settings': {'foreground': '#5f0000', 'background': '#ff5f5f'},
},
{'scope': 'tqs', 'settings': {'foreground': '#00005f'}},
{'scope': 'b', 'settings': {'fontStyle': 'bold'}},
{'scope': 'i', 'settings': {'fontStyle': 'italic'}},
{'scope': 'u', 'settings': {'fontStyle': 'underline'}},
],
})
SYNTAX = json.dumps({
'scopeName': 'source.demo',
'fileTypes': ['demo'],
'firstLineMatch': '^#!/usr/bin/(env demo|demo)$',
'patterns': [
{'match': r'#.*$\n?', 'name': 'comment'},
{'match': r'^-.*$\n?', 'name': 'diffremove'},
{'begin': '"""', 'end': '"""', 'name': 'tqs'},
],
})
DEMO_S = '''\
- foo
# comment here
uncolored
"""tqs!
still more
"""
'''
@pytest.fixture(autouse=True)
def theme_and_grammar(xdg_data_home, xdg_config_home):
xdg_config_home.join('babi/theme.json').ensure().write(THEME)
xdg_data_home.join('babi/textmate_syntax/demo.json').ensure().write(SYNTAX)
@pytest.fixture
def demo(tmpdir):
f = tmpdir.join('f.demo')
f.write(DEMO_S)
yield f
def test_syntax_highlighting(run, demo):
with run(str(demo), term='screen-256color', width=20) as h, and_exit(h):
h.await_text('still more')
for i, attr in enumerate([
[(236, 40, curses.A_REVERSE)] * 20, # header
[(52, 203, 0)] * 5 + [(236, 40, 0)] * 15, # - foo
[(243, 40, 0)] * 14 + [(236, 40, 0)] * 6, # # comment here
[(236, 40, 0)] * 20, # uncolored
[(17, 40, 0)] * 7 + [(236, 40, 0)] * 13, # """tqs!
[(17, 40, 0)] * 10 + [(236, 40, 0)] * 10, # still more
[(17, 40, 0)] * 3 + [(236, 40, 0)] * 17, # """
]):
h.assert_screen_attr_equals(i, attr)
def test_syntax_highlighting_does_not_highlight_arrows(run, tmpdir):
f = tmpdir.join('f')
f.write(
f'#!/usr/bin/env demo\n'
f'# l{"o" * 15}ng comment\n',
)
with run(str(f), term='screen-256color', width=20) as h, and_exit(h):
h.await_text('loooo')
h.assert_screen_attr_equals(2, [(243, 40, 0)] * 19 + [(236, 40, 0)])

530
tests/highlight_test.py Normal file
View File

@@ -0,0 +1,530 @@
from babi.highlight import Grammar
from babi.highlight import Grammars
from babi.highlight import highlight_line
from babi.highlight import Region
def _compiler_state(grammar_dct, *others):
grammar = Grammar.from_data(grammar_dct)
grammars = [grammar, *(Grammar.from_data(dct) for dct in others)]
compiler = Grammars(grammars).compiler_for_scope(grammar.scope_name)
return compiler, compiler.root_state
def test_backslash_a():
grammar = {
'scopeName': 'test',
'patterns': [{'name': 'aaa', 'match': r'\Aa+'}],
}
compiler, state = _compiler_state(grammar)
state, (region_0,) = highlight_line(compiler, state, 'aaa', True)
state, (region_1,) = highlight_line(compiler, state, 'aaa', False)
# \A should only match at the beginning of the file
assert region_0 == Region(0, 3, ('test', 'aaa'))
assert region_1 == Region(0, 3, ('test',))
BEGIN_END_NO_NL = {
'scopeName': 'test',
'patterns': [{
'begin': 'x',
'end': 'x',
'patterns': [
{'match': r'\Ga', 'name': 'ga'},
{'match': 'a', 'name': 'noga'},
],
}],
}
def test_backslash_g_inline():
compiler, state = _compiler_state(BEGIN_END_NO_NL)
_, regions = highlight_line(compiler, state, 'xaax', True)
assert regions == (
Region(0, 1, ('test',)),
Region(1, 2, ('test', 'ga')),
Region(2, 3, ('test', 'noga')),
Region(3, 4, ('test',)),
)
def test_backslash_g_next_line():
compiler, state = _compiler_state(BEGIN_END_NO_NL)
state, regions1 = highlight_line(compiler, state, 'x\n', True)
state, regions2 = highlight_line(compiler, state, 'aax\n', False)
assert regions1 == (
Region(0, 1, ('test',)),
Region(1, 2, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'noga')),
Region(1, 2, ('test', 'noga')),
Region(2, 3, ('test',)),
Region(3, 4, ('test',)),
)
def test_end_before_other_match():
compiler, state = _compiler_state(BEGIN_END_NO_NL)
state, regions = highlight_line(compiler, state, 'xazzx', True)
assert regions == (
Region(0, 1, ('test',)),
Region(1, 2, ('test', 'ga')),
Region(2, 4, ('test',)),
Region(4, 5, ('test',)),
)
BEGIN_END_NL = {
'scopeName': 'test',
'patterns': [{
'begin': r'x$\n?',
'end': 'x',
'patterns': [
{'match': r'\Ga', 'name': 'ga'},
{'match': 'a', 'name': 'noga'},
],
}],
}
def test_backslash_g_captures_nl():
compiler, state = _compiler_state(BEGIN_END_NL)
state, regions1 = highlight_line(compiler, state, 'x\n', True)
state, regions2 = highlight_line(compiler, state, 'aax\n', False)
assert regions1 == (
Region(0, 2, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'ga')),
Region(1, 2, ('test', 'noga')),
Region(2, 3, ('test',)),
Region(3, 4, ('test',)),
)
def test_backslash_g_captures_nl_next_line():
compiler, state = _compiler_state(BEGIN_END_NL)
state, regions1 = highlight_line(compiler, state, 'x\n', True)
state, regions2 = highlight_line(compiler, state, 'aa\n', False)
state, regions3 = highlight_line(compiler, state, 'aax\n', False)
assert regions1 == (
Region(0, 2, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'ga')),
Region(1, 2, ('test', 'noga')),
Region(2, 3, ('test',)),
)
assert regions3 == (
Region(0, 1, ('test', 'ga')),
Region(1, 2, ('test', 'noga')),
Region(2, 3, ('test',)),
Region(3, 4, ('test',)),
)
def test_while_no_nl():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [{
'begin': '> ',
'while': '> ',
'contentName': 'while',
'patterns': [
{'match': r'\Ga', 'name': 'ga'},
{'match': 'a', 'name': 'noga'},
],
}],
})
state, regions1 = highlight_line(compiler, state, '> aa\n', True)
state, regions2 = highlight_line(compiler, state, '> aa\n', False)
state, regions3 = highlight_line(compiler, state, 'after\n', False)
assert regions1 == (
Region(0, 2, ('test',)),
Region(2, 3, ('test', 'while', 'ga')),
Region(3, 4, ('test', 'while', 'noga')),
Region(4, 5, ('test', 'while')),
)
assert regions2 == (
Region(0, 2, ('test', 'while')),
Region(2, 3, ('test', 'while', 'ga')),
Region(3, 4, ('test', 'while', 'noga')),
Region(4, 5, ('test', 'while')),
)
assert regions3 == (
Region(0, 6, ('test',)),
)
def test_complex_captures():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'match': '(<).([^>]+)(>)',
'captures': {
'1': {'name': 'lbracket'},
'2': {
'patterns': [
{'match': 'a', 'name': 'a'},
{'match': 'z', 'name': 'z'},
],
},
'3': {'name': 'rbracket'},
},
},
],
})
state, regions = highlight_line(compiler, state, '<qabz>', first_line=True)
assert regions == (
Region(0, 1, ('test', 'lbracket')),
Region(1, 2, ('test',)),
Region(2, 3, ('test', 'a')),
Region(3, 4, ('test',)),
Region(4, 5, ('test', 'z')),
Region(5, 6, ('test', 'rbracket')),
)
def test_captures_multiple_applied_to_same_capture():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'match': '((a)) ((b) c) (d (e)) ((f) )',
'name': 'matched',
'captures': {
'1': {'name': 'g1'},
'2': {'name': 'g2'},
'3': {'name': 'g3'},
'4': {'name': 'g4'},
'5': {'name': 'g5'},
'6': {'name': 'g6'},
'7': {
'patterns': [
{'match': 'f', 'name': 'g7f'},
{'match': ' ', 'name': 'g7space'},
],
},
# this one has to backtrack some
'8': {'name': 'g8'},
},
},
],
})
state, regions = highlight_line(compiler, state, 'a b c d e f ', True)
assert regions == (
Region(0, 1, ('test', 'matched', 'g1', 'g2')),
Region(1, 2, ('test', 'matched')),
Region(2, 3, ('test', 'matched', 'g3', 'g4')),
Region(3, 5, ('test', 'matched', 'g3')),
Region(5, 6, ('test', 'matched')),
Region(6, 8, ('test', 'matched', 'g5')),
Region(8, 9, ('test', 'matched', 'g5', 'g6')),
Region(9, 10, ('test', 'matched')),
Region(10, 11, ('test', 'matched', 'g7f', 'g8')),
Region(11, 12, ('test', 'matched', 'g7space')),
)
def test_captures_ignores_empty():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [{
'match': '(.*) hi',
'captures': {'1': {'name': 'before'}},
}],
})
state, regions1 = highlight_line(compiler, state, ' hi\n', True)
state, regions2 = highlight_line(compiler, state, 'o hi\n', False)
assert regions1 == (
Region(0, 3, ('test',)),
Region(3, 4, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'before')),
Region(1, 4, ('test',)),
Region(4, 5, ('test',)),
)
def test_captures_ignores_invalid_out_of_bounds():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [{'match': '.', 'captures': {'1': {'name': 'oob'}}}],
})
state, regions = highlight_line(compiler, state, 'x', first_line=True)
assert regions == (
Region(0, 1, ('test',)),
)
def test_captures_begin_end():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'begin': '(""")',
'end': '(""")',
'beginCaptures': {'1': {'name': 'startquote'}},
'endCaptures': {'1': {'name': 'endquote'}},
},
],
})
state, regions = highlight_line(compiler, state, '"""x"""', True)
assert regions == (
Region(0, 3, ('test', 'startquote')),
Region(3, 4, ('test',)),
Region(4, 7, ('test', 'endquote')),
)
def test_captures_while_captures():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'begin': '(>) ',
'while': '(>) ',
'beginCaptures': {'1': {'name': 'bblock'}},
'whileCaptures': {'1': {'name': 'wblock'}},
},
],
})
state, regions1 = highlight_line(compiler, state, '> x\n', True)
state, regions2 = highlight_line(compiler, state, '> x\n', False)
assert regions1 == (
Region(0, 1, ('test', 'bblock')),
Region(1, 2, ('test',)),
Region(2, 4, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'wblock')),
Region(1, 2, ('test',)),
Region(2, 4, ('test',)),
)
def test_captures_implies_begin_end_captures():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'begin': '(""")',
'end': '(""")',
'captures': {'1': {'name': 'quote'}},
},
],
})
state, regions = highlight_line(compiler, state, '"""x"""', True)
assert regions == (
Region(0, 3, ('test', 'quote')),
Region(3, 4, ('test',)),
Region(4, 7, ('test', 'quote')),
)
def test_captures_implies_begin_while_captures():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'begin': '(>) ',
'while': '(>) ',
'captures': {'1': {'name': 'block'}},
},
],
})
state, regions1 = highlight_line(compiler, state, '> x\n', True)
state, regions2 = highlight_line(compiler, state, '> x\n', False)
assert regions1 == (
Region(0, 1, ('test', 'block')),
Region(1, 2, ('test',)),
Region(2, 4, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'block')),
Region(1, 2, ('test',)),
Region(2, 4, ('test',)),
)
def test_include_self():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [
{
'begin': '<',
'end': '>',
'contentName': 'bracketed',
'patterns': [{'include': '$self'}],
},
{'match': '.', 'name': 'content'},
],
})
state, regions = highlight_line(compiler, state, '<<_>>', first_line=True)
assert regions == (
Region(0, 1, ('test',)),
Region(1, 2, ('test', 'bracketed')),
Region(2, 3, ('test', 'bracketed', 'bracketed', 'content')),
Region(3, 4, ('test', 'bracketed', 'bracketed')),
Region(4, 5, ('test', 'bracketed')),
)
def test_include_repository_rule():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [{'include': '#impl'}],
'repository': {
'impl': {
'patterns': [
{'match': 'a', 'name': 'a'},
{'match': '.', 'name': 'other'},
],
},
},
})
state, regions = highlight_line(compiler, state, 'az', first_line=True)
assert regions == (
Region(0, 1, ('test', 'a')),
Region(1, 2, ('test', 'other')),
)
def test_include_other_grammar():
compiler, state = _compiler_state(
{
'scopeName': 'test',
'patterns': [
{
'begin': '<',
'end': '>',
'name': 'angle',
'patterns': [{'include': 'other.grammar'}],
},
{
'begin': '`',
'end': '`',
'name': 'tick',
'patterns': [{'include': 'other.grammar#backtick'}],
},
],
},
{
'scopeName': 'other.grammar',
'patterns': [
{'match': 'a', 'name': 'roota'},
{'match': '.', 'name': 'rootother'},
],
'repository': {
'backtick': {
'patterns': [
{'match': 'a', 'name': 'ticka'},
{'match': '.', 'name': 'tickother'},
],
},
},
},
)
state, regions1 = highlight_line(compiler, state, '<az>\n', True)
state, regions2 = highlight_line(compiler, state, '`az`\n', False)
assert regions1 == (
Region(0, 1, ('test', 'angle')),
Region(1, 2, ('test', 'angle', 'roota')),
Region(2, 3, ('test', 'angle', 'rootother')),
Region(3, 4, ('test', 'angle')),
Region(4, 5, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'tick')),
Region(1, 2, ('test', 'tick', 'ticka')),
Region(2, 3, ('test', 'tick', 'tickother')),
Region(3, 4, ('test', 'tick')),
Region(4, 5, ('test',)),
)
def test_include_base():
compiler, state = _compiler_state(
{
'scopeName': 'test',
'patterns': [
{
'begin': '<',
'end': '>',
'name': 'bracket',
# $base from root grammar includes itself
'patterns': [{'include': '$base'}],
},
{'include': 'other.grammar'},
{'match': 'z', 'name': 'testz'},
],
},
{
'scopeName': 'other.grammar',
'patterns': [
{
'begin': '`',
'end': '`',
'name': 'tick',
# $base from included grammar includes the root
'patterns': [{'include': '$base'}],
},
],
},
)
state, regions1 = highlight_line(compiler, state, '<z>\n', True)
state, regions2 = highlight_line(compiler, state, '`z`\n', False)
assert regions1 == (
Region(0, 1, ('test', 'bracket')),
Region(1, 2, ('test', 'bracket', 'testz')),
Region(2, 3, ('test', 'bracket')),
Region(3, 4, ('test',)),
)
assert regions2 == (
Region(0, 1, ('test', 'tick')),
Region(1, 2, ('test', 'tick', 'testz')),
Region(2, 3, ('test', 'tick')),
Region(3, 4, ('test',)),
)

0
tests/hl/__init__.py Normal file
View File

147
tests/hl/syntax_test.py Normal file
View File

@@ -0,0 +1,147 @@
import contextlib
import curses
from unittest import mock
import pytest
from babi.color_manager import ColorManager
from babi.highlight import Grammar
from babi.highlight import Grammars
from babi.hl.syntax import Syntax
from babi.theme import Color
from babi.theme import Theme
class FakeCurses:
def __init__(self, *, n_colors, can_change_color):
self._n_colors = n_colors
self._can_change_color = can_change_color
self.colors = {}
self.pairs = {}
def _curses__can_change_color(self):
return self._can_change_color
def _curses__init_color(self, n, r, g, b):
self.colors[n] = (r, g, b)
def _curses__init_pair(self, n, fg, bg):
self.pairs[n] = (fg, bg)
def _curses__color_pair(self, n):
assert n == 0 or n in self.pairs
return n << 8
@classmethod
@contextlib.contextmanager
def patch(cls, **kwargs):
fake = cls(**kwargs)
with mock.patch.object(curses, 'COLORS', fake._n_colors, create=True):
with mock.patch.multiple(
curses,
can_change_color=fake._curses__can_change_color,
color_pair=fake._curses__color_pair,
init_color=fake._curses__init_color,
init_pair=fake._curses__init_pair,
):
yield fake
class FakeScreen:
def __init__(self):
self.attr = 0
def bkgd(self, c, attr):
assert c == ' '
self.attr = attr
@pytest.fixture
def stdscr():
return FakeScreen()
THEME = Theme.from_dct({
'colors': {'foreground': '#cccccc', 'background': '#333333'},
'tokenColors': [
{'scope': 'string', 'settings': {'foreground': '#009900'}},
{'scope': 'keyword', 'settings': {'background': '#000000'}},
{'scope': 'keyword', 'settings': {'fontStyle': 'bold'}},
],
})
@pytest.fixture
def syntax():
return Syntax(Grammars([Grammar.blank()]), THEME, ColorManager.make())
def test_init_screen_low_color(stdscr, syntax):
with FakeCurses.patch(n_colors=16, can_change_color=False) as fake_curses:
syntax._init_screen(stdscr)
assert syntax.color_manager.colors == {}
assert syntax.color_manager.raw_pairs == {}
assert fake_curses.colors == {}
assert fake_curses.pairs == {}
assert stdscr.attr == 0
def test_init_screen_256_color(stdscr, syntax):
with FakeCurses.patch(n_colors=256, can_change_color=False) as fake_curses:
syntax._init_screen(stdscr)
assert syntax.color_manager.colors == {
Color.parse('#cccccc'): 252,
Color.parse('#333333'): 236,
Color.parse('#000000'): 16,
Color.parse('#009900'): 28,
}
assert syntax.color_manager.raw_pairs == {(252, 236): 1}
assert fake_curses.colors == {}
assert fake_curses.pairs == {1: (252, 236)}
assert stdscr.attr == 1 << 8
def test_init_screen_true_color(stdscr, syntax):
with FakeCurses.patch(n_colors=256, can_change_color=True) as fake_curses:
syntax._init_screen(stdscr)
# weird colors happened with low color numbers so it counts down from max
assert syntax.color_manager.colors == {
Color.parse('#000000'): 255,
Color.parse('#009900'): 254,
Color.parse('#333333'): 253,
Color.parse('#cccccc'): 252,
}
assert syntax.color_manager.raw_pairs == {(252, 253): 1}
assert fake_curses.colors == {
255: (0, 0, 0),
254: (0, 600, 0),
253: (200, 200, 200),
252: (800, 800, 800),
}
assert fake_curses.pairs == {1: (252, 253)}
assert stdscr.attr == 1 << 8
def test_lazily_instantiated_pairs(stdscr, syntax):
# pairs are assigned lazily to avoid hard upper limit (256) on pairs
with FakeCurses.patch(n_colors=256, can_change_color=False) as fake_curses:
syntax._init_screen(stdscr)
assert len(syntax.color_manager.raw_pairs) == 1
assert len(fake_curses.pairs) == 1
style = THEME.select(('string.python',))
attr = syntax.get_blank_file_highlighter().attr(style)
assert attr == 2 << 8
assert len(syntax.color_manager.raw_pairs) == 2
assert len(fake_curses.pairs) == 2
def test_style_attributes_applied(stdscr, syntax):
with FakeCurses.patch(n_colors=256, can_change_color=False):
syntax._init_screen(stdscr)
style = THEME.select(('keyword.python',))
attr = syntax.get_blank_file_highlighter().attr(style)
assert attr == 2 << 8 | curses.A_BOLD

74
tests/reg_test.py Normal file
View File

@@ -0,0 +1,74 @@
import onigurumacffi
import pytest
from babi.reg import _Reg
from babi.reg import _RegSet
def test_reg_first_line():
reg = _Reg(r'\Ahello')
assert reg.match('hello', 0, first_line=True, boundary=True)
assert reg.search('hello', 0, first_line=True, boundary=True)
assert not reg.match('hello', 0, first_line=False, boundary=True)
assert not reg.search('hello', 0, first_line=False, boundary=True)
def test_reg_boundary():
reg = _Reg(r'\Ghello')
assert reg.search('ohello', 1, first_line=True, boundary=True)
assert reg.match('ohello', 1, first_line=True, boundary=True)
assert not reg.search('ohello', 1, first_line=True, boundary=False)
assert not reg.match('ohello', 1, first_line=True, boundary=False)
def test_reg_neither():
reg = _Reg(r'(\A|\G)hello')
assert not reg.search('hello', 0, first_line=False, boundary=False)
assert not reg.search('ohello', 1, first_line=False, boundary=False)
def test_reg_other_escapes_left_untouched():
reg = _Reg(r'(^|\A|\G)\w\s\w')
assert reg.match('a b', 0, first_line=False, boundary=False)
def test_reg_not_out_of_bounds_at_end():
# the only way this is triggerable is with an illegal regex, we'd rather
# produce an error about the regex being wrong than an IndexError
reg = _Reg('\\A\\')
with pytest.raises(onigurumacffi.OnigError) as excinfo:
reg.search('\\', 0, first_line=False, boundary=False)
msg, = excinfo.value.args
assert msg == 'end pattern at escape'
def test_reg_repr():
assert repr(_Reg(r'\A123')) == r"_Reg('\\A123')"
def test_regset_first_line():
regset = _RegSet(r'\Ahello', 'hello')
idx, _ = regset.search('hello', 0, first_line=True, boundary=True)
assert idx == 0
idx, _ = regset.search('hello', 0, first_line=False, boundary=True)
assert idx == 1
def test_regset_boundary():
regset = _RegSet(r'\Ghello', 'hello')
idx, _ = regset.search('ohello', 1, first_line=True, boundary=True)
assert idx == 0
idx, _ = regset.search('ohello', 1, first_line=True, boundary=False)
assert idx == 1
def test_regset_neither():
regset = _RegSet(r'\Ahello', r'\Ghello', 'hello')
idx, _ = regset.search('hello', 0, first_line=False, boundary=False)
assert idx == 2
idx, _ = regset.search('ohello', 1, first_line=False, boundary=False)
assert idx == 2
def test_regset_repr():
assert repr(_RegSet('ohai', r'\Aworld')) == r"_RegSet('ohai', '\\Aworld')"

85
tests/theme_test.py Normal file
View File

@@ -0,0 +1,85 @@
import pytest
from babi.color import Color
from babi.theme import Theme
THEME = Theme.from_dct({
'colors': {'foreground': '#100000', 'background': '#aaaaaa'},
'tokenColors': [
{'scope': 'foo.bar', 'settings': {'foreground': '#200000'}},
{'scope': 'foo', 'settings': {'foreground': '#300000'}},
{'scope': 'parent foo.bar', 'settings': {'foreground': '#400000'}},
],
})
def unhex(color):
return f'#{hex(color.r << 16 | color.g << 8 | color.b)[2:]}'
@pytest.mark.parametrize(
('scope', 'expected'),
(
pytest.param(('',), '#100000', id='trivial'),
pytest.param(('unknown',), '#100000', id='unknown'),
pytest.param(('foo.bar',), '#200000', id='exact match'),
pytest.param(('foo.baz',), '#300000', id='prefix match'),
pytest.param(('src.diff', 'foo.bar'), '#200000', id='nested scope'),
pytest.param(
('foo.bar', 'unrelated'), '#200000',
id='nested scope not last one',
),
),
)
def test_select(scope, expected):
ret = THEME.select(scope)
assert unhex(ret.fg) == expected
def test_theme_default_settings_from_no_scope():
theme = Theme.from_dct({
'tokenColors': [
{'settings': {'foreground': '#cccccc', 'background': '#333333'}},
],
})
assert theme.default.fg == Color.parse('#cccccc')
assert theme.default.bg == Color.parse('#333333')
def test_theme_default_settings_from_empty_string_scope():
theme = Theme.from_dct({
'tokenColors': [
{
'scope': '',
'settings': {'foreground': '#cccccc', 'background': '#333333'},
},
],
})
assert theme.default.fg == Color.parse('#cccccc')
assert theme.default.bg == Color.parse('#333333')
def test_theme_scope_split_by_commas():
theme = Theme.from_dct({
'colors': {'foreground': '#cccccc', 'background': '#333333'},
'tokenColors': [
{'scope': 'a, b, c', 'settings': {'fontStyle': 'italic'}},
],
})
assert theme.select(('d',)).i is False
assert theme.select(('a',)).i is True
assert theme.select(('b',)).i is True
assert theme.select(('c',)).i is True
def test_theme_scope_as_A_list():
theme = Theme.from_dct({
'colors': {'foreground': '#cccccc', 'background': '#333333'},
'tokenColors': [
{'scope': ['a', 'b', 'c'], 'settings': {'fontStyle': 'underline'}},
],
})
assert theme.select(('d',)).u is False
assert theme.select(('a',)).u is True
assert theme.select(('b',)).u is True
assert theme.select(('c',)).u is True

20
tests/user_data_test.py Normal file
View File

@@ -0,0 +1,20 @@
import os
from unittest import mock
from babi.user_data import xdg_data
def test_when_xdg_data_home_is_set():
with mock.patch.dict(os.environ, {'XDG_DATA_HOME': '/foo'}):
ret = xdg_data('history', 'command')
assert ret == '/foo/babi/history/command'
def test_when_xdg_data_home_is_not_set():
def fake_expanduser(s):
return s.replace('~', '/home/username')
with mock.patch.object(os.path, 'expanduser', fake_expanduser):
with mock.patch.dict(os.environ, clear=True):
ret = xdg_data('history')
assert ret == '/home/username/.local/share/babi/history'

View File

@@ -13,3 +13,6 @@ commands =
skip_install = true
deps = pre-commit
commands = pre-commit run --all-files --show-diff-on-failure
[pep8]
ignore = E265,E501,W504