cqlsh_rs/
completer.rs

1//! Tab completion for the CQL shell.
2//!
3//! Implements rustyline's `Completer`, `Helper`, `Hinter`, `Highlighter`, and
4//! `Validator` traits to provide context-aware tab completion in the REPL.
5//! Completions include CQL keywords, shell commands, schema objects (keyspaces,
6//! tables, columns), consistency levels, DESCRIBE sub-commands, and file paths.
7
8use std::borrow::Cow;
9use std::sync::Arc;
10
11use rustyline::completion::{Completer, Pair};
12use rustyline::highlight::Highlighter;
13use rustyline::hint::Hinter;
14use rustyline::validate::Validator;
15use rustyline::{Context, Helper};
16use tokio::runtime::Handle;
17use tokio::sync::RwLock;
18
19use crate::colorizer::CqlColorizer;
20use crate::schema_cache::SchemaCache;
21
22/// CQL keywords that can start a statement.
23const CQL_KEYWORDS: &[&str] = &[
24    "ALTER", "APPLY", "BATCH", "BEGIN", "CREATE", "DELETE", "DESCRIBE", "DROP", "GRANT", "INSERT",
25    "LIST", "REVOKE", "SELECT", "TRUNCATE", "UPDATE", "USE",
26];
27
28/// CQL clause keywords used within statements.
29const CQL_CLAUSE_KEYWORDS: &[&str] = &[
30    "ADD",
31    "AGGREGATE",
32    "ALL",
33    "ALLOW",
34    "AND",
35    "AS",
36    "ASC",
37    "AUTHORIZE",
38    "BATCH",
39    "BY",
40    "CALLED",
41    "CLUSTERING",
42    "COLUMN",
43    "COMPACT",
44    "CONTAINS",
45    "COUNT",
46    "CUSTOM",
47    "DELETE",
48    "DESC",
49    "DESCRIBE",
50    "DISTINCT",
51    "DROP",
52    "ENTRIES",
53    "EXECUTE",
54    "EXISTS",
55    "FILTERING",
56    "FINALFUNC",
57    "FROM",
58    "FROZEN",
59    "FULL",
60    "FUNCTION",
61    "FUNCTIONS",
62    "IF",
63    "IN",
64    "INDEX",
65    "INITCOND",
66    "INPUT",
67    "INSERT",
68    "INTO",
69    "IS",
70    "JSON",
71    "KEY",
72    "KEYS",
73    "KEYSPACE",
74    "KEYSPACES",
75    "LANGUAGE",
76    "LIKE",
77    "LIMIT",
78    "LIST",
79    "LOGIN",
80    "MAP",
81    "MATERIALIZED",
82    "MODIFY",
83    "NAMESPACE",
84    "NORECURSIVE",
85    "NOT",
86    "NULL",
87    "OF",
88    "ON",
89    "OR",
90    "ORDER",
91    "PARTITION",
92    "PASSWORD",
93    "PER",
94    "PERMISSION",
95    "PERMISSIONS",
96    "PRIMARY",
97    "RENAME",
98    "REPLACE",
99    "RETURNS",
100    "REVOKE",
101    "SCHEMA",
102    "SELECT",
103    "SET",
104    "SFUNC",
105    "STATIC",
106    "STORAGE",
107    "STYPE",
108    "SUPERUSER",
109    "TABLE",
110    "TABLES",
111    "TEXT",
112    "TIMESTAMP",
113    "TO",
114    "TOKEN",
115    "TRIGGER",
116    "TRUNCATE",
117    "TTL",
118    "TUPLE",
119    "TYPE",
120    "UNLOGGED",
121    "UPDATE",
122    "USER",
123    "USERS",
124    "USING",
125    "VALUES",
126    "VIEW",
127    "WHERE",
128    "WITH",
129    "WRITETIME",
130];
131
132/// Built-in shell commands.
133const SHELL_COMMANDS: &[&str] = &[
134    "CAPTURE",
135    "CLEAR",
136    "CLS",
137    "CONSISTENCY",
138    "COPY",
139    "DESCRIBE",
140    "DESC",
141    "EXIT",
142    "EXPAND",
143    "HELP",
144    "LOGIN",
145    "PAGING",
146    "QUIT",
147    "SERIAL",
148    "SHOW",
149    "SOURCE",
150    "TRACING",
151];
152
153/// CQL consistency levels.
154const CONSISTENCY_LEVELS: &[&str] = &[
155    "ALL",
156    "ANY",
157    "EACH_QUORUM",
158    "LOCAL_ONE",
159    "LOCAL_QUORUM",
160    "LOCAL_SERIAL",
161    "ONE",
162    "QUORUM",
163    "SERIAL",
164    "THREE",
165    "TWO",
166];
167
168/// DESCRIBE sub-commands.
169const DESCRIBE_SUB_COMMANDS: &[&str] = &[
170    "AGGREGATE",
171    "AGGREGATES",
172    "CLUSTER",
173    "FULL",
174    "FUNCTION",
175    "FUNCTIONS",
176    "INDEX",
177    "KEYSPACE",
178    "KEYSPACES",
179    "MATERIALIZED",
180    "SCHEMA",
181    "TABLE",
182    "TABLES",
183    "TYPE",
184    "TYPES",
185];
186
187/// CQL data types for CREATE TABLE column definitions.
188#[allow(dead_code)] // Will be used when CqlType completion context is implemented
189const CQL_TYPES: &[&str] = &[
190    "ascii",
191    "bigint",
192    "blob",
193    "boolean",
194    "counter",
195    "date",
196    "decimal",
197    "double",
198    "duration",
199    "float",
200    "frozen",
201    "inet",
202    "int",
203    "list",
204    "map",
205    "set",
206    "smallint",
207    "text",
208    "time",
209    "timestamp",
210    "timeuuid",
211    "tinyint",
212    "tuple",
213    "uuid",
214    "varchar",
215    "varint",
216];
217
218/// Detected completion context based on the input up to the cursor.
219#[derive(Debug, PartialEq)]
220enum CompletionContext {
221    /// At the start of input — complete with statement keywords and shell commands.
222    Empty,
223    /// After a statement keyword — complete with clause keywords.
224    ClauseKeyword,
225    /// After FROM, INTO, UPDATE, etc. — complete with table names.
226    TableName { keyspace: Option<String> },
227    /// After SELECT ... FROM table WHERE — complete with column names.
228    ColumnName {
229        keyspace: Option<String>,
230        table: String,
231    },
232    /// After CONSISTENCY — complete with consistency levels.
233    ConsistencyLevel,
234    /// After DESCRIBE/DESC — complete with sub-commands or schema names.
235    DescribeTarget,
236    /// After SOURCE or CAPTURE — complete with file paths.
237    FilePath,
238    /// After USE — complete with keyspace names.
239    KeyspaceName,
240}
241
242/// Tab completer for the CQL shell REPL.
243pub struct CqlCompleter {
244    /// Shared schema cache for keyspace/table/column lookups.
245    cache: Arc<RwLock<SchemaCache>>,
246    /// Current keyspace (shared with session via USE command).
247    current_keyspace: Arc<RwLock<Option<String>>>,
248    /// Tokio runtime handle for blocking cache reads inside sync complete().
249    rt_handle: Handle,
250    /// Syntax colorizer for highlighting.
251    colorizer: CqlColorizer,
252}
253
254impl CqlCompleter {
255    /// Create a new completer with shared cache and keyspace state.
256    pub fn new(
257        cache: Arc<RwLock<SchemaCache>>,
258        current_keyspace: Arc<RwLock<Option<String>>>,
259        rt_handle: Handle,
260        color_enabled: bool,
261    ) -> Self {
262        Self {
263            cache,
264            current_keyspace,
265            rt_handle,
266            colorizer: CqlColorizer::new(color_enabled),
267        }
268    }
269
270    /// Detect completion context from the input line up to the cursor position.
271    fn detect_context(&self, line: &str, pos: usize) -> CompletionContext {
272        let before_cursor = &line[..pos];
273        let tokens: Vec<&str> = before_cursor.split_whitespace().collect();
274
275        if tokens.is_empty() {
276            return CompletionContext::Empty;
277        }
278
279        let first = tokens[0].to_uppercase();
280
281        // CONSISTENCY <level>
282        if first == "CONSISTENCY" && tokens.len() <= 2 {
283            return if (tokens.len() == 1 && before_cursor.ends_with(' ')) || tokens.len() == 2 {
284                CompletionContext::ConsistencyLevel
285            } else {
286                CompletionContext::Empty
287            };
288        }
289
290        // SERIAL CONSISTENCY <level>
291        if first == "SERIAL" && tokens.len() >= 2 && tokens[1].to_uppercase() == "CONSISTENCY" {
292            return CompletionContext::ConsistencyLevel;
293        }
294
295        // SOURCE / CAPTURE — file path
296        if first == "SOURCE" || first == "CAPTURE" {
297            return CompletionContext::FilePath;
298        }
299
300        // USE <keyspace>
301        if first == "USE" {
302            return CompletionContext::KeyspaceName;
303        }
304
305        // DESCRIBE / DESC
306        if first == "DESCRIBE" || first == "DESC" {
307            if tokens.len() == 1 && before_cursor.ends_with(' ') {
308                return CompletionContext::DescribeTarget;
309            }
310            if tokens.len() == 2 {
311                let sub = tokens[1].to_uppercase();
312                if before_cursor.ends_with(' ') {
313                    // After sub-command, complete with schema names
314                    return match sub.as_str() {
315                        "KEYSPACE" => CompletionContext::KeyspaceName,
316                        "TABLE" | "INDEX" | "MATERIALIZED" => {
317                            CompletionContext::TableName { keyspace: None }
318                        }
319                        _ => CompletionContext::DescribeTarget,
320                    };
321                }
322                return CompletionContext::DescribeTarget;
323            }
324            if tokens.len() == 3 {
325                let sub = tokens[1].to_uppercase();
326                return match sub.as_str() {
327                    "KEYSPACE" => CompletionContext::KeyspaceName,
328                    "TABLE" | "INDEX" => CompletionContext::TableName { keyspace: None },
329                    _ => CompletionContext::ClauseKeyword,
330                };
331            }
332            return CompletionContext::ClauseKeyword;
333        }
334
335        // Check for FROM/INTO/UPDATE keywords to trigger table name completion
336        let upper_tokens: Vec<String> = tokens.iter().map(|t| t.to_uppercase()).collect();
337        for (i, token) in upper_tokens.iter().enumerate() {
338            if (token == "FROM" || token == "INTO" || token == "UPDATE" || token == "TABLE")
339                && i + 1 >= tokens.len()
340                && before_cursor.ends_with(' ')
341            {
342                return CompletionContext::TableName { keyspace: None };
343            }
344            if (token == "FROM" || token == "INTO" || token == "UPDATE" || token == "TABLE")
345                && i + 1 < tokens.len()
346            {
347                let table_token = tokens[i + 1];
348                // Check if partially typing a qualified name (ks.)
349                if table_token.contains('.') && table_token.ends_with('.') {
350                    let ks = table_token.trim_end_matches('.').to_string();
351                    return CompletionContext::TableName { keyspace: Some(ks) };
352                }
353                // If we're past the table name, might be column context
354                if i + 2 < tokens.len() || (i + 1 < tokens.len() && before_cursor.ends_with(' ')) {
355                    // Check for WHERE clause
356                    if upper_tokens
357                        .iter()
358                        .skip(i + 2)
359                        .any(|t| t == "WHERE" || t == "SET")
360                    {
361                        let table = tokens[i + 1].to_string();
362                        let ks = tokio::task::block_in_place(|| {
363                            self.rt_handle
364                                .block_on(async { self.current_keyspace.read().await.clone() })
365                        });
366                        return CompletionContext::ColumnName {
367                            keyspace: ks,
368                            table,
369                        };
370                    }
371                }
372                // Still typing the table name
373                if !before_cursor.ends_with(' ') && i + 1 == tokens.len() - 1 {
374                    return CompletionContext::TableName { keyspace: None };
375                }
376            }
377        }
378
379        // At beginning of line, completing a keyword
380        if tokens.len() == 1 && !before_cursor.ends_with(' ') {
381            return CompletionContext::Empty;
382        }
383
384        CompletionContext::ClauseKeyword
385    }
386
387    /// Generate completions for the detected context.
388    fn complete_for_context(&self, ctx: &CompletionContext, prefix: &str) -> Vec<Pair> {
389        let prefix_upper = prefix.to_uppercase();
390
391        match ctx {
392            CompletionContext::Empty => {
393                let mut candidates: Vec<&str> = Vec::new();
394                candidates.extend_from_slice(CQL_KEYWORDS);
395                candidates.extend_from_slice(SHELL_COMMANDS);
396                filter_candidates(&candidates, &prefix_upper, true)
397            }
398            CompletionContext::ClauseKeyword => {
399                filter_candidates(CQL_CLAUSE_KEYWORDS, &prefix_upper, true)
400            }
401            CompletionContext::ConsistencyLevel => {
402                filter_candidates(CONSISTENCY_LEVELS, &prefix_upper, true)
403            }
404            CompletionContext::DescribeTarget => {
405                filter_candidates(DESCRIBE_SUB_COMMANDS, &prefix_upper, true)
406            }
407            CompletionContext::KeyspaceName => {
408                let cache =
409                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
410                let names = cache.keyspace_names();
411                filter_candidates(&names, prefix, false)
412            }
413            CompletionContext::TableName { keyspace } => {
414                let cache =
415                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
416                let ks = keyspace.clone().or_else(|| {
417                    tokio::task::block_in_place(|| {
418                        self.rt_handle
419                            .block_on(async { self.current_keyspace.read().await.clone() })
420                    })
421                });
422                match ks {
423                    Some(ref ks_name) => {
424                        let names = cache.table_names(ks_name);
425                        filter_candidates(&names, prefix, false)
426                    }
427                    None => {
428                        // No keyspace context — offer keyspace names for qualification
429                        let names = cache.keyspace_names();
430                        filter_candidates(&names, prefix, false)
431                    }
432                }
433            }
434            CompletionContext::ColumnName { keyspace, table } => {
435                let cache =
436                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
437                let ks = keyspace.clone().or_else(|| {
438                    tokio::task::block_in_place(|| {
439                        self.rt_handle
440                            .block_on(async { self.current_keyspace.read().await.clone() })
441                    })
442                });
443                match ks {
444                    Some(ref ks_name) => {
445                        let names = cache.column_names(ks_name, table);
446                        filter_candidates(&names, prefix, false)
447                    }
448                    None => vec![],
449                }
450            }
451            CompletionContext::FilePath => complete_file_path(prefix),
452        }
453    }
454}
455
456/// Filter candidates by prefix, returning matching `Pair`s.
457fn filter_candidates(candidates: &[&str], prefix: &str, uppercase: bool) -> Vec<Pair> {
458    candidates
459        .iter()
460        .filter(|c| {
461            if uppercase {
462                c.to_uppercase().starts_with(&prefix.to_uppercase())
463            } else {
464                c.starts_with(prefix)
465            }
466        })
467        .map(|c| {
468            let display = if uppercase {
469                c.to_uppercase()
470            } else {
471                c.to_string()
472            };
473            Pair {
474                display: display.clone(),
475                replacement: display,
476            }
477        })
478        .collect()
479}
480
481/// Complete file paths for SOURCE and CAPTURE commands.
482fn complete_file_path(prefix: &str) -> Vec<Pair> {
483    // Strip surrounding quotes if present
484    let path_str = prefix
485        .strip_prefix('\'')
486        .or_else(|| prefix.strip_prefix('"'))
487        .unwrap_or(prefix);
488
489    // Expand ~ to home directory
490    let expanded = if path_str.starts_with('~') {
491        if let Some(home) = dirs::home_dir() {
492            path_str.replacen('~', &home.to_string_lossy(), 1)
493        } else {
494            path_str.to_string()
495        }
496    } else {
497        path_str.to_string()
498    };
499
500    let (dir, file_prefix) = if expanded.ends_with('/') {
501        (expanded.as_str(), "")
502    } else {
503        let path = std::path::Path::new(&expanded);
504        let parent = path
505            .parent()
506            .map(|p| p.to_str().unwrap_or("."))
507            .unwrap_or(".");
508        let file = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
509        (parent, file)
510    };
511
512    let dir_to_read = if dir.is_empty() { "." } else { dir };
513
514    let Ok(entries) = std::fs::read_dir(dir_to_read) else {
515        return vec![];
516    };
517
518    entries
519        .filter_map(|entry| entry.ok())
520        .filter_map(|entry| {
521            let name = entry.file_name().to_string_lossy().to_string();
522            if name.starts_with(file_prefix) {
523                let is_dir = entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false);
524                let suffix = if is_dir { "/" } else { "" };
525                let full = if dir.is_empty() || dir == "." {
526                    format!("{name}{suffix}")
527                } else if dir.ends_with('/') {
528                    format!("{dir}{name}{suffix}")
529                } else {
530                    format!("{dir}/{name}{suffix}")
531                };
532                Some(Pair {
533                    display: name + suffix,
534                    replacement: full,
535                })
536            } else {
537                None
538            }
539        })
540        .collect()
541}
542
543impl Completer for CqlCompleter {
544    type Candidate = Pair;
545
546    fn complete(
547        &self,
548        line: &str,
549        pos: usize,
550        _ctx: &Context<'_>,
551    ) -> rustyline::Result<(usize, Vec<Pair>)> {
552        // block_in_place: complete() is called from within the Tokio runtime (sync rustyline trait)
553        let needs_refresh = tokio::task::block_in_place(|| {
554            self.rt_handle
555                .block_on(async { self.cache.read().await.is_stale() })
556        });
557        if needs_refresh {
558            // Best-effort refresh — don't block on errors
559            tokio::task::block_in_place(|| {
560                self.rt_handle.block_on(async {
561                    // Try to get write lock without blocking other completions
562                    if let Ok(mut cache) = self.cache.try_write() {
563                        // Re-check staleness after acquiring lock
564                        if cache.is_stale() {
565                            // We can't refresh without a session reference here.
566                            // The REPL pre-refreshes the cache; this is a fallback mark.
567                            cache.invalidate();
568                        }
569                    }
570                })
571            });
572        }
573
574        let context = self.detect_context(line, pos);
575
576        // Find the start of the word being completed
577        let before_cursor = &line[..pos];
578        let word_start = before_cursor
579            .rfind(|c: char| c.is_whitespace() || c == '.' || c == '\'' || c == '"')
580            .map(|i| i + 1)
581            .unwrap_or(0);
582        let prefix = &line[word_start..pos];
583
584        let completions = self.complete_for_context(&context, prefix);
585
586        Ok((word_start, completions))
587    }
588}
589
590impl Hinter for CqlCompleter {
591    type Hint = String;
592
593    fn hint(&self, _line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
594        None
595    }
596}
597
598impl Highlighter for CqlCompleter {
599    fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
600        let colored = self.colorizer.colorize_line(line);
601        if colored == line {
602            Cow::Borrowed(line)
603        } else {
604            Cow::Owned(colored)
605        }
606    }
607
608    fn highlight_prompt<'b, 's: 'b, 'p: 'b>(
609        &'s self,
610        prompt: &'p str,
611        _default: bool,
612    ) -> Cow<'b, str> {
613        Cow::Borrowed(prompt)
614    }
615
616    fn highlight_char(
617        &self,
618        _line: &str,
619        _pos: usize,
620        _forced: rustyline::highlight::CmdKind,
621    ) -> bool {
622        // Return true to trigger re-highlighting on every keystroke
623        true
624    }
625}
626
627impl Validator for CqlCompleter {}
628
629impl Helper for CqlCompleter {}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    fn make_completer() -> CqlCompleter {
636        let rt = tokio::runtime::Runtime::new().unwrap();
637        let cache = Arc::new(RwLock::new(SchemaCache::new()));
638        let current_ks = Arc::new(RwLock::new(None::<String>));
639        CqlCompleter::new(cache, current_ks, rt.handle().clone(), false)
640    }
641
642    #[test]
643    fn completer_can_be_created() {
644        let _c = make_completer();
645    }
646
647    #[test]
648    fn detect_empty_context() {
649        let c = make_completer();
650        assert_eq!(c.detect_context("", 0), CompletionContext::Empty);
651    }
652
653    #[test]
654    fn detect_keyword_prefix() {
655        let c = make_completer();
656        assert_eq!(c.detect_context("SEL", 3), CompletionContext::Empty);
657    }
658
659    #[test]
660    fn detect_consistency_context() {
661        let c = make_completer();
662        assert_eq!(
663            c.detect_context("CONSISTENCY ", 12),
664            CompletionContext::ConsistencyLevel
665        );
666    }
667
668    #[test]
669    fn detect_serial_consistency_context() {
670        let c = make_completer();
671        assert_eq!(
672            c.detect_context("SERIAL CONSISTENCY ", 19),
673            CompletionContext::ConsistencyLevel
674        );
675    }
676
677    #[test]
678    fn detect_use_keyspace_context() {
679        let c = make_completer();
680        assert_eq!(c.detect_context("USE ", 4), CompletionContext::KeyspaceName);
681    }
682
683    #[test]
684    fn detect_describe_sub_command() {
685        let c = make_completer();
686        assert_eq!(
687            c.detect_context("DESCRIBE ", 9),
688            CompletionContext::DescribeTarget
689        );
690    }
691
692    #[test]
693    fn detect_describe_table_name() {
694        let c = make_completer();
695        assert_eq!(
696            c.detect_context("DESCRIBE TABLE ", 15),
697            CompletionContext::TableName { keyspace: None }
698        );
699    }
700
701    #[test]
702    fn detect_describe_keyspace_name() {
703        let c = make_completer();
704        assert_eq!(
705            c.detect_context("DESCRIBE KEYSPACE ", 18),
706            CompletionContext::KeyspaceName
707        );
708    }
709
710    #[test]
711    fn detect_source_file_path() {
712        let c = make_completer();
713        assert_eq!(
714            c.detect_context("SOURCE '/tmp/", 13),
715            CompletionContext::FilePath
716        );
717    }
718
719    #[test]
720    fn detect_capture_file_path() {
721        let c = make_completer();
722        assert_eq!(c.detect_context("CAPTURE ", 8), CompletionContext::FilePath);
723    }
724
725    #[test]
726    fn detect_from_table_context() {
727        let c = make_completer();
728        assert_eq!(
729            c.detect_context("SELECT * FROM ", 14),
730            CompletionContext::TableName { keyspace: None }
731        );
732    }
733
734    #[test]
735    fn complete_keyword_prefix() {
736        let c = make_completer();
737        let pairs = c.complete_for_context(&CompletionContext::Empty, "SEL");
738        assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
739    }
740
741    #[test]
742    fn complete_consistency_level_prefix() {
743        let c = make_completer();
744        let pairs = c.complete_for_context(&CompletionContext::ConsistencyLevel, "QU");
745        assert!(pairs.iter().any(|p| p.replacement == "QUORUM"));
746    }
747
748    #[test]
749    fn complete_describe_sub_command() {
750        let c = make_completer();
751        let pairs = c.complete_for_context(&CompletionContext::DescribeTarget, "KEY");
752        assert!(pairs.iter().any(|p| p.replacement == "KEYSPACE"));
753        assert!(pairs.iter().any(|p| p.replacement == "KEYSPACES"));
754    }
755
756    #[test]
757    fn filter_is_case_insensitive_for_keywords() {
758        let pairs = filter_candidates(CQL_KEYWORDS, "sel", true);
759        assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
760    }
761
762    #[test]
763    fn file_path_completion_tmp() {
764        // /tmp should exist on all Unix systems
765        let pairs = complete_file_path("/tmp/");
766        // Should return entries — exact count varies
767        assert!(
768            !pairs.is_empty() || std::fs::read_dir("/tmp").map(|d| d.count()).unwrap_or(0) == 0
769        );
770    }
771}