1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
def walk_sql(sql: str, start: int = 0):
"""Yield (index, char, depth) for unquoted chars, tracking paren depth."""
depth = 0
quote_char = None
escape = False
for i in range(start, len(sql)):
char = sql[i]
if escape:
escape = False
continue
if quote_char:
if char == "\\" and quote_char == "'":
escape = True
elif char == quote_char:
quote_char = None
continue
if char in {"'", '"', "`"}:
quote_char = char
continue
if char == "(":
depth += 1
elif char == ")":
depth -= 1
yield i, char, depth
def extract_parenthesized_block(sql: str, start: int) -> tuple[str, int]:
"""Return the content and closing index of the first parenthesized block."""
block_start = -1
for i, char, depth in walk_sql(sql, start):
if char == "(" and depth == 1 and block_start == -1:
block_start = i + 1
elif char == ")" and depth == 0 and block_start != -1:
return sql[block_start:i], i
raise ValueError("Could not parse parenthesized SQL block")
def split_top_level(sql: str, delimiter: str = ",") -> list[str]:
"""Split SQL on *delimiter* only at the top nesting level."""
parts = []
part_start = 0
for i, char, depth in walk_sql(sql):
if char == delimiter and depth == 0:
part = sql[part_start:i].strip()
if part:
parts.append(part)
part_start = i + 1
tail = sql[part_start:].strip()
if tail:
parts.append(tail)
return parts
def find_top_level_clause(sql: str, clauses: tuple[str, ...]) -> tuple[int, str | None]:
"""Find the first occurrence of any *clause* at top nesting level."""
upper_sql = sql.upper()
for i, _char, depth in walk_sql(sql):
if depth == 0:
for clause in clauses:
if upper_sql.startswith(clause, i):
return i, clause
return -1, None
|