cqlsh_rs/
session.rs

1//! CQL session management layer.
2//!
3//! Wraps the driver with higher-level session state management including
4//! keyspace tracking, consistency level management, and tracing control.
5//! This mirrors the Python cqlsh `Shell` session state.
6
7use anyhow::{bail, Result};
8
9use crate::config::MergedConfig;
10use crate::driver::types::CqlValue;
11use crate::driver::{
12    AggregateMetadata, ConnectionConfig, Consistency, CqlDriver, CqlResult, FunctionMetadata,
13    KeyspaceMetadata, PreparedId, ScyllaDriver, SslConfig, TableMetadata, TracingSession,
14    UdtMetadata,
15};
16
17/// High-level CQL session managing driver state and user preferences.
18pub struct CqlSession {
19    driver: ScyllaDriver,
20    /// Current keyspace (updated on USE commands).
21    current_keyspace: Option<String>,
22    /// Display name for the connection (host:port).
23    pub connection_display: String,
24    /// Cluster name retrieved after connecting.
25    pub cluster_name: Option<String>,
26    /// CQL version from the connected node.
27    pub cql_version: Option<String>,
28    /// Release version of the connected node.
29    pub release_version: Option<String>,
30    /// ScyllaDB version (None if connected to Apache Cassandra).
31    pub scylla_version: Option<String>,
32}
33
34impl CqlSession {
35    /// Create a new session by connecting using the merged configuration.
36    pub async fn connect(config: &MergedConfig) -> Result<Self> {
37        let ssl_config = if config.ssl {
38            Some(SslConfig {
39                certfile: config.cqlshrc.ssl.certfile.clone(),
40                validate: config.cqlshrc.ssl.validate.unwrap_or(false),
41                userkey: config.cqlshrc.ssl.userkey.clone(),
42                usercert: config.cqlshrc.ssl.usercert.clone(),
43                host_certfiles: config.cqlshrc.certfiles.clone(),
44            })
45        } else {
46            None
47        };
48
49        let conn_config = ConnectionConfig {
50            host: config.host.clone(),
51            port: config.port,
52            username: config.username.clone(),
53            password: config.password.clone(),
54            keyspace: config.keyspace.clone(),
55            connect_timeout: config.connect_timeout,
56            request_timeout: config.request_timeout,
57            ssl: config.ssl,
58            ssl_config,
59            protocol_version: config.protocol_version,
60        };
61
62        let driver = ScyllaDriver::connect(&conn_config).await?;
63
64        let connection_display = format!("{}:{}", config.host, config.port);
65
66        // Fetch cluster metadata after connecting
67        let cluster_name = driver.get_cluster_name().await.ok().flatten();
68        let cql_version = driver.get_cql_version().await.ok().flatten();
69        let release_version = driver.get_release_version().await.ok().flatten();
70        let scylla_version = driver.get_scylla_version().await.ok().flatten();
71
72        // Set initial consistency from config
73        if let Some(cl_str) = &config.consistency_level {
74            if let Some(cl) = Consistency::from_str_cql(cl_str) {
75                driver.set_consistency(cl);
76            }
77        }
78
79        // Set initial serial consistency from config
80        if let Some(scl_str) = &config.serial_consistency_level {
81            if let Some(scl) = Consistency::from_str_cql(scl_str) {
82                driver.set_serial_consistency(Some(scl));
83            }
84        }
85
86        Ok(CqlSession {
87            driver,
88            current_keyspace: config.keyspace.clone(),
89            connection_display,
90            cluster_name,
91            cql_version,
92            release_version,
93            scylla_version,
94        })
95    }
96
97    /// Execute a CQL statement. Handles USE keyspace commands specially.
98    pub async fn execute(&mut self, query: &str) -> Result<CqlResult> {
99        let trimmed = query.trim();
100
101        // Detect USE keyspace commands
102        if let Some(keyspace) = parse_use_command(trimmed) {
103            self.use_keyspace(&keyspace).await?;
104            return Ok(CqlResult::empty());
105        }
106
107        self.driver.execute_unpaged(query).await
108    }
109
110    /// Execute a raw CQL query without USE interception.
111    ///
112    /// Used by DESCRIBE and other internal commands that need to query
113    /// system tables directly.
114    pub async fn execute_query(&self, query: &str) -> Result<CqlResult> {
115        self.driver.execute_unpaged(query).await
116    }
117
118    /// Execute a CQL statement with paging.
119    pub async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
120        self.driver.execute_paged(query, page_size).await
121    }
122
123    /// Prepare a CQL statement.
124    pub async fn prepare(&self, query: &str) -> Result<PreparedId> {
125        self.driver.prepare(query).await
126    }
127
128    /// Execute a previously prepared statement with typed bound values.
129    pub async fn execute_prepared(
130        &self,
131        id: &PreparedId,
132        values: &[CqlValue],
133    ) -> Result<CqlResult> {
134        self.driver.execute_prepared(id, values).await
135    }
136
137    /// Switch to a different keyspace.
138    pub async fn use_keyspace(&mut self, keyspace: &str) -> Result<()> {
139        self.driver.use_keyspace(keyspace).await?;
140        self.current_keyspace = Some(keyspace.to_string());
141        Ok(())
142    }
143
144    /// Get the current keyspace.
145    pub fn current_keyspace(&self) -> Option<&str> {
146        self.current_keyspace.as_deref()
147    }
148
149    /// Get the current consistency level.
150    pub fn get_consistency(&self) -> Consistency {
151        self.driver.get_consistency()
152    }
153
154    /// Set the consistency level.
155    pub fn set_consistency(&self, consistency: Consistency) {
156        self.driver.set_consistency(consistency);
157    }
158
159    /// Set the consistency level from a string. Returns error if invalid.
160    pub fn set_consistency_str(&self, level: &str) -> Result<()> {
161        let consistency = Consistency::from_str_cql(level)
162            .ok_or_else(|| anyhow::anyhow!("invalid consistency level: {level}"))?;
163        self.driver.set_consistency(consistency);
164        Ok(())
165    }
166
167    /// Get the current serial consistency level.
168    pub fn get_serial_consistency(&self) -> Option<Consistency> {
169        self.driver.get_serial_consistency()
170    }
171
172    /// Set the serial consistency level.
173    pub fn set_serial_consistency(&self, consistency: Option<Consistency>) {
174        self.driver.set_serial_consistency(consistency);
175    }
176
177    /// Set the serial consistency level from a string. Returns error if invalid.
178    pub fn set_serial_consistency_str(&self, level: &str) -> Result<()> {
179        let consistency = Consistency::from_str_cql(level)
180            .ok_or_else(|| anyhow::anyhow!("invalid serial consistency level: {level}"))?;
181        match consistency {
182            Consistency::Serial | Consistency::LocalSerial => {
183                self.driver.set_serial_consistency(Some(consistency));
184                Ok(())
185            }
186            _ => bail!("serial consistency must be SERIAL or LOCAL_SERIAL, got {level}"),
187        }
188    }
189
190    /// Enable or disable tracing.
191    pub fn set_tracing(&self, enabled: bool) {
192        self.driver.set_tracing(enabled);
193    }
194
195    /// Check if tracing is enabled.
196    pub fn is_tracing_enabled(&self) -> bool {
197        self.driver.is_tracing_enabled()
198    }
199
200    /// Get the last tracing session ID.
201    pub fn last_trace_id(&self) -> Option<uuid::Uuid> {
202        self.driver.last_trace_id()
203    }
204
205    /// Retrieve tracing session data.
206    pub async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>> {
207        self.driver.get_trace_session(trace_id).await
208    }
209
210    /// Get metadata for all keyspaces.
211    pub async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
212        self.driver.get_keyspaces().await
213    }
214
215    /// Get metadata for tables in a keyspace.
216    pub async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
217        self.driver.get_tables(keyspace).await
218    }
219
220    /// Get metadata for a specific table.
221    pub async fn get_table_metadata(
222        &self,
223        keyspace: &str,
224        table: &str,
225    ) -> Result<Option<TableMetadata>> {
226        self.driver.get_table_metadata(keyspace, table).await
227    }
228
229    /// Get metadata for all user-defined types in a keyspace.
230    pub async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
231        self.driver.get_udts(keyspace).await
232    }
233
234    /// Get metadata for all user-defined functions in a keyspace.
235    pub async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
236        self.driver.get_functions(keyspace).await
237    }
238
239    /// Get metadata for all user-defined aggregates in a keyspace.
240    pub async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
241        self.driver.get_aggregates(keyspace).await
242    }
243
244    /// Check if the connection is still alive.
245    pub async fn is_connected(&self) -> bool {
246        self.driver.is_connected().await
247    }
248}
249
250/// Parse a USE keyspace command, returning the keyspace name if matched.
251fn parse_use_command(query: &str) -> Option<String> {
252    let upper = query.to_uppercase();
253    let trimmed = upper.trim().trim_end_matches(';').trim();
254
255    if !trimmed.starts_with("USE ") {
256        return None;
257    }
258
259    let keyspace = query
260        .trim()
261        .trim_end_matches(';')
262        .trim()
263        .strip_prefix("USE ")
264        .or_else(|| {
265            query
266                .trim()
267                .trim_end_matches(';')
268                .trim()
269                .strip_prefix("use ")
270        })
271        .map(|s| s.trim())?;
272
273    // Remove quotes if present
274    let keyspace = if (keyspace.starts_with('"') && keyspace.ends_with('"'))
275        || (keyspace.starts_with('\'') && keyspace.ends_with('\''))
276    {
277        &keyspace[1..keyspace.len() - 1]
278    } else {
279        keyspace
280    };
281
282    if keyspace.is_empty() {
283        None
284    } else {
285        Some(keyspace.to_string())
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn parse_use_simple() {
295        assert_eq!(
296            parse_use_command("USE my_keyspace"),
297            Some("my_keyspace".to_string())
298        );
299    }
300
301    #[test]
302    fn parse_use_semicolon() {
303        assert_eq!(
304            parse_use_command("USE my_keyspace;"),
305            Some("my_keyspace".to_string())
306        );
307    }
308
309    #[test]
310    fn parse_use_lowercase() {
311        assert_eq!(
312            parse_use_command("use test_ks"),
313            Some("test_ks".to_string())
314        );
315    }
316
317    #[test]
318    fn parse_use_quoted() {
319        assert_eq!(
320            parse_use_command("USE \"MyKeyspace\""),
321            Some("MyKeyspace".to_string())
322        );
323    }
324
325    #[test]
326    fn parse_use_single_quoted() {
327        assert_eq!(parse_use_command("USE 'my_ks'"), Some("my_ks".to_string()));
328    }
329
330    #[test]
331    fn parse_use_with_whitespace() {
332        assert_eq!(
333            parse_use_command("  USE  my_keyspace  ;  "),
334            Some("my_keyspace".to_string())
335        );
336    }
337
338    #[test]
339    fn parse_not_use_command() {
340        assert_eq!(parse_use_command("SELECT * FROM table"), None);
341        assert_eq!(parse_use_command("INSERT INTO users"), None);
342    }
343
344    #[test]
345    fn parse_use_empty() {
346        assert_eq!(parse_use_command("USE "), None);
347        assert_eq!(parse_use_command("USE ;"), None);
348    }
349}