cqlsh_rs/driver/
mod.rs

1//! Driver abstraction layer for CQL database connectivity.
2//!
3//! Provides a trait-based abstraction over the underlying database driver,
4//! enabling testability and future flexibility. The primary implementation
5//! uses the `scylla` crate for Cassandra/ScyllaDB connectivity.
6//!
7//! Many types and trait methods are defined ahead of their use in later
8//! development phases (REPL, DESCRIBE, COPY, etc.).
9
10pub mod scylla_driver;
11pub mod types;
12
13use std::collections::HashMap;
14
15use anyhow::Result;
16use async_trait::async_trait;
17
18pub use scylla_driver::ScyllaDriver;
19#[allow(unused_imports)]
20pub use types::{CqlColumn, CqlResult, CqlRow, CqlValue};
21
22/// Configuration for establishing a database connection.
23#[derive(Debug, Clone)]
24pub struct ConnectionConfig {
25    /// Contact point host (e.g., "127.0.0.1").
26    pub host: String,
27    /// Native transport port (default: 9042).
28    pub port: u16,
29    /// Optional username for authentication.
30    pub username: Option<String>,
31    /// Optional password for authentication.
32    pub password: Option<String>,
33    /// Optional default keyspace.
34    pub keyspace: Option<String>,
35    /// Connection timeout in seconds.
36    pub connect_timeout: u64,
37    /// Per-request timeout in seconds.
38    pub request_timeout: u64,
39    /// Whether to use SSL/TLS.
40    pub ssl: bool,
41    /// SSL/TLS configuration.
42    pub ssl_config: Option<SslConfig>,
43    /// Protocol version (None = auto-negotiate).
44    pub protocol_version: Option<u8>,
45}
46
47/// SSL/TLS configuration options.
48#[derive(Debug, Clone)]
49pub struct SslConfig {
50    /// Path to CA certificate file for server verification.
51    pub certfile: Option<String>,
52    /// Whether to validate the server certificate.
53    pub validate: bool,
54    /// Path to client private key file (for mutual TLS).
55    pub userkey: Option<String>,
56    /// Path to client certificate file (for mutual TLS).
57    pub usercert: Option<String>,
58    /// Per-host certificate files.
59    pub host_certfiles: HashMap<String, String>,
60}
61
62/// Metadata about a column in a result set.
63#[derive(Debug, Clone)]
64pub struct ColumnMetadata {
65    pub name: String,
66    pub type_name: String,
67}
68
69/// Metadata about a keyspace.
70#[derive(Debug, Clone)]
71pub struct KeyspaceMetadata {
72    pub name: String,
73    pub replication: HashMap<String, String>,
74    pub durable_writes: bool,
75}
76
77/// Metadata about a table.
78#[derive(Debug, Clone)]
79pub struct TableMetadata {
80    pub keyspace: String,
81    pub name: String,
82    pub columns: Vec<ColumnMetadata>,
83    pub partition_key: Vec<String>,
84    pub clustering_key: Vec<String>,
85}
86
87/// Metadata about a user-defined type (UDT).
88#[derive(Debug, Clone)]
89pub struct UdtMetadata {
90    pub keyspace: String,
91    pub name: String,
92    pub field_names: Vec<String>,
93    pub field_types: Vec<String>,
94}
95
96/// Metadata about a user-defined function (UDF).
97#[derive(Debug, Clone)]
98pub struct FunctionMetadata {
99    pub keyspace: String,
100    pub name: String,
101    pub argument_types: Vec<String>,
102    pub return_type: String,
103}
104
105/// Metadata about a user-defined aggregate (UDA).
106#[derive(Debug, Clone)]
107pub struct AggregateMetadata {
108    pub keyspace: String,
109    pub name: String,
110    pub argument_types: Vec<String>,
111    pub return_type: String,
112}
113
114/// Consistency levels matching CQL specification.
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum Consistency {
117    Any,
118    One,
119    Two,
120    Three,
121    Quorum,
122    All,
123    LocalQuorum,
124    EachQuorum,
125    Serial,
126    LocalSerial,
127    LocalOne,
128}
129
130impl Consistency {
131    /// Parse a consistency level from a string (case-insensitive).
132    pub fn from_str_cql(s: &str) -> Option<Self> {
133        match s.to_uppercase().as_str() {
134            "ANY" => Some(Self::Any),
135            "ONE" => Some(Self::One),
136            "TWO" => Some(Self::Two),
137            "THREE" => Some(Self::Three),
138            "QUORUM" => Some(Self::Quorum),
139            "ALL" => Some(Self::All),
140            "LOCAL_QUORUM" => Some(Self::LocalQuorum),
141            "EACH_QUORUM" => Some(Self::EachQuorum),
142            "SERIAL" => Some(Self::Serial),
143            "LOCAL_SERIAL" => Some(Self::LocalSerial),
144            "LOCAL_ONE" => Some(Self::LocalOne),
145            _ => None,
146        }
147    }
148
149    /// Return the CQL string representation.
150    pub fn as_cql_str(&self) -> &'static str {
151        match self {
152            Self::Any => "ANY",
153            Self::One => "ONE",
154            Self::Two => "TWO",
155            Self::Three => "THREE",
156            Self::Quorum => "QUORUM",
157            Self::All => "ALL",
158            Self::LocalQuorum => "LOCAL_QUORUM",
159            Self::EachQuorum => "EACH_QUORUM",
160            Self::Serial => "SERIAL",
161            Self::LocalSerial => "LOCAL_SERIAL",
162            Self::LocalOne => "LOCAL_ONE",
163        }
164    }
165}
166
167impl std::fmt::Display for Consistency {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.write_str(self.as_cql_str())
170    }
171}
172
173/// The core driver trait abstracting database operations.
174///
175/// All methods are async and return `Result` for proper error propagation.
176/// Implementations must be `Send + Sync` for use across async tasks.
177#[async_trait]
178pub trait CqlDriver: Send + Sync {
179    /// Establish a connection to the database cluster.
180    async fn connect(config: &ConnectionConfig) -> Result<Self>
181    where
182        Self: Sized;
183
184    /// Execute a raw CQL query string without parameters.
185    async fn execute_unpaged(&self, query: &str) -> Result<CqlResult>;
186
187    /// Execute a CQL query with automatic paging, returning all rows.
188    async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult>;
189
190    /// Prepare a CQL statement for repeated execution.
191    async fn prepare(&self, query: &str) -> Result<PreparedId>;
192
193    /// Execute a previously prepared statement with the given values.
194    async fn execute_prepared(
195        &self,
196        prepared_id: &PreparedId,
197        values: &[CqlValue],
198    ) -> Result<CqlResult>;
199
200    /// Switch the current keyspace (USE <keyspace>).
201    async fn use_keyspace(&self, keyspace: &str) -> Result<()>;
202
203    /// Get the current consistency level.
204    fn get_consistency(&self) -> Consistency;
205
206    /// Set the consistency level for subsequent queries.
207    fn set_consistency(&self, consistency: Consistency);
208
209    /// Get the current serial consistency level.
210    fn get_serial_consistency(&self) -> Option<Consistency>;
211
212    /// Set the serial consistency level for subsequent queries.
213    fn set_serial_consistency(&self, consistency: Option<Consistency>);
214
215    /// Enable or disable request tracing.
216    fn set_tracing(&self, enabled: bool);
217
218    /// Check if tracing is currently enabled.
219    fn is_tracing_enabled(&self) -> bool;
220
221    /// Get the last tracing session ID (if tracing was enabled).
222    fn last_trace_id(&self) -> Option<uuid::Uuid>;
223
224    /// Retrieve tracing session data for a given trace ID.
225    async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>>;
226
227    /// Get metadata for all keyspaces.
228    async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>>;
229
230    /// Get metadata for all tables in a keyspace.
231    async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>>;
232
233    /// Get metadata for a specific table.
234    async fn get_table_metadata(
235        &self,
236        keyspace: &str,
237        table: &str,
238    ) -> Result<Option<TableMetadata>>;
239
240    /// Get metadata for all user-defined types in a keyspace.
241    async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>>;
242
243    /// Get metadata for all user-defined functions in a keyspace.
244    async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>>;
245
246    /// Get metadata for all user-defined aggregates in a keyspace.
247    async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>>;
248
249    /// Get the cluster name.
250    async fn get_cluster_name(&self) -> Result<Option<String>>;
251
252    /// Get the CQL version from the connected node.
253    async fn get_cql_version(&self) -> Result<Option<String>>;
254
255    /// Get the release version of the connected node.
256    async fn get_release_version(&self) -> Result<Option<String>>;
257
258    /// Get the ScyllaDB version (None if not ScyllaDB).
259    async fn get_scylla_version(&self) -> Result<Option<String>>;
260
261    /// Check if the connection is still alive.
262    async fn is_connected(&self) -> bool;
263}
264
265/// Opaque handle for a prepared statement.
266#[derive(Debug, Clone)]
267pub struct PreparedId {
268    /// Internal identifier (implementation-specific).
269    pub(crate) inner: Vec<u8>,
270}
271
272/// Tracing session data returned by the database.
273#[derive(Debug, Clone)]
274pub struct TracingSession {
275    pub trace_id: uuid::Uuid,
276    pub client: Option<String>,
277    pub command: Option<String>,
278    pub coordinator: Option<String>,
279    pub duration: Option<i32>,
280    pub parameters: HashMap<String, String>,
281    pub request: Option<String>,
282    pub started_at: Option<String>,
283    pub events: Vec<TracingEvent>,
284}
285
286/// A single event within a tracing session.
287#[derive(Debug, Clone)]
288pub struct TracingEvent {
289    pub activity: Option<String>,
290    pub source: Option<String>,
291    pub source_elapsed: Option<i32>,
292    pub thread: Option<String>,
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn udt_metadata_fields() {
301        let udt = UdtMetadata {
302            keyspace: "ks".to_string(),
303            name: "address".to_string(),
304            field_names: vec!["street".to_string(), "city".to_string()],
305            field_types: vec!["text".to_string(), "text".to_string()],
306        };
307        assert_eq!(udt.keyspace, "ks");
308        assert_eq!(udt.name, "address");
309        assert_eq!(udt.field_names.len(), 2);
310        assert_eq!(udt.field_types.len(), 2);
311        assert_eq!(udt.field_names[0], "street");
312        assert_eq!(udt.field_types[0], "text");
313    }
314
315    #[test]
316    fn function_metadata_fields() {
317        let func = FunctionMetadata {
318            keyspace: "ks".to_string(),
319            name: "my_func".to_string(),
320            argument_types: vec!["int".to_string(), "text".to_string()],
321            return_type: "boolean".to_string(),
322        };
323        assert_eq!(func.keyspace, "ks");
324        assert_eq!(func.name, "my_func");
325        assert_eq!(func.argument_types, vec!["int", "text"]);
326        assert_eq!(func.return_type, "boolean");
327    }
328
329    #[test]
330    fn aggregate_metadata_fields() {
331        let agg = AggregateMetadata {
332            keyspace: "ks".to_string(),
333            name: "my_agg".to_string(),
334            argument_types: vec!["int".to_string()],
335            return_type: "bigint".to_string(),
336        };
337        assert_eq!(agg.keyspace, "ks");
338        assert_eq!(agg.name, "my_agg");
339        assert_eq!(agg.argument_types, vec!["int"]);
340        assert_eq!(agg.return_type, "bigint");
341    }
342
343    #[test]
344    fn udt_metadata_clone() {
345        let udt = UdtMetadata {
346            keyspace: "ks".to_string(),
347            name: "my_type".to_string(),
348            field_names: vec!["f1".to_string()],
349            field_types: vec!["int".to_string()],
350        };
351        let cloned = udt.clone();
352        assert_eq!(cloned.keyspace, udt.keyspace);
353        assert_eq!(cloned.name, udt.name);
354    }
355
356    #[test]
357    fn function_metadata_empty_args() {
358        let func = FunctionMetadata {
359            keyspace: "ks".to_string(),
360            name: "no_args_func".to_string(),
361            argument_types: vec![],
362            return_type: "text".to_string(),
363        };
364        assert!(func.argument_types.is_empty());
365    }
366
367    #[test]
368    fn aggregate_metadata_clone() {
369        let agg = AggregateMetadata {
370            keyspace: "ks".to_string(),
371            name: "my_agg".to_string(),
372            argument_types: vec!["int".to_string()],
373            return_type: "bigint".to_string(),
374        };
375        let cloned = agg.clone();
376        assert_eq!(cloned.return_type, agg.return_type);
377    }
378
379    #[test]
380    fn consistency_from_str() {
381        assert_eq!(
382            Consistency::from_str_cql("QUORUM"),
383            Some(Consistency::Quorum)
384        );
385        assert_eq!(
386            Consistency::from_str_cql("local_quorum"),
387            Some(Consistency::LocalQuorum)
388        );
389        assert_eq!(
390            Consistency::from_str_cql("LOCAL_SERIAL"),
391            Some(Consistency::LocalSerial)
392        );
393        assert_eq!(Consistency::from_str_cql("INVALID"), None);
394    }
395
396    #[test]
397    fn consistency_display() {
398        assert_eq!(Consistency::One.to_string(), "ONE");
399        assert_eq!(Consistency::LocalQuorum.to_string(), "LOCAL_QUORUM");
400    }
401
402    #[test]
403    fn consistency_roundtrip() {
404        let levels = [
405            Consistency::Any,
406            Consistency::One,
407            Consistency::Two,
408            Consistency::Three,
409            Consistency::Quorum,
410            Consistency::All,
411            Consistency::LocalQuorum,
412            Consistency::EachQuorum,
413            Consistency::Serial,
414            Consistency::LocalSerial,
415            Consistency::LocalOne,
416        ];
417        for level in &levels {
418            let s = level.as_cql_str();
419            assert_eq!(Consistency::from_str_cql(s), Some(*level));
420        }
421    }
422}