cqlsh_rs/
colorizer.rs

1//! CQL syntax colorization for the REPL prompt and output.
2//!
3//! Provides a simple tokenizer that applies ANSI colors to CQL keywords,
4//! string literals, numbers, and comments using crossterm styling.
5//! Also provides output coloring for query result values, headers, and errors
6//! matching Python cqlsh's color scheme.
7
8use crossterm::style::Stylize;
9
10use crate::driver::types::CqlValue;
11
12/// Set of CQL keywords to highlight (uppercase for matching).
13const KEYWORDS: &[&str] = &[
14    "ADD",
15    "ALTER",
16    "AND",
17    "APPLY",
18    "AS",
19    "ASC",
20    "AUTHORIZE",
21    "BATCH",
22    "BEGIN",
23    "BY",
24    "CALLED",
25    "CLUSTERING",
26    "COLUMN",
27    "COMPACT",
28    "CONTAINS",
29    "COUNT",
30    "CREATE",
31    "CUSTOM",
32    "DELETE",
33    "DESC",
34    "DESCRIBE",
35    "DISTINCT",
36    "DROP",
37    "ENTRIES",
38    "EXECUTE",
39    "EXISTS",
40    "FILTERING",
41    "FROM",
42    "FROZEN",
43    "FULL",
44    "FUNCTION",
45    "GRANT",
46    "IF",
47    "IN",
48    "INDEX",
49    "INSERT",
50    "INTO",
51    "IS",
52    "JSON",
53    "KEY",
54    "KEYSPACE",
55    "KEYSPACES",
56    "LANGUAGE",
57    "LIKE",
58    "LIMIT",
59    "LIST",
60    "LOGIN",
61    "MAP",
62    "MATERIALIZED",
63    "MODIFY",
64    "NAMESPACE",
65    "NOT",
66    "NULL",
67    "OF",
68    "ON",
69    "OR",
70    "ORDER",
71    "PARTITION",
72    "PASSWORD",
73    "PER",
74    "PERMISSION",
75    "PERMISSIONS",
76    "PRIMARY",
77    "RENAME",
78    "REPLACE",
79    "RETURNS",
80    "REVOKE",
81    "SCHEMA",
82    "SELECT",
83    "SET",
84    "STATIC",
85    "STORAGE",
86    "SUPERUSER",
87    "TABLE",
88    "TABLES",
89    "TEXT",
90    "TIMESTAMP",
91    "TO",
92    "TOKEN",
93    "TRIGGER",
94    "TRUNCATE",
95    "TTL",
96    "TUPLE",
97    "TYPE",
98    "UNLOGGED",
99    "UPDATE",
100    "USE",
101    "USER",
102    "USERS",
103    "USING",
104    "VALUES",
105    "VIEW",
106    "WHERE",
107    "WITH",
108    "WRITETIME",
109];
110
111/// CQL syntax colorizer using ANSI escape codes.
112pub struct CqlColorizer {
113    enabled: bool,
114}
115
116impl CqlColorizer {
117    /// Create a new colorizer. If `enabled` is false, all methods return input unchanged.
118    pub fn new(enabled: bool) -> Self {
119        Self { enabled }
120    }
121
122    /// Returns whether colorization is enabled.
123    pub fn is_enabled(&self) -> bool {
124        self.enabled
125    }
126
127    /// Colorize a CQL result value matching Python cqlsh's color scheme.
128    ///
129    /// Color mapping:
130    /// - Text/Ascii → yellow bold
131    /// - Numeric/boolean/uuid/timestamp/date/time/duration/inet → green bold
132    /// - Blob → dark magenta (non-bold)
133    /// - Null → red bold
134    /// - Collection delimiters → blue bold, inner values colored by type
135    pub fn colorize_value(&self, value: &CqlValue) -> String {
136        if !self.enabled {
137            return value.to_string();
138        }
139        self.colorize_value_inner(value)
140    }
141
142    /// Colorize a column header (magenta bold, matching Python cqlsh default).
143    pub fn colorize_header(&self, name: &str) -> String {
144        if !self.enabled {
145            return name.to_string();
146        }
147        format!("{}", name.magenta().bold())
148    }
149
150    /// Colorize an error message (red bold, matching Python cqlsh).
151    pub fn colorize_error(&self, msg: &str) -> String {
152        if !self.enabled {
153            return msg.to_string();
154        }
155        format!("{}", msg.red().bold())
156    }
157
158    /// Colorize a warning message (red bold, matching Python cqlsh).
159    pub fn colorize_warning(&self, msg: &str) -> String {
160        self.colorize_error(msg)
161    }
162
163    /// Colorize the "Tracing session:" label (magenta bold).
164    pub fn colorize_trace_label(&self, label: &str) -> String {
165        if !self.enabled {
166            return label.to_string();
167        }
168        format!("{}", label.magenta().bold())
169    }
170
171    /// Colorize the cluster name in the welcome message (blue bold).
172    pub fn colorize_cluster_name(&self, name: &str) -> String {
173        if !self.enabled {
174            return name.to_string();
175        }
176        format!("{}", name.blue().bold())
177    }
178
179    /// Inner recursive value colorizer.
180    fn colorize_value_inner(&self, value: &CqlValue) -> String {
181        match value {
182            CqlValue::Ascii(s) | CqlValue::Text(s) => {
183                format!("{}", s.as_str().yellow().bold())
184            }
185            CqlValue::Int(_)
186            | CqlValue::BigInt(_)
187            | CqlValue::SmallInt(_)
188            | CqlValue::TinyInt(_)
189            | CqlValue::Float(_)
190            | CqlValue::Double(_)
191            | CqlValue::Decimal(_)
192            | CqlValue::Varint(_)
193            | CqlValue::Counter(_)
194            | CqlValue::Boolean(_)
195            | CqlValue::Uuid(_)
196            | CqlValue::TimeUuid(_)
197            | CqlValue::Timestamp(_)
198            | CqlValue::Date(_)
199            | CqlValue::Time(_)
200            | CqlValue::Duration { .. }
201            | CqlValue::Inet(_) => {
202                format!("{}", value.to_string().green().bold())
203            }
204            CqlValue::Blob(_) => {
205                format!("{}", value.to_string().dark_magenta())
206            }
207            CqlValue::Null => {
208                format!("{}", "null".red().bold())
209            }
210            CqlValue::Unset => {
211                format!("{}", "<unset>".red().bold())
212            }
213            CqlValue::List(items) => {
214                let mut result = format!("{}", "[".blue().bold());
215                for (i, item) in items.iter().enumerate() {
216                    if i > 0 {
217                        result.push_str(&format!("{}", ", ".blue().bold()));
218                    }
219                    result.push_str(&self.colorize_collection_element(item));
220                }
221                result.push_str(&format!("{}", "]".blue().bold()));
222                result
223            }
224            CqlValue::Set(items) => {
225                let mut result = format!("{}", "{".blue().bold());
226                for (i, item) in items.iter().enumerate() {
227                    if i > 0 {
228                        result.push_str(&format!("{}", ", ".blue().bold()));
229                    }
230                    result.push_str(&self.colorize_collection_element(item));
231                }
232                result.push_str(&format!("{}", "}".blue().bold()));
233                result
234            }
235            CqlValue::Map(entries) => {
236                let mut result = format!("{}", "{".blue().bold());
237                for (i, (k, v)) in entries.iter().enumerate() {
238                    if i > 0 {
239                        result.push_str(&format!("{}", ", ".blue().bold()));
240                    }
241                    result.push_str(&self.colorize_collection_element(k));
242                    result.push_str(&format!("{}", ": ".blue().bold()));
243                    result.push_str(&self.colorize_collection_element(v));
244                }
245                result.push_str(&format!("{}", "}".blue().bold()));
246                result
247            }
248            CqlValue::Tuple(items) => {
249                let mut result = format!("{}", "(".blue().bold());
250                for (i, item) in items.iter().enumerate() {
251                    if i > 0 {
252                        result.push_str(&format!("{}", ", ".blue().bold()));
253                    }
254                    match item {
255                        Some(v) => result.push_str(&self.colorize_collection_element(v)),
256                        None => result.push_str(&format!("{}", "null".red().bold())),
257                    }
258                }
259                result.push_str(&format!("{}", ")".blue().bold()));
260                result
261            }
262            CqlValue::UserDefinedType { fields, .. } => {
263                let mut result = format!("{}", "{".blue().bold());
264                for (i, (name, val)) in fields.iter().enumerate() {
265                    if i > 0 {
266                        result.push_str(&format!("{}", ", ".blue().bold()));
267                    }
268                    // UDT field names are yellow (like text)
269                    result.push_str(&format!("{}", name.as_str().yellow().bold()));
270                    result.push_str(&format!("{}", ": ".blue().bold()));
271                    match val {
272                        Some(v) => result.push_str(&self.colorize_collection_element(v)),
273                        None => result.push_str(&format!("{}", "null".red().bold())),
274                    }
275                }
276                result.push_str(&format!("{}", "}".blue().bold()));
277                result
278            }
279        }
280    }
281
282    /// Colorize an element inside a collection, quoting strings like Display does.
283    fn colorize_collection_element(&self, value: &CqlValue) -> String {
284        match value {
285            CqlValue::Ascii(s) | CqlValue::Text(s) => {
286                // Inside collections, strings are quoted: 'value'
287                let quoted = format!("'{}'", s.replace('\'', "''"));
288                format!("{}", quoted.yellow().bold())
289            }
290            other => self.colorize_value_inner(other),
291        }
292    }
293
294    /// Colorize a line of CQL input for display.
295    ///
296    /// Applies colors:
297    /// - CQL keywords → bold blue
298    /// - String literals ('...') → green
299    /// - Numbers → cyan
300    /// - Comments (-- ...) → dark grey
301    pub fn colorize_line(&self, line: &str) -> String {
302        if !self.enabled {
303            return line.to_string();
304        }
305
306        let mut result = String::with_capacity(line.len() * 2);
307        let chars: Vec<char> = line.chars().collect();
308        let len = chars.len();
309        let mut i = 0;
310
311        while i < len {
312            // Comment: -- to end of line
313            if i + 1 < len && chars[i] == '-' && chars[i + 1] == '-' {
314                let rest: String = chars[i..].iter().collect();
315                result.push_str(&format!("{}", rest.dark_grey()));
316                break;
317            }
318
319            // String literal: '...'
320            if chars[i] == '\'' {
321                let start = i;
322                i += 1;
323                while i < len && chars[i] != '\'' {
324                    if chars[i] == '\\' && i + 1 < len {
325                        i += 1; // skip escaped char
326                    }
327                    i += 1;
328                }
329                if i < len {
330                    i += 1; // consume closing quote
331                }
332                let literal: String = chars[start..i].iter().collect();
333                result.push_str(&format!("{}", literal.green()));
334                continue;
335            }
336
337            // Number (simple: digits possibly with dots)
338            if chars[i].is_ascii_digit()
339                || (chars[i] == '-'
340                    && i + 1 < len
341                    && chars[i + 1].is_ascii_digit()
342                    && (i == 0 || !chars[i - 1].is_alphanumeric()))
343            {
344                let start = i;
345                if chars[i] == '-' {
346                    i += 1;
347                }
348                while i < len && (chars[i].is_ascii_digit() || chars[i] == '.') {
349                    i += 1;
350                }
351                // Make sure this isn't part of an identifier
352                if i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
353                    // It's an identifier like "table1" — don't colorize
354                    let word: String = chars[start..].iter().collect();
355                    let end = word
356                        .find(|c: char| c.is_whitespace() || c == ',' || c == ')' || c == ';')
357                        .unwrap_or(word.len());
358                    result.push_str(&chars[start..start + end].iter().collect::<String>());
359                    i = start + end;
360                } else {
361                    let num: String = chars[start..i].iter().collect();
362                    result.push_str(&format!("{}", num.cyan()));
363                }
364                continue;
365            }
366
367            // Word (potential keyword)
368            if chars[i].is_alphabetic() || chars[i] == '_' {
369                let start = i;
370                while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
371                    i += 1;
372                }
373                let word: String = chars[start..i].iter().collect();
374                // Don't highlight as keyword if preceded by '.' (it's a qualified name)
375                let after_dot = start > 0 && chars[start - 1] == '.';
376                if !after_dot && is_keyword(&word) {
377                    result.push_str(&format!("{}", word.blue().bold()));
378                } else {
379                    result.push_str(&word);
380                }
381                continue;
382            }
383
384            // Other characters (whitespace, operators, etc.)
385            result.push(chars[i]);
386            i += 1;
387        }
388
389        result
390    }
391}
392
393/// Check if a word is a CQL keyword (case-insensitive).
394fn is_keyword(word: &str) -> bool {
395    let upper = word.to_uppercase();
396    KEYWORDS.binary_search(&upper.as_str()).is_ok()
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn keywords_are_highlighted() {
405        let c = CqlColorizer::new(true);
406        let output = c.colorize_line("SELECT * FROM users");
407        assert!(output.contains("\x1b["), "should contain ANSI escape codes");
408        assert!(output.contains("SELECT"));
409        assert!(output.contains("FROM"));
410    }
411
412    #[test]
413    fn colorizer_disabled_returns_unchanged() {
414        let c = CqlColorizer::new(false);
415        let output = c.colorize_line("SELECT * FROM users");
416        assert_eq!(output, "SELECT * FROM users");
417    }
418
419    #[test]
420    fn string_literals_are_colored() {
421        let c = CqlColorizer::new(true);
422        let output = c.colorize_line("INSERT INTO t (a) VALUES ('hello')");
423        // 'hello' should be green (contains ANSI codes)
424        assert!(output.contains("\x1b["));
425        assert!(output.contains("hello"));
426    }
427
428    #[test]
429    fn numbers_are_colored() {
430        let c = CqlColorizer::new(true);
431        let output = c.colorize_line("SELECT * FROM t LIMIT 100");
432        assert!(output.contains("100"));
433    }
434
435    #[test]
436    fn comments_are_colored() {
437        let c = CqlColorizer::new(true);
438        let output = c.colorize_line("SELECT 1 -- test comment");
439        assert!(output.contains("test comment"));
440    }
441
442    #[test]
443    fn non_keywords_are_not_highlighted() {
444        let c = CqlColorizer::new(true);
445        let output = c.colorize_line("my_table");
446        // "my_table" is not a keyword, should not have ANSI codes
447        assert!(!output.contains("\x1b["));
448    }
449
450    #[test]
451    fn mixed_case_keywords() {
452        let c = CqlColorizer::new(true);
453        let output = c.colorize_line("select * from users");
454        assert!(
455            output.contains("\x1b["),
456            "lowercase keywords should also be highlighted"
457        );
458    }
459
460    #[test]
461    fn keyword_list_is_sorted() {
462        // binary_search requires sorted list
463        for window in KEYWORDS.windows(2) {
464            assert!(
465                window[0] < window[1],
466                "KEYWORDS not sorted: {:?} >= {:?}",
467                window[0],
468                window[1]
469            );
470        }
471    }
472
473    // --- Output coloring tests ---
474
475    #[test]
476    fn colorize_text_value_yellow() {
477        let c = CqlColorizer::new(true);
478        let output = c.colorize_value(&CqlValue::Text("hello".to_string()));
479        assert!(output.contains("\x1b["), "should contain ANSI codes");
480        assert!(output.contains("hello"));
481    }
482
483    #[test]
484    fn colorize_int_value_green() {
485        let c = CqlColorizer::new(true);
486        let output = c.colorize_value(&CqlValue::Int(42));
487        assert!(output.contains("\x1b["), "should contain ANSI codes");
488        assert!(output.contains("42"));
489    }
490
491    #[test]
492    fn colorize_null_value_red() {
493        let c = CqlColorizer::new(true);
494        let output = c.colorize_value(&CqlValue::Null);
495        assert!(output.contains("\x1b["), "should contain ANSI codes");
496        assert!(output.contains("null"));
497    }
498
499    #[test]
500    fn colorize_blob_value_dark_magenta() {
501        let c = CqlColorizer::new(true);
502        let output = c.colorize_value(&CqlValue::Blob(vec![0xde, 0xad]));
503        assert!(output.contains("\x1b["), "should contain ANSI codes");
504        assert!(output.contains("dead"));
505    }
506
507    #[test]
508    fn colorize_list_with_blue_delimiters() {
509        let c = CqlColorizer::new(true);
510        let list = CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]);
511        let output = c.colorize_value(&list);
512        assert!(output.contains("\x1b["), "should contain ANSI codes");
513    }
514
515    #[test]
516    fn colorize_value_disabled_returns_plain() {
517        let c = CqlColorizer::new(false);
518        let output = c.colorize_value(&CqlValue::Text("hello".to_string()));
519        assert_eq!(output, "hello");
520    }
521
522    #[test]
523    fn colorize_header_magenta() {
524        let c = CqlColorizer::new(true);
525        let output = c.colorize_header("name");
526        assert!(output.contains("\x1b["), "should contain ANSI codes");
527        assert!(output.contains("name"));
528    }
529
530    #[test]
531    fn colorize_error_red() {
532        let c = CqlColorizer::new(true);
533        let output = c.colorize_error("SyntaxException: bad input");
534        assert!(output.contains("\x1b["), "should contain ANSI codes");
535        assert!(output.contains("SyntaxException"));
536    }
537
538    #[test]
539    fn colorize_map_with_colored_elements() {
540        let c = CqlColorizer::new(true);
541        let map = CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))]);
542        let output = c.colorize_value(&map);
543        assert!(output.contains("\x1b["), "should contain ANSI codes");
544    }
545
546    #[test]
547    fn colorize_udt_field_names_yellow() {
548        let c = CqlColorizer::new(true);
549        let udt = CqlValue::UserDefinedType {
550            keyspace: "ks".to_string(),
551            type_name: "my_type".to_string(),
552            fields: vec![
553                (
554                    "name".to_string(),
555                    Some(CqlValue::Text("Alice".to_string())),
556                ),
557                ("age".to_string(), Some(CqlValue::Int(30))),
558            ],
559        };
560        let output = c.colorize_value(&udt);
561        assert!(output.contains("\x1b["), "should contain ANSI codes");
562    }
563}