1use std::borrow::Cow;
9use std::sync::Arc;
10
11use rustyline::completion::{Completer, Pair};
12use rustyline::highlight::Highlighter;
13use rustyline::hint::Hinter;
14use rustyline::validate::Validator;
15use rustyline::{Context, Helper};
16use tokio::runtime::Handle;
17use tokio::sync::RwLock;
18
19use crate::colorizer::CqlColorizer;
20use crate::schema_cache::SchemaCache;
21
22const CQL_KEYWORDS: &[&str] = &[
24 "ALTER", "APPLY", "BATCH", "BEGIN", "CREATE", "DELETE", "DESCRIBE", "DROP", "GRANT", "INSERT",
25 "LIST", "REVOKE", "SELECT", "TRUNCATE", "UPDATE", "USE",
26];
27
28const CQL_CLAUSE_KEYWORDS: &[&str] = &[
30 "ADD",
31 "AGGREGATE",
32 "ALL",
33 "ALLOW",
34 "AND",
35 "AS",
36 "ASC",
37 "AUTHORIZE",
38 "BATCH",
39 "BY",
40 "CALLED",
41 "CLUSTERING",
42 "COLUMN",
43 "COMPACT",
44 "CONTAINS",
45 "COUNT",
46 "CUSTOM",
47 "DELETE",
48 "DESC",
49 "DESCRIBE",
50 "DISTINCT",
51 "DROP",
52 "ENTRIES",
53 "EXECUTE",
54 "EXISTS",
55 "FILTERING",
56 "FINALFUNC",
57 "FROM",
58 "FROZEN",
59 "FULL",
60 "FUNCTION",
61 "FUNCTIONS",
62 "IF",
63 "IN",
64 "INDEX",
65 "INITCOND",
66 "INPUT",
67 "INSERT",
68 "INTO",
69 "IS",
70 "JSON",
71 "KEY",
72 "KEYS",
73 "KEYSPACE",
74 "KEYSPACES",
75 "LANGUAGE",
76 "LIKE",
77 "LIMIT",
78 "LIST",
79 "LOGIN",
80 "MAP",
81 "MATERIALIZED",
82 "MODIFY",
83 "NAMESPACE",
84 "NORECURSIVE",
85 "NOT",
86 "NULL",
87 "OF",
88 "ON",
89 "OR",
90 "ORDER",
91 "PARTITION",
92 "PASSWORD",
93 "PER",
94 "PERMISSION",
95 "PERMISSIONS",
96 "PRIMARY",
97 "RENAME",
98 "REPLACE",
99 "RETURNS",
100 "REVOKE",
101 "SCHEMA",
102 "SELECT",
103 "SET",
104 "SFUNC",
105 "STATIC",
106 "STORAGE",
107 "STYPE",
108 "SUPERUSER",
109 "TABLE",
110 "TABLES",
111 "TEXT",
112 "TIMESTAMP",
113 "TO",
114 "TOKEN",
115 "TRIGGER",
116 "TRUNCATE",
117 "TTL",
118 "TUPLE",
119 "TYPE",
120 "UNLOGGED",
121 "UPDATE",
122 "USER",
123 "USERS",
124 "USING",
125 "VALUES",
126 "VIEW",
127 "WHERE",
128 "WITH",
129 "WRITETIME",
130];
131
132const SHELL_COMMANDS: &[&str] = &[
134 "CAPTURE",
135 "CLEAR",
136 "CLS",
137 "CONSISTENCY",
138 "COPY",
139 "DESCRIBE",
140 "DESC",
141 "EXIT",
142 "EXPAND",
143 "HELP",
144 "LOGIN",
145 "PAGING",
146 "QUIT",
147 "SERIAL",
148 "SHOW",
149 "SOURCE",
150 "TRACING",
151];
152
153const CONSISTENCY_LEVELS: &[&str] = &[
155 "ALL",
156 "ANY",
157 "EACH_QUORUM",
158 "LOCAL_ONE",
159 "LOCAL_QUORUM",
160 "LOCAL_SERIAL",
161 "ONE",
162 "QUORUM",
163 "SERIAL",
164 "THREE",
165 "TWO",
166];
167
168const DESCRIBE_SUB_COMMANDS: &[&str] = &[
170 "AGGREGATE",
171 "AGGREGATES",
172 "CLUSTER",
173 "FULL",
174 "FUNCTION",
175 "FUNCTIONS",
176 "INDEX",
177 "KEYSPACE",
178 "KEYSPACES",
179 "MATERIALIZED",
180 "SCHEMA",
181 "TABLE",
182 "TABLES",
183 "TYPE",
184 "TYPES",
185];
186
187#[allow(dead_code)] const CQL_TYPES: &[&str] = &[
190 "ascii",
191 "bigint",
192 "blob",
193 "boolean",
194 "counter",
195 "date",
196 "decimal",
197 "double",
198 "duration",
199 "float",
200 "frozen",
201 "inet",
202 "int",
203 "list",
204 "map",
205 "set",
206 "smallint",
207 "text",
208 "time",
209 "timestamp",
210 "timeuuid",
211 "tinyint",
212 "tuple",
213 "uuid",
214 "varchar",
215 "varint",
216];
217
218#[derive(Debug, PartialEq)]
220enum CompletionContext {
221 Empty,
223 ClauseKeyword,
225 TableName { keyspace: Option<String> },
227 ColumnName {
229 keyspace: Option<String>,
230 table: String,
231 },
232 ConsistencyLevel,
234 DescribeTarget,
236 FilePath,
238 KeyspaceName,
240}
241
242pub struct CqlCompleter {
244 cache: Arc<RwLock<SchemaCache>>,
246 current_keyspace: Arc<RwLock<Option<String>>>,
248 rt_handle: Handle,
250 colorizer: CqlColorizer,
252}
253
254impl CqlCompleter {
255 pub fn new(
257 cache: Arc<RwLock<SchemaCache>>,
258 current_keyspace: Arc<RwLock<Option<String>>>,
259 rt_handle: Handle,
260 color_enabled: bool,
261 ) -> Self {
262 Self {
263 cache,
264 current_keyspace,
265 rt_handle,
266 colorizer: CqlColorizer::new(color_enabled),
267 }
268 }
269
270 fn detect_context(&self, line: &str, pos: usize) -> CompletionContext {
272 let before_cursor = &line[..pos];
273 let tokens: Vec<&str> = before_cursor.split_whitespace().collect();
274
275 if tokens.is_empty() {
276 return CompletionContext::Empty;
277 }
278
279 let first = tokens[0].to_uppercase();
280
281 if first == "CONSISTENCY" && tokens.len() <= 2 {
283 return if (tokens.len() == 1 && before_cursor.ends_with(' ')) || tokens.len() == 2 {
284 CompletionContext::ConsistencyLevel
285 } else {
286 CompletionContext::Empty
287 };
288 }
289
290 if first == "SERIAL" && tokens.len() >= 2 && tokens[1].to_uppercase() == "CONSISTENCY" {
292 return CompletionContext::ConsistencyLevel;
293 }
294
295 if first == "SOURCE" || first == "CAPTURE" {
297 return CompletionContext::FilePath;
298 }
299
300 if first == "USE" {
302 return CompletionContext::KeyspaceName;
303 }
304
305 if first == "DESCRIBE" || first == "DESC" {
307 if tokens.len() == 1 && before_cursor.ends_with(' ') {
308 return CompletionContext::DescribeTarget;
309 }
310 if tokens.len() == 2 {
311 let sub = tokens[1].to_uppercase();
312 if before_cursor.ends_with(' ') {
313 return match sub.as_str() {
315 "KEYSPACE" => CompletionContext::KeyspaceName,
316 "TABLE" | "INDEX" | "MATERIALIZED" => {
317 CompletionContext::TableName { keyspace: None }
318 }
319 _ => CompletionContext::DescribeTarget,
320 };
321 }
322 return CompletionContext::DescribeTarget;
323 }
324 if tokens.len() == 3 {
325 let sub = tokens[1].to_uppercase();
326 return match sub.as_str() {
327 "KEYSPACE" => CompletionContext::KeyspaceName,
328 "TABLE" | "INDEX" => CompletionContext::TableName { keyspace: None },
329 _ => CompletionContext::ClauseKeyword,
330 };
331 }
332 return CompletionContext::ClauseKeyword;
333 }
334
335 let upper_tokens: Vec<String> = tokens.iter().map(|t| t.to_uppercase()).collect();
337 for (i, token) in upper_tokens.iter().enumerate() {
338 if (token == "FROM" || token == "INTO" || token == "UPDATE" || token == "TABLE")
339 && i + 1 >= tokens.len()
340 && before_cursor.ends_with(' ')
341 {
342 return CompletionContext::TableName { keyspace: None };
343 }
344 if (token == "FROM" || token == "INTO" || token == "UPDATE" || token == "TABLE")
345 && i + 1 < tokens.len()
346 {
347 let table_token = tokens[i + 1];
348 if table_token.contains('.') && table_token.ends_with('.') {
350 let ks = table_token.trim_end_matches('.').to_string();
351 return CompletionContext::TableName { keyspace: Some(ks) };
352 }
353 if i + 2 < tokens.len() || (i + 1 < tokens.len() && before_cursor.ends_with(' ')) {
355 if upper_tokens
357 .iter()
358 .skip(i + 2)
359 .any(|t| t == "WHERE" || t == "SET")
360 {
361 let table = tokens[i + 1].to_string();
362 let ks = tokio::task::block_in_place(|| {
363 self.rt_handle
364 .block_on(async { self.current_keyspace.read().await.clone() })
365 });
366 return CompletionContext::ColumnName {
367 keyspace: ks,
368 table,
369 };
370 }
371 }
372 if !before_cursor.ends_with(' ') && i + 1 == tokens.len() - 1 {
374 return CompletionContext::TableName { keyspace: None };
375 }
376 }
377 }
378
379 if tokens.len() == 1 && !before_cursor.ends_with(' ') {
381 return CompletionContext::Empty;
382 }
383
384 CompletionContext::ClauseKeyword
385 }
386
387 fn complete_for_context(&self, ctx: &CompletionContext, prefix: &str) -> Vec<Pair> {
389 let prefix_upper = prefix.to_uppercase();
390
391 match ctx {
392 CompletionContext::Empty => {
393 let mut candidates: Vec<&str> = Vec::new();
394 candidates.extend_from_slice(CQL_KEYWORDS);
395 candidates.extend_from_slice(SHELL_COMMANDS);
396 filter_candidates(&candidates, &prefix_upper, true)
397 }
398 CompletionContext::ClauseKeyword => {
399 filter_candidates(CQL_CLAUSE_KEYWORDS, &prefix_upper, true)
400 }
401 CompletionContext::ConsistencyLevel => {
402 filter_candidates(CONSISTENCY_LEVELS, &prefix_upper, true)
403 }
404 CompletionContext::DescribeTarget => {
405 filter_candidates(DESCRIBE_SUB_COMMANDS, &prefix_upper, true)
406 }
407 CompletionContext::KeyspaceName => {
408 let cache =
409 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
410 let names = cache.keyspace_names();
411 filter_candidates(&names, prefix, false)
412 }
413 CompletionContext::TableName { keyspace } => {
414 let cache =
415 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
416 let ks = keyspace.clone().or_else(|| {
417 tokio::task::block_in_place(|| {
418 self.rt_handle
419 .block_on(async { self.current_keyspace.read().await.clone() })
420 })
421 });
422 match ks {
423 Some(ref ks_name) => {
424 let names = cache.table_names(ks_name);
425 filter_candidates(&names, prefix, false)
426 }
427 None => {
428 let names = cache.keyspace_names();
430 filter_candidates(&names, prefix, false)
431 }
432 }
433 }
434 CompletionContext::ColumnName { keyspace, table } => {
435 let cache =
436 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
437 let ks = keyspace.clone().or_else(|| {
438 tokio::task::block_in_place(|| {
439 self.rt_handle
440 .block_on(async { self.current_keyspace.read().await.clone() })
441 })
442 });
443 match ks {
444 Some(ref ks_name) => {
445 let names = cache.column_names(ks_name, table);
446 filter_candidates(&names, prefix, false)
447 }
448 None => vec![],
449 }
450 }
451 CompletionContext::FilePath => complete_file_path(prefix),
452 }
453 }
454}
455
456fn filter_candidates(candidates: &[&str], prefix: &str, uppercase: bool) -> Vec<Pair> {
458 candidates
459 .iter()
460 .filter(|c| {
461 if uppercase {
462 c.to_uppercase().starts_with(&prefix.to_uppercase())
463 } else {
464 c.starts_with(prefix)
465 }
466 })
467 .map(|c| {
468 let display = if uppercase {
469 c.to_uppercase()
470 } else {
471 c.to_string()
472 };
473 Pair {
474 display: display.clone(),
475 replacement: display,
476 }
477 })
478 .collect()
479}
480
481fn complete_file_path(prefix: &str) -> Vec<Pair> {
483 let path_str = prefix
485 .strip_prefix('\'')
486 .or_else(|| prefix.strip_prefix('"'))
487 .unwrap_or(prefix);
488
489 let expanded = if path_str.starts_with('~') {
491 if let Some(home) = dirs::home_dir() {
492 path_str.replacen('~', &home.to_string_lossy(), 1)
493 } else {
494 path_str.to_string()
495 }
496 } else {
497 path_str.to_string()
498 };
499
500 let (dir, file_prefix) = if expanded.ends_with('/') {
501 (expanded.as_str(), "")
502 } else {
503 let path = std::path::Path::new(&expanded);
504 let parent = path
505 .parent()
506 .map(|p| p.to_str().unwrap_or("."))
507 .unwrap_or(".");
508 let file = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
509 (parent, file)
510 };
511
512 let dir_to_read = if dir.is_empty() { "." } else { dir };
513
514 let Ok(entries) = std::fs::read_dir(dir_to_read) else {
515 return vec![];
516 };
517
518 entries
519 .filter_map(|entry| entry.ok())
520 .filter_map(|entry| {
521 let name = entry.file_name().to_string_lossy().to_string();
522 if name.starts_with(file_prefix) {
523 let is_dir = entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false);
524 let suffix = if is_dir { "/" } else { "" };
525 let full = if dir.is_empty() || dir == "." {
526 format!("{name}{suffix}")
527 } else if dir.ends_with('/') {
528 format!("{dir}{name}{suffix}")
529 } else {
530 format!("{dir}/{name}{suffix}")
531 };
532 Some(Pair {
533 display: name + suffix,
534 replacement: full,
535 })
536 } else {
537 None
538 }
539 })
540 .collect()
541}
542
543impl Completer for CqlCompleter {
544 type Candidate = Pair;
545
546 fn complete(
547 &self,
548 line: &str,
549 pos: usize,
550 _ctx: &Context<'_>,
551 ) -> rustyline::Result<(usize, Vec<Pair>)> {
552 let needs_refresh = tokio::task::block_in_place(|| {
554 self.rt_handle
555 .block_on(async { self.cache.read().await.is_stale() })
556 });
557 if needs_refresh {
558 tokio::task::block_in_place(|| {
560 self.rt_handle.block_on(async {
561 if let Ok(mut cache) = self.cache.try_write() {
563 if cache.is_stale() {
565 cache.invalidate();
568 }
569 }
570 })
571 });
572 }
573
574 let context = self.detect_context(line, pos);
575
576 let before_cursor = &line[..pos];
578 let word_start = before_cursor
579 .rfind(|c: char| c.is_whitespace() || c == '.' || c == '\'' || c == '"')
580 .map(|i| i + 1)
581 .unwrap_or(0);
582 let prefix = &line[word_start..pos];
583
584 let completions = self.complete_for_context(&context, prefix);
585
586 Ok((word_start, completions))
587 }
588}
589
590impl Hinter for CqlCompleter {
591 type Hint = String;
592
593 fn hint(&self, _line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
594 None
595 }
596}
597
598impl Highlighter for CqlCompleter {
599 fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
600 let colored = self.colorizer.colorize_line(line);
601 if colored == line {
602 Cow::Borrowed(line)
603 } else {
604 Cow::Owned(colored)
605 }
606 }
607
608 fn highlight_prompt<'b, 's: 'b, 'p: 'b>(
609 &'s self,
610 prompt: &'p str,
611 _default: bool,
612 ) -> Cow<'b, str> {
613 Cow::Borrowed(prompt)
614 }
615
616 fn highlight_char(
617 &self,
618 _line: &str,
619 _pos: usize,
620 _forced: rustyline::highlight::CmdKind,
621 ) -> bool {
622 true
624 }
625}
626
627impl Validator for CqlCompleter {}
628
629impl Helper for CqlCompleter {}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 fn make_completer() -> CqlCompleter {
636 let rt = tokio::runtime::Runtime::new().unwrap();
637 let cache = Arc::new(RwLock::new(SchemaCache::new()));
638 let current_ks = Arc::new(RwLock::new(None::<String>));
639 CqlCompleter::new(cache, current_ks, rt.handle().clone(), false)
640 }
641
642 #[test]
643 fn completer_can_be_created() {
644 let _c = make_completer();
645 }
646
647 #[test]
648 fn detect_empty_context() {
649 let c = make_completer();
650 assert_eq!(c.detect_context("", 0), CompletionContext::Empty);
651 }
652
653 #[test]
654 fn detect_keyword_prefix() {
655 let c = make_completer();
656 assert_eq!(c.detect_context("SEL", 3), CompletionContext::Empty);
657 }
658
659 #[test]
660 fn detect_consistency_context() {
661 let c = make_completer();
662 assert_eq!(
663 c.detect_context("CONSISTENCY ", 12),
664 CompletionContext::ConsistencyLevel
665 );
666 }
667
668 #[test]
669 fn detect_serial_consistency_context() {
670 let c = make_completer();
671 assert_eq!(
672 c.detect_context("SERIAL CONSISTENCY ", 19),
673 CompletionContext::ConsistencyLevel
674 );
675 }
676
677 #[test]
678 fn detect_use_keyspace_context() {
679 let c = make_completer();
680 assert_eq!(c.detect_context("USE ", 4), CompletionContext::KeyspaceName);
681 }
682
683 #[test]
684 fn detect_describe_sub_command() {
685 let c = make_completer();
686 assert_eq!(
687 c.detect_context("DESCRIBE ", 9),
688 CompletionContext::DescribeTarget
689 );
690 }
691
692 #[test]
693 fn detect_describe_table_name() {
694 let c = make_completer();
695 assert_eq!(
696 c.detect_context("DESCRIBE TABLE ", 15),
697 CompletionContext::TableName { keyspace: None }
698 );
699 }
700
701 #[test]
702 fn detect_describe_keyspace_name() {
703 let c = make_completer();
704 assert_eq!(
705 c.detect_context("DESCRIBE KEYSPACE ", 18),
706 CompletionContext::KeyspaceName
707 );
708 }
709
710 #[test]
711 fn detect_source_file_path() {
712 let c = make_completer();
713 assert_eq!(
714 c.detect_context("SOURCE '/tmp/", 13),
715 CompletionContext::FilePath
716 );
717 }
718
719 #[test]
720 fn detect_capture_file_path() {
721 let c = make_completer();
722 assert_eq!(c.detect_context("CAPTURE ", 8), CompletionContext::FilePath);
723 }
724
725 #[test]
726 fn detect_from_table_context() {
727 let c = make_completer();
728 assert_eq!(
729 c.detect_context("SELECT * FROM ", 14),
730 CompletionContext::TableName { keyspace: None }
731 );
732 }
733
734 #[test]
735 fn complete_keyword_prefix() {
736 let c = make_completer();
737 let pairs = c.complete_for_context(&CompletionContext::Empty, "SEL");
738 assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
739 }
740
741 #[test]
742 fn complete_consistency_level_prefix() {
743 let c = make_completer();
744 let pairs = c.complete_for_context(&CompletionContext::ConsistencyLevel, "QU");
745 assert!(pairs.iter().any(|p| p.replacement == "QUORUM"));
746 }
747
748 #[test]
749 fn complete_describe_sub_command() {
750 let c = make_completer();
751 let pairs = c.complete_for_context(&CompletionContext::DescribeTarget, "KEY");
752 assert!(pairs.iter().any(|p| p.replacement == "KEYSPACE"));
753 assert!(pairs.iter().any(|p| p.replacement == "KEYSPACES"));
754 }
755
756 #[test]
757 fn filter_is_case_insensitive_for_keywords() {
758 let pairs = filter_candidates(CQL_KEYWORDS, "sel", true);
759 assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
760 }
761
762 #[test]
763 fn file_path_completion_tmp() {
764 let pairs = complete_file_path("/tmp/");
766 assert!(
768 !pairs.is_empty() || std::fs::read_dir("/tmp").map(|d| d.count()).unwrap_or(0) == 0
769 );
770 }
771}