cqlsh_rs/driver/
scylla_driver.rs

1//! ScyllaDriver — CqlDriver implementation using the `scylla` crate.
2//!
3//! Provides connectivity to Apache Cassandra and ScyllaDB clusters using
4//! the scylla-rust-driver, with support for authentication, SSL/TLS,
5//! prepared statements, paging, and schema metadata queries.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::TryStreamExt;
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22    Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23    CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlValue};
28use super::{
29    AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30    KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31    UdtMetadata,
32};
33
34/// ScyllaDriver wraps a scylla `Session` and provides the `CqlDriver` trait.
35pub struct ScyllaDriver {
36    session: Session,
37    /// Cache of prepared statements keyed by internal ID.
38    prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39    /// Current consistency level.
40    consistency: Mutex<Consistency>,
41    /// Current serial consistency level.
42    serial_consistency: Mutex<Option<Consistency>>,
43    /// Whether tracing is enabled for queries.
44    tracing_enabled: AtomicBool,
45    /// Last tracing session ID.
46    last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50    /// Build the TLS configuration from SslConfig.
51    fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52        use rustls::pki_types::CertificateDer;
53        use std::fs::File;
54        use std::io::BufReader;
55
56        let mut root_store = rustls::RootCertStore::empty();
57
58        // Load CA certificate if provided
59        if let Some(certfile) = &ssl_config.certfile {
60            let file = File::open(certfile)
61                .with_context(|| format!("opening CA certificate: {certfile}"))?;
62            let mut reader = BufReader::new(file);
63            let certs = rustls_pemfile::certs(&mut reader)
64                .collect::<std::result::Result<Vec<_>, _>>()
65                .with_context(|| format!("parsing CA certificate: {certfile}"))?;
66            for cert in certs {
67                root_store
68                    .add(cert)
69                    .context("adding CA certificate to root store")?;
70            }
71        }
72
73        let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
74
75        // Client certificate authentication (mutual TLS)
76        let config = if let (Some(usercert_path), Some(userkey_path)) =
77            (&ssl_config.usercert, &ssl_config.userkey)
78        {
79            let cert_file = File::open(usercert_path)
80                .with_context(|| format!("opening client certificate: {usercert_path}"))?;
81            let mut cert_reader = BufReader::new(cert_file);
82            let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
83                .collect::<std::result::Result<Vec<_>, _>>()
84                .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
85
86            let key_file = File::open(userkey_path)
87                .with_context(|| format!("opening client key: {userkey_path}"))?;
88            let mut key_reader = BufReader::new(key_file);
89            let key = rustls_pemfile::private_key(&mut key_reader)
90                .with_context(|| format!("parsing client key: {userkey_path}"))?
91                .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
92
93            builder
94                .with_client_auth_cert(certs, key)
95                .context("configuring mutual TLS")?
96        } else {
97            builder.with_no_client_auth()
98        };
99
100        Ok(Arc::new(config))
101    }
102
103    /// Extract a `Vec<String>` from a `CqlValue::List` column value.
104    fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
105        match val {
106            Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
107            _ => Vec::new(),
108        }
109    }
110
111    /// Convert a scylla QueryResult into our CqlResult type.
112    fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
113        let tracing_id = result.tracing_id();
114        let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
115
116        // Check if this is a non-row result (DDL/DML)
117        if !result.is_rows() {
118            return Ok(CqlResult {
119                columns: Vec::new(),
120                rows: Vec::new(),
121                has_rows: false,
122                tracing_id,
123                warnings,
124            });
125        }
126
127        // Convert to QueryRowsResult to access typed rows
128        let rows_result = result
129            .into_rows_result()
130            .context("converting query result to rows")?;
131
132        // Extract column metadata
133        let col_specs = rows_result.column_specs();
134        let columns: Vec<CqlColumn> = col_specs
135            .iter()
136            .map(|spec| CqlColumn {
137                name: spec.name().to_string(),
138                type_name: format!("{:?}", spec.typ()),
139            })
140            .collect();
141
142        // Deserialize rows as untyped Row (Vec<Option<CqlValue>>)
143        let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
144
145        let mut cql_rows = Vec::new();
146        for row_result in typed_rows {
147            let row = row_result.context("deserializing row")?;
148            let values: Vec<CqlValue> = row
149                .columns
150                .into_iter()
151                .enumerate()
152                .map(|(col_idx, opt_val)| match opt_val {
153                    Some(v) => {
154                        tracing::debug!(
155                            column = col_idx,
156                            variant = ?std::mem::discriminant(&v),
157                            "converting ScyllaCqlValue: {v:?}"
158                        );
159                        Self::convert_scylla_value(v)
160                    }
161                    None => {
162                        tracing::debug!(column = col_idx, "column value is None (null)");
163                        CqlValue::Null
164                    }
165                })
166                .collect();
167            cql_rows.push(CqlRow { values });
168        }
169
170        Ok(CqlResult {
171            columns,
172            rows: cql_rows,
173            has_rows: true,
174            tracing_id,
175            warnings,
176        })
177    }
178
179    /// Convert a scylla CqlValue to our CqlValue type.
180    fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
181        match value {
182            ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
183            ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
184            ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
185            ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
186            ScyllaCqlValue::Decimal(d) => {
187                let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
188                let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
189                CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
190            }
191            ScyllaCqlValue::Date(d) => {
192                // scylla CqlDate wraps u32 days since epoch center (2^31)
193                let days = d.0;
194                let epoch_offset = days as i64 - (1i64 << 31);
195                match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
196                {
197                    Some(date) => CqlValue::Date(date),
198                    None => CqlValue::Text(format!("<invalid date: {days}>")),
199                }
200            }
201            ScyllaCqlValue::Double(d) => CqlValue::Double(d),
202            ScyllaCqlValue::Duration(d) => CqlValue::Duration {
203                months: d.months,
204                days: d.days,
205                nanoseconds: d.nanoseconds,
206            },
207            ScyllaCqlValue::Empty => CqlValue::Null,
208            ScyllaCqlValue::Float(f) => CqlValue::Float(f),
209            ScyllaCqlValue::Int(i) => CqlValue::Int(i),
210            ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
211            ScyllaCqlValue::Text(s) => CqlValue::Text(s),
212            ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
213            ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
214            ScyllaCqlValue::List(items) => {
215                CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
216            }
217            ScyllaCqlValue::Map(entries) => CqlValue::Map(
218                entries
219                    .into_iter()
220                    .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
221                    .collect(),
222            ),
223            ScyllaCqlValue::Set(items) => {
224                CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
225            }
226            ScyllaCqlValue::UserDefinedType {
227                keyspace,
228                name,
229                fields,
230            } => CqlValue::UserDefinedType {
231                keyspace,
232                type_name: name,
233                fields: fields
234                    .into_iter()
235                    .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
236                    .collect(),
237            },
238            ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
239            ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
240            ScyllaCqlValue::Time(t) => {
241                let nanos = t.0;
242                let secs = (nanos / 1_000_000_000) as u32;
243                let nano_part = (nanos % 1_000_000_000) as u32;
244                match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
245                    Some(time) => CqlValue::Time(time),
246                    None => CqlValue::Text(format!("<invalid time: {nanos}>")),
247                }
248            }
249            ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
250            ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
251                items
252                    .into_iter()
253                    .map(|v| v.map(Self::convert_scylla_value))
254                    .collect(),
255            ),
256            ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
257            ScyllaCqlValue::Varint(v) => {
258                let big_int =
259                    num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
260                CqlValue::Varint(big_int)
261            }
262            // CqlValue is non-exhaustive; handle future variants gracefully
263            _ => {
264                tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
265                CqlValue::Text(format!("{value:?}"))
266            }
267        }
268    }
269
270    /// Convert our internal CqlValue to scylla's CqlValue (reverse of convert_scylla_value).
271    fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
272        match v {
273            CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
274            CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
275            CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
276            CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
277            CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
278            CqlValue::Duration {
279                months,
280                days,
281                nanoseconds,
282            } => ScyllaCqlValue::Duration(CqlDuration {
283                months: *months,
284                days: *days,
285                nanoseconds: *nanoseconds,
286            }),
287            CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
288            CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
289            CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
290            CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
291            CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
292            CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
293            CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
294            CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
295            CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
296            CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
297            CqlValue::Date(d) => {
298                // Convert NaiveDate back to scylla's u32 days offset from 2^31 epoch
299                let days_from_ce = d.num_days_from_ce();
300                let epoch_offset = days_from_ce as i64 - 719_163;
301                let cql_days = (epoch_offset + (1i64 << 31)) as u32;
302                ScyllaCqlValue::Date(CqlDate(cql_days))
303            }
304            CqlValue::Time(t) => {
305                let nanos =
306                    t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
307                ScyllaCqlValue::Time(CqlTime(nanos))
308            }
309            CqlValue::Varint(bi) => {
310                let bytes = bi.to_signed_bytes_be();
311                ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
312            }
313            CqlValue::Decimal(d) => {
314                let (int_val, scale) = d.as_bigint_and_exponent();
315                let bytes = int_val.to_signed_bytes_be();
316                ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
317                    &bytes,
318                    scale as i32,
319                ))
320            }
321            CqlValue::List(items) => {
322                ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
323            }
324            CqlValue::Set(items) => {
325                ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
326            }
327            CqlValue::Map(entries) => ScyllaCqlValue::Map(
328                entries
329                    .iter()
330                    .map(|(k, v)| {
331                        (
332                            Self::internal_to_scylla_cql(k),
333                            Self::internal_to_scylla_cql(v),
334                        )
335                    })
336                    .collect(),
337            ),
338            CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
339                items
340                    .iter()
341                    .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
342                    .collect(),
343            ),
344            CqlValue::UserDefinedType {
345                keyspace,
346                type_name,
347                fields,
348            } => ScyllaCqlValue::UserDefinedType {
349                keyspace: keyspace.clone(),
350                name: type_name.clone(),
351                fields: fields
352                    .iter()
353                    .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
354                    .collect(),
355            },
356            CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
357        }
358    }
359
360    /// Convert our Consistency to scylla's Consistency.
361    fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
362        use scylla::statement::Consistency as SC;
363        match c {
364            Consistency::Any => SC::Any,
365            Consistency::One => SC::One,
366            Consistency::Two => SC::Two,
367            Consistency::Three => SC::Three,
368            Consistency::Quorum => SC::Quorum,
369            Consistency::All => SC::All,
370            Consistency::LocalQuorum => SC::LocalQuorum,
371            Consistency::EachQuorum => SC::EachQuorum,
372            Consistency::Serial => SC::Serial,
373            Consistency::LocalSerial => SC::LocalSerial,
374            Consistency::LocalOne => SC::LocalOne,
375        }
376    }
377
378    /// Convert our Consistency to scylla's SerialConsistency.
379    fn to_scylla_serial_consistency(
380        c: Consistency,
381    ) -> Option<scylla::statement::SerialConsistency> {
382        use scylla::statement::SerialConsistency as SC;
383        match c {
384            Consistency::Serial => Some(SC::Serial),
385            Consistency::LocalSerial => Some(SC::LocalSerial),
386            _ => None,
387        }
388    }
389
390    /// Build a Statement with the current consistency and tracing settings.
391    fn build_query(&self, cql: &str) -> Statement {
392        let mut stmt = Statement::new(cql);
393
394        let consistency = *self.consistency.lock().unwrap();
395        stmt.set_consistency(Self::to_scylla_consistency(consistency));
396
397        let serial = *self.serial_consistency.lock().unwrap();
398        if let Some(sc) = serial {
399            if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
400                stmt.set_serial_consistency(Some(sc));
401            }
402        }
403
404        if self.tracing_enabled.load(Ordering::Relaxed) {
405            stmt.set_tracing(true);
406        }
407
408        stmt
409    }
410
411    /// Store tracing ID from a result if present.
412    fn store_trace_id(&self, result: &QueryResult) {
413        if let Some(trace_id) = result.tracing_id() {
414            *self.last_trace_id.lock().unwrap() = Some(trace_id);
415        }
416    }
417}
418
419#[async_trait]
420impl CqlDriver for ScyllaDriver {
421    async fn connect(config: &ConnectionConfig) -> Result<Self> {
422        let addr = format!("{}:{}", config.host, config.port);
423
424        let mut builder = SessionBuilder::new().known_node(&addr);
425
426        // Authentication
427        if let (Some(username), Some(password)) = (&config.username, &config.password) {
428            builder = builder.user(username, password);
429        }
430
431        // Connection timeout
432        builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
433
434        // Default keyspace
435        if let Some(keyspace) = &config.keyspace {
436            builder = builder.use_keyspace(keyspace, false);
437        }
438
439        // SSL/TLS
440        if config.ssl {
441            let tls_config = if let Some(ssl_config) = &config.ssl_config {
442                Self::build_rustls_config(ssl_config)?
443            } else {
444                // SSL enabled but no config — use default (no validation)
445                let root_store = rustls::RootCertStore::empty();
446                Arc::new(
447                    rustls::ClientConfig::builder()
448                        .with_root_certificates(root_store)
449                        .with_no_client_auth(),
450                )
451            };
452            builder = builder.tls_context(Some(tls_config));
453        }
454
455        let session = builder.build().await.context("connecting to cluster")?;
456
457        Ok(ScyllaDriver {
458            session,
459            prepared_cache: Mutex::new(HashMap::new()),
460            consistency: Mutex::new(Consistency::One),
461            serial_consistency: Mutex::new(None),
462            tracing_enabled: AtomicBool::new(false),
463            last_trace_id: Mutex::new(None),
464        })
465    }
466
467    async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
468        let stmt = self.build_query(query);
469
470        let result = self.session.query_unpaged(stmt, ()).await?;
471
472        self.store_trace_id(&result);
473        Self::convert_query_result(result)
474    }
475
476    async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
477        let mut stmt = self.build_query(query);
478        stmt.set_page_size(page_size);
479
480        let query_pager = self
481            .session
482            .query_iter(stmt, ())
483            .await
484            .context("starting paged query")?;
485
486        // Get column metadata from the pager
487        let col_specs = query_pager.column_specs();
488        let columns: Vec<CqlColumn> = col_specs
489            .iter()
490            .map(|spec| CqlColumn {
491                name: spec.name().to_string(),
492                type_name: format!("{:?}", spec.typ()),
493            })
494            .collect();
495
496        // Stream all rows using the untyped Row type
497        let mut rows_stream = query_pager.rows_stream::<Row>()?;
498        let mut cql_rows = Vec::new();
499
500        while let Some(row) = rows_stream.try_next().await? {
501            let values: Vec<CqlValue> = row
502                .columns
503                .into_iter()
504                .map(|opt_val| match opt_val {
505                    Some(v) => Self::convert_scylla_value(v),
506                    None => CqlValue::Null,
507                })
508                .collect();
509            cql_rows.push(CqlRow { values });
510        }
511
512        Ok(CqlResult {
513            columns,
514            rows: cql_rows,
515            has_rows: true,
516            tracing_id: None,
517            warnings: Vec::new(),
518        })
519    }
520
521    async fn prepare(&self, query: &str) -> Result<PreparedId> {
522        let prepared = self
523            .session
524            .prepare(query)
525            .await
526            .context("preparing CQL statement")?;
527
528        let id = prepared.get_id().to_vec();
529        self.prepared_cache
530            .lock()
531            .unwrap()
532            .insert(id.clone(), prepared);
533
534        Ok(PreparedId { inner: id })
535    }
536
537    async fn execute_prepared(
538        &self,
539        prepared_id: &PreparedId,
540        values: &[CqlValue],
541    ) -> Result<CqlResult> {
542        let prepared = self
543            .prepared_cache
544            .lock()
545            .unwrap()
546            .get(&prepared_id.inner)
547            .cloned()
548            .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
549
550        // Convert internal CqlValues to scylla CqlValues for binding.
551        // Null/Unset become None (bound as null), all others become Some(value).
552        let scylla_values: Vec<Option<ScyllaCqlValue>> = values
553            .iter()
554            .map(|v| match v {
555                CqlValue::Null | CqlValue::Unset => None,
556                other => Some(Self::internal_to_scylla_cql(other)),
557            })
558            .collect();
559
560        let result = self
561            .session
562            .execute_unpaged(&prepared, scylla_values)
563            .await
564            .context("executing prepared statement")?;
565
566        self.store_trace_id(&result);
567        Self::convert_query_result(result)
568    }
569
570    async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
571        self.session
572            .use_keyspace(keyspace, false)
573            .await
574            .with_context(|| format!("switching to keyspace: {keyspace}"))?;
575        Ok(())
576    }
577
578    fn get_consistency(&self) -> Consistency {
579        *self.consistency.lock().unwrap()
580    }
581
582    fn set_consistency(&self, consistency: Consistency) {
583        *self.consistency.lock().unwrap() = consistency;
584    }
585
586    fn get_serial_consistency(&self) -> Option<Consistency> {
587        *self.serial_consistency.lock().unwrap()
588    }
589
590    fn set_serial_consistency(&self, consistency: Option<Consistency>) {
591        *self.serial_consistency.lock().unwrap() = consistency;
592    }
593
594    fn set_tracing(&self, enabled: bool) {
595        self.tracing_enabled.store(enabled, Ordering::Relaxed);
596    }
597
598    fn is_tracing_enabled(&self) -> bool {
599        self.tracing_enabled.load(Ordering::Relaxed)
600    }
601
602    fn last_trace_id(&self) -> Option<Uuid> {
603        *self.last_trace_id.lock().unwrap()
604    }
605
606    async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
607        let query = format!(
608            "SELECT client, command, coordinator, duration, parameters, request, started_at \
609             FROM system_traces.sessions WHERE session_id = {}",
610            trace_id
611        );
612        let result = self.execute_unpaged(&query).await?;
613
614        if result.rows.is_empty() {
615            return Ok(None);
616        }
617
618        let events_query = format!(
619            "SELECT activity, source, source_elapsed, thread \
620             FROM system_traces.events WHERE session_id = {}",
621            trace_id
622        );
623        let events_result = self.execute_unpaged(&events_query).await?;
624
625        let events: Vec<TracingEvent> = events_result
626            .rows
627            .iter()
628            .map(|row| TracingEvent {
629                activity: row.get(0).and_then(cql_value_to_string),
630                source: row.get(1).and_then(cql_value_to_string),
631                source_elapsed: row.get(2).and_then(cql_value_to_i32),
632                thread: row.get(3).and_then(cql_value_to_string),
633            })
634            .collect();
635
636        let session_row = &result.rows[0];
637        Ok(Some(TracingSession {
638            trace_id,
639            client: session_row.get(0).and_then(cql_value_to_string),
640            command: session_row.get(1).and_then(cql_value_to_string),
641            coordinator: session_row.get(2).and_then(cql_value_to_string),
642            duration: session_row.get(3).and_then(cql_value_to_i32),
643            parameters: HashMap::new(),
644            request: session_row.get(5).and_then(cql_value_to_string),
645            started_at: session_row.get(6).and_then(cql_value_to_string),
646            events,
647        }))
648    }
649
650    async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
651        let result = self
652            .execute_unpaged(
653                "SELECT keyspace_name, replication, durable_writes \
654                 FROM system_schema.keyspaces",
655            )
656            .await?;
657
658        let mut keyspaces = Vec::new();
659        for row in &result.rows {
660            let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
661            let durable_writes = match row.get(2) {
662                Some(CqlValue::Boolean(b)) => *b,
663                _ => true,
664            };
665
666            keyspaces.push(KeyspaceMetadata {
667                name,
668                replication: HashMap::new(),
669                durable_writes,
670            });
671        }
672
673        Ok(keyspaces)
674    }
675
676    async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
677        let result = self
678            .execute_unpaged(&format!(
679                "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{}'",
680                keyspace.replace('\'', "''")
681            ))
682            .await?;
683
684        let mut tables = Vec::new();
685        for row in &result.rows {
686            let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
687
688            let col_result = self
689                .execute_unpaged(&format!(
690                    "SELECT column_name, type, kind \
691                     FROM system_schema.columns \
692                     WHERE keyspace_name = '{}' AND table_name = '{}'",
693                    keyspace.replace('\'', "''"),
694                    table_name.replace('\'', "''")
695                ))
696                .await?;
697
698            let mut columns = Vec::new();
699            let mut partition_key = Vec::new();
700            let mut clustering_key = Vec::new();
701
702            for col_row in &col_result.rows {
703                let col_name = col_row
704                    .get(0)
705                    .and_then(cql_value_to_string)
706                    .unwrap_or_default();
707                let col_type = col_row
708                    .get(1)
709                    .and_then(cql_value_to_string)
710                    .unwrap_or_default();
711                let kind = col_row
712                    .get(2)
713                    .and_then(cql_value_to_string)
714                    .unwrap_or_default();
715
716                columns.push(ColumnMetadata {
717                    name: col_name.clone(),
718                    type_name: col_type,
719                });
720
721                match kind.as_str() {
722                    "partition_key" => partition_key.push(col_name),
723                    "clustering" => clustering_key.push(col_name),
724                    _ => {}
725                }
726            }
727
728            tables.push(TableMetadata {
729                keyspace: keyspace.to_string(),
730                name: table_name,
731                columns,
732                partition_key,
733                clustering_key,
734            });
735        }
736
737        Ok(tables)
738    }
739
740    async fn get_table_metadata(
741        &self,
742        keyspace: &str,
743        table: &str,
744    ) -> Result<Option<TableMetadata>> {
745        let tables = self.get_tables(keyspace).await?;
746        Ok(tables.into_iter().find(|t| t.name == table))
747    }
748
749    async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
750        let query = format!(
751            "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
752            keyspace.replace('\'', "''")
753        );
754        let result = self.execute_unpaged(&query).await?;
755        let udts = result
756            .rows
757            .iter()
758            .filter_map(|row| {
759                let name = row.get_by_name("type_name", &result.columns)?.to_string();
760                let field_names =
761                    Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
762                let field_types =
763                    Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
764                Some(UdtMetadata {
765                    keyspace: keyspace.to_string(),
766                    name,
767                    field_names,
768                    field_types,
769                })
770            })
771            .collect();
772        Ok(udts)
773    }
774
775    async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
776        let query = format!(
777            "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
778            keyspace.replace('\'', "''")
779        );
780        let result = self.execute_unpaged(&query).await?;
781        let functions = result
782            .rows
783            .iter()
784            .filter_map(|row| {
785                let name = row
786                    .get_by_name("function_name", &result.columns)?
787                    .to_string();
788                let argument_types = Self::extract_string_list_val(
789                    row.get_by_name("argument_types", &result.columns),
790                );
791                let return_type = row
792                    .get_by_name("return_type", &result.columns)
793                    .map(|v| v.to_string())
794                    .unwrap_or_default();
795                Some(FunctionMetadata {
796                    keyspace: keyspace.to_string(),
797                    name,
798                    argument_types,
799                    return_type,
800                })
801            })
802            .collect();
803        Ok(functions)
804    }
805
806    async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
807        let query = format!(
808            "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
809            keyspace.replace('\'', "''")
810        );
811        let result = self.execute_unpaged(&query).await?;
812        let aggregates = result
813            .rows
814            .iter()
815            .filter_map(|row| {
816                let name = row
817                    .get_by_name("aggregate_name", &result.columns)?
818                    .to_string();
819                let argument_types = Self::extract_string_list_val(
820                    row.get_by_name("argument_types", &result.columns),
821                );
822                let return_type = row
823                    .get_by_name("return_type", &result.columns)
824                    .map(|v| v.to_string())
825                    .unwrap_or_default();
826                Some(AggregateMetadata {
827                    keyspace: keyspace.to_string(),
828                    name,
829                    argument_types,
830                    return_type,
831                })
832            })
833            .collect();
834        Ok(aggregates)
835    }
836
837    async fn get_cluster_name(&self) -> Result<Option<String>> {
838        let result = self
839            .execute_unpaged("SELECT cluster_name FROM system.local")
840            .await?;
841        Ok(result
842            .rows
843            .first()
844            .and_then(|row| row.get(0))
845            .and_then(cql_value_to_string))
846    }
847
848    async fn get_cql_version(&self) -> Result<Option<String>> {
849        let result = self
850            .execute_unpaged("SELECT cql_version FROM system.local")
851            .await?;
852        Ok(result
853            .rows
854            .first()
855            .and_then(|row| row.get(0))
856            .and_then(cql_value_to_string))
857    }
858
859    async fn get_release_version(&self) -> Result<Option<String>> {
860        let result = self
861            .execute_unpaged("SELECT release_version FROM system.local")
862            .await?;
863        Ok(result
864            .rows
865            .first()
866            .and_then(|row| row.get(0))
867            .and_then(cql_value_to_string))
868    }
869
870    async fn get_scylla_version(&self) -> Result<Option<String>> {
871        // ScyllaDB exposes its version in system.local.scylla_version
872        // This column doesn't exist in Apache Cassandra, so errors are expected.
873        let result = self
874            .execute_unpaged("SELECT scylla_version FROM system.local")
875            .await;
876        match result {
877            Ok(r) => Ok(r
878                .rows
879                .first()
880                .and_then(|row| row.get(0))
881                .and_then(cql_value_to_string)),
882            Err(_) => Ok(None), // Column doesn't exist → not ScyllaDB
883        }
884    }
885
886    async fn is_connected(&self) -> bool {
887        self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
888            .await
889            .is_ok()
890    }
891}
892
893/// Helper: extract a string from a CqlValue.
894fn cql_value_to_string(v: &CqlValue) -> Option<String> {
895    match v {
896        CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
897        CqlValue::Inet(addr) => Some(addr.to_string()),
898        CqlValue::Null => None,
899        other => Some(other.to_string()),
900    }
901}
902
903/// Helper: extract an i32 from a CqlValue.
904fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
905    match v {
906        CqlValue::Int(i) => Some(*i),
907        CqlValue::BigInt(i) => Some(*i as i32),
908        CqlValue::SmallInt(i) => Some(*i as i32),
909        CqlValue::TinyInt(i) => Some(*i as i32),
910        _ => None,
911    }
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917
918    #[test]
919    fn convert_scylla_value_text() {
920        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
921        assert_eq!(v, CqlValue::Text("hello".to_string()));
922    }
923
924    #[test]
925    fn convert_scylla_value_int() {
926        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
927        assert_eq!(v, CqlValue::Int(42));
928    }
929
930    #[test]
931    fn convert_scylla_value_boolean() {
932        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
933        assert_eq!(v, CqlValue::Boolean(true));
934    }
935
936    #[test]
937    fn convert_scylla_value_null() {
938        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
939        assert_eq!(v, CqlValue::Null);
940    }
941
942    #[test]
943    fn convert_scylla_value_list() {
944        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
945            ScyllaCqlValue::Int(1),
946            ScyllaCqlValue::Int(2),
947        ]));
948        assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
949    }
950
951    #[test]
952    fn convert_scylla_value_uuid() {
953        let id = Uuid::nil();
954        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
955        assert_eq!(v, CqlValue::Uuid(id));
956    }
957
958    #[test]
959    fn convert_scylla_value_blob() {
960        let v =
961            ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
962        assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
963    }
964
965    #[test]
966    fn convert_scylla_value_float() {
967        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
968        assert_eq!(v, CqlValue::Float(1.5));
969    }
970
971    #[test]
972    fn convert_scylla_value_double() {
973        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
974        assert_eq!(v, CqlValue::Double(1.5));
975    }
976
977    #[test]
978    fn convert_scylla_value_map() {
979        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
980            ScyllaCqlValue::Text("key".to_string()),
981            ScyllaCqlValue::Int(42),
982        )]));
983        assert_eq!(
984            v,
985            CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
986        );
987    }
988
989    #[test]
990    fn convert_scylla_value_set() {
991        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
992            ScyllaCqlValue::Int(1),
993            ScyllaCqlValue::Int(2),
994        ]));
995        assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
996    }
997
998    #[test]
999    fn convert_scylla_value_udt() {
1000        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1001            keyspace: "ks".to_string(),
1002            name: "my_type".to_string(),
1003            fields: vec![
1004                ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1005                ("f2".to_string(), None),
1006            ],
1007        });
1008        assert_eq!(
1009            v,
1010            CqlValue::UserDefinedType {
1011                keyspace: "ks".to_string(),
1012                type_name: "my_type".to_string(),
1013                fields: vec![
1014                    ("f1".to_string(), Some(CqlValue::Int(1))),
1015                    ("f2".to_string(), None),
1016                ],
1017            }
1018        );
1019    }
1020
1021    #[test]
1022    fn to_scylla_consistency_mapping() {
1023        use scylla::statement::Consistency as SC;
1024        assert!(matches!(
1025            ScyllaDriver::to_scylla_consistency(Consistency::One),
1026            SC::One
1027        ));
1028        assert!(matches!(
1029            ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1030            SC::Quorum
1031        ));
1032        assert!(matches!(
1033            ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1034            SC::LocalQuorum
1035        ));
1036        assert!(matches!(
1037            ScyllaDriver::to_scylla_consistency(Consistency::All),
1038            SC::All
1039        ));
1040    }
1041
1042    #[test]
1043    fn to_scylla_serial_consistency_mapping() {
1044        use scylla::statement::SerialConsistency as SC;
1045        assert!(matches!(
1046            ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1047            Some(SC::Serial)
1048        ));
1049        assert!(matches!(
1050            ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1051            Some(SC::LocalSerial)
1052        ));
1053        assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1054    }
1055
1056    #[test]
1057    fn cql_value_to_string_helper() {
1058        assert_eq!(
1059            cql_value_to_string(&CqlValue::Text("hello".to_string())),
1060            Some("hello".to_string())
1061        );
1062        assert_eq!(
1063            cql_value_to_string(&CqlValue::Int(42)),
1064            Some("42".to_string())
1065        );
1066        assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1067    }
1068
1069    #[test]
1070    fn cql_value_to_i32_helper() {
1071        assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1072        assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1073        assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1074    }
1075}