fix expansion of regexes with regex-special characters

This commit is contained in:
Anthony Sottile
2020-03-22 12:43:34 -07:00
parent bf1c3d1ee1
commit bdf07b8cb3
3 changed files with 25 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ from babi.fdict import FDict
from babi.reg import _Reg
from babi.reg import _RegSet
from babi.reg import ERR_REG
from babi.reg import expand_escaped
from babi.reg import make_reg
from babi.reg import make_regset
@@ -418,7 +419,7 @@ class EndRule(NamedTuple):
next_scope = scope + self.content_name
boundary = match.end() == len(match.string)
reg = make_reg(match.expand(self.end))
reg = make_reg(expand_escaped(match, self.end))
state = state.push(Entry(next_scope, self, reg, boundary))
regions = _captures(compiler, scope, match, self.begin_captures)
return state, True, regions
@@ -479,7 +480,7 @@ class WhileRule(NamedTuple):
next_scope = scope + self.content_name
boundary = match.end() == len(match.string)
reg = make_reg(match.expand(self.while_))
reg = make_reg(expand_escaped(match, self.while_))
state = state.push_while(self, Entry(next_scope, self, reg, boundary))
regions = _captures(compiler, scope, match, self.begin_captures)
return state, True, regions

View File

@@ -1,4 +1,5 @@
import functools
import re
from typing import Match
from typing import Optional
from typing import Tuple
@@ -7,6 +8,8 @@ import onigurumacffi
from babi.cached_property import cached_property
_BACKREF_RE = re.compile(r'((?<!\\)(?:\\\\)*)\\([0-9]+)')
def _replace_esc(s: str, chars: str) -> str:
"""replace the given escape sequences of `chars` with \\uffff"""
@@ -142,6 +145,10 @@ class _RegSet:
return self._set_no_A_no_G.search(line, pos)
def expand_escaped(match: Match[str], s: str) -> str:
return _BACKREF_RE.sub(lambda m: f'{m[1]}{re.escape(match[int(m[2])])}', s)
make_reg = functools.lru_cache(maxsize=None)(_Reg)
make_regset = functools.lru_cache(maxsize=None)(_RegSet)
ERR_REG = make_reg(')this pattern always triggers an error when used(')

View File

@@ -564,3 +564,18 @@ def test_rule_with_begin_and_no_end():
Region(5, 7, ('test', 'bang', 'invalid')),
Region(7, 8, ('test', 'bang', 'invalid')),
)
def test_begin_end_substitute_special_chars():
compiler, state = _compiler_state({
'scopeName': 'test',
'patterns': [{'begin': r'(\*)', 'end': r'\1', 'name': 'italic'}],
})
state, regions = highlight_line(compiler, state, '*italic*', True)
assert regions == (
Region(0, 1, ('test', 'italic')),
Region(1, 7, ('test', 'italic')),
Region(7, 8, ('test', 'italic')),
)