1use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::TryStreamExt;
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22 Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23 CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlValue};
28use super::{
29 AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30 KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31 UdtMetadata,
32};
33
34pub struct ScyllaDriver {
36 session: Session,
37 prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39 consistency: Mutex<Consistency>,
41 serial_consistency: Mutex<Option<Consistency>>,
43 tracing_enabled: AtomicBool,
45 last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50 fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52 use rustls::pki_types::CertificateDer;
53 use std::fs::File;
54 use std::io::BufReader;
55
56 let mut root_store = rustls::RootCertStore::empty();
57
58 if let Some(certfile) = &ssl_config.certfile {
60 let file = File::open(certfile)
61 .with_context(|| format!("opening CA certificate: {certfile}"))?;
62 let mut reader = BufReader::new(file);
63 let certs = rustls_pemfile::certs(&mut reader)
64 .collect::<std::result::Result<Vec<_>, _>>()
65 .with_context(|| format!("parsing CA certificate: {certfile}"))?;
66 for cert in certs {
67 root_store
68 .add(cert)
69 .context("adding CA certificate to root store")?;
70 }
71 }
72
73 let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
74
75 let config = if let (Some(usercert_path), Some(userkey_path)) =
77 (&ssl_config.usercert, &ssl_config.userkey)
78 {
79 let cert_file = File::open(usercert_path)
80 .with_context(|| format!("opening client certificate: {usercert_path}"))?;
81 let mut cert_reader = BufReader::new(cert_file);
82 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
83 .collect::<std::result::Result<Vec<_>, _>>()
84 .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
85
86 let key_file = File::open(userkey_path)
87 .with_context(|| format!("opening client key: {userkey_path}"))?;
88 let mut key_reader = BufReader::new(key_file);
89 let key = rustls_pemfile::private_key(&mut key_reader)
90 .with_context(|| format!("parsing client key: {userkey_path}"))?
91 .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
92
93 builder
94 .with_client_auth_cert(certs, key)
95 .context("configuring mutual TLS")?
96 } else {
97 builder.with_no_client_auth()
98 };
99
100 Ok(Arc::new(config))
101 }
102
103 fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
105 match val {
106 Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
107 _ => Vec::new(),
108 }
109 }
110
111 fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
113 let tracing_id = result.tracing_id();
114 let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
115
116 if !result.is_rows() {
118 return Ok(CqlResult {
119 columns: Vec::new(),
120 rows: Vec::new(),
121 has_rows: false,
122 tracing_id,
123 warnings,
124 });
125 }
126
127 let rows_result = result
129 .into_rows_result()
130 .context("converting query result to rows")?;
131
132 let col_specs = rows_result.column_specs();
134 let columns: Vec<CqlColumn> = col_specs
135 .iter()
136 .map(|spec| CqlColumn {
137 name: spec.name().to_string(),
138 type_name: format!("{:?}", spec.typ()),
139 })
140 .collect();
141
142 let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
144
145 let mut cql_rows = Vec::new();
146 for row_result in typed_rows {
147 let row = row_result.context("deserializing row")?;
148 let values: Vec<CqlValue> = row
149 .columns
150 .into_iter()
151 .enumerate()
152 .map(|(col_idx, opt_val)| match opt_val {
153 Some(v) => {
154 tracing::debug!(
155 column = col_idx,
156 variant = ?std::mem::discriminant(&v),
157 "converting ScyllaCqlValue: {v:?}"
158 );
159 Self::convert_scylla_value(v)
160 }
161 None => {
162 tracing::debug!(column = col_idx, "column value is None (null)");
163 CqlValue::Null
164 }
165 })
166 .collect();
167 cql_rows.push(CqlRow { values });
168 }
169
170 Ok(CqlResult {
171 columns,
172 rows: cql_rows,
173 has_rows: true,
174 tracing_id,
175 warnings,
176 })
177 }
178
179 fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
181 match value {
182 ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
183 ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
184 ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
185 ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
186 ScyllaCqlValue::Decimal(d) => {
187 let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
188 let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
189 CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
190 }
191 ScyllaCqlValue::Date(d) => {
192 let days = d.0;
194 let epoch_offset = days as i64 - (1i64 << 31);
195 match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
196 {
197 Some(date) => CqlValue::Date(date),
198 None => CqlValue::Text(format!("<invalid date: {days}>")),
199 }
200 }
201 ScyllaCqlValue::Double(d) => CqlValue::Double(d),
202 ScyllaCqlValue::Duration(d) => CqlValue::Duration {
203 months: d.months,
204 days: d.days,
205 nanoseconds: d.nanoseconds,
206 },
207 ScyllaCqlValue::Empty => CqlValue::Null,
208 ScyllaCqlValue::Float(f) => CqlValue::Float(f),
209 ScyllaCqlValue::Int(i) => CqlValue::Int(i),
210 ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
211 ScyllaCqlValue::Text(s) => CqlValue::Text(s),
212 ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
213 ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
214 ScyllaCqlValue::List(items) => {
215 CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
216 }
217 ScyllaCqlValue::Map(entries) => CqlValue::Map(
218 entries
219 .into_iter()
220 .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
221 .collect(),
222 ),
223 ScyllaCqlValue::Set(items) => {
224 CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
225 }
226 ScyllaCqlValue::UserDefinedType {
227 keyspace,
228 name,
229 fields,
230 } => CqlValue::UserDefinedType {
231 keyspace,
232 type_name: name,
233 fields: fields
234 .into_iter()
235 .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
236 .collect(),
237 },
238 ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
239 ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
240 ScyllaCqlValue::Time(t) => {
241 let nanos = t.0;
242 let secs = (nanos / 1_000_000_000) as u32;
243 let nano_part = (nanos % 1_000_000_000) as u32;
244 match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
245 Some(time) => CqlValue::Time(time),
246 None => CqlValue::Text(format!("<invalid time: {nanos}>")),
247 }
248 }
249 ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
250 ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
251 items
252 .into_iter()
253 .map(|v| v.map(Self::convert_scylla_value))
254 .collect(),
255 ),
256 ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
257 ScyllaCqlValue::Varint(v) => {
258 let big_int =
259 num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
260 CqlValue::Varint(big_int)
261 }
262 _ => {
264 tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
265 CqlValue::Text(format!("{value:?}"))
266 }
267 }
268 }
269
270 fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
272 match v {
273 CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
274 CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
275 CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
276 CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
277 CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
278 CqlValue::Duration {
279 months,
280 days,
281 nanoseconds,
282 } => ScyllaCqlValue::Duration(CqlDuration {
283 months: *months,
284 days: *days,
285 nanoseconds: *nanoseconds,
286 }),
287 CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
288 CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
289 CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
290 CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
291 CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
292 CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
293 CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
294 CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
295 CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
296 CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
297 CqlValue::Date(d) => {
298 let days_from_ce = d.num_days_from_ce();
300 let epoch_offset = days_from_ce as i64 - 719_163;
301 let cql_days = (epoch_offset + (1i64 << 31)) as u32;
302 ScyllaCqlValue::Date(CqlDate(cql_days))
303 }
304 CqlValue::Time(t) => {
305 let nanos =
306 t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
307 ScyllaCqlValue::Time(CqlTime(nanos))
308 }
309 CqlValue::Varint(bi) => {
310 let bytes = bi.to_signed_bytes_be();
311 ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
312 }
313 CqlValue::Decimal(d) => {
314 let (int_val, scale) = d.as_bigint_and_exponent();
315 let bytes = int_val.to_signed_bytes_be();
316 ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
317 &bytes,
318 scale as i32,
319 ))
320 }
321 CqlValue::List(items) => {
322 ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
323 }
324 CqlValue::Set(items) => {
325 ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
326 }
327 CqlValue::Map(entries) => ScyllaCqlValue::Map(
328 entries
329 .iter()
330 .map(|(k, v)| {
331 (
332 Self::internal_to_scylla_cql(k),
333 Self::internal_to_scylla_cql(v),
334 )
335 })
336 .collect(),
337 ),
338 CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
339 items
340 .iter()
341 .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
342 .collect(),
343 ),
344 CqlValue::UserDefinedType {
345 keyspace,
346 type_name,
347 fields,
348 } => ScyllaCqlValue::UserDefinedType {
349 keyspace: keyspace.clone(),
350 name: type_name.clone(),
351 fields: fields
352 .iter()
353 .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
354 .collect(),
355 },
356 CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
357 }
358 }
359
360 fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
362 use scylla::statement::Consistency as SC;
363 match c {
364 Consistency::Any => SC::Any,
365 Consistency::One => SC::One,
366 Consistency::Two => SC::Two,
367 Consistency::Three => SC::Three,
368 Consistency::Quorum => SC::Quorum,
369 Consistency::All => SC::All,
370 Consistency::LocalQuorum => SC::LocalQuorum,
371 Consistency::EachQuorum => SC::EachQuorum,
372 Consistency::Serial => SC::Serial,
373 Consistency::LocalSerial => SC::LocalSerial,
374 Consistency::LocalOne => SC::LocalOne,
375 }
376 }
377
378 fn to_scylla_serial_consistency(
380 c: Consistency,
381 ) -> Option<scylla::statement::SerialConsistency> {
382 use scylla::statement::SerialConsistency as SC;
383 match c {
384 Consistency::Serial => Some(SC::Serial),
385 Consistency::LocalSerial => Some(SC::LocalSerial),
386 _ => None,
387 }
388 }
389
390 fn build_query(&self, cql: &str) -> Statement {
392 let mut stmt = Statement::new(cql);
393
394 let consistency = *self.consistency.lock().unwrap();
395 stmt.set_consistency(Self::to_scylla_consistency(consistency));
396
397 let serial = *self.serial_consistency.lock().unwrap();
398 if let Some(sc) = serial {
399 if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
400 stmt.set_serial_consistency(Some(sc));
401 }
402 }
403
404 if self.tracing_enabled.load(Ordering::Relaxed) {
405 stmt.set_tracing(true);
406 }
407
408 stmt
409 }
410
411 fn store_trace_id(&self, result: &QueryResult) {
413 if let Some(trace_id) = result.tracing_id() {
414 *self.last_trace_id.lock().unwrap() = Some(trace_id);
415 }
416 }
417}
418
419#[async_trait]
420impl CqlDriver for ScyllaDriver {
421 async fn connect(config: &ConnectionConfig) -> Result<Self> {
422 let addr = format!("{}:{}", config.host, config.port);
423
424 let mut builder = SessionBuilder::new().known_node(&addr);
425
426 if let (Some(username), Some(password)) = (&config.username, &config.password) {
428 builder = builder.user(username, password);
429 }
430
431 builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
433
434 if let Some(keyspace) = &config.keyspace {
436 builder = builder.use_keyspace(keyspace, false);
437 }
438
439 if config.ssl {
441 let tls_config = if let Some(ssl_config) = &config.ssl_config {
442 Self::build_rustls_config(ssl_config)?
443 } else {
444 let root_store = rustls::RootCertStore::empty();
446 Arc::new(
447 rustls::ClientConfig::builder()
448 .with_root_certificates(root_store)
449 .with_no_client_auth(),
450 )
451 };
452 builder = builder.tls_context(Some(tls_config));
453 }
454
455 let session = builder.build().await.context("connecting to cluster")?;
456
457 Ok(ScyllaDriver {
458 session,
459 prepared_cache: Mutex::new(HashMap::new()),
460 consistency: Mutex::new(Consistency::One),
461 serial_consistency: Mutex::new(None),
462 tracing_enabled: AtomicBool::new(false),
463 last_trace_id: Mutex::new(None),
464 })
465 }
466
467 async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
468 let stmt = self.build_query(query);
469
470 let result = self.session.query_unpaged(stmt, ()).await?;
471
472 self.store_trace_id(&result);
473 Self::convert_query_result(result)
474 }
475
476 async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
477 let mut stmt = self.build_query(query);
478 stmt.set_page_size(page_size);
479
480 let query_pager = self
481 .session
482 .query_iter(stmt, ())
483 .await
484 .context("starting paged query")?;
485
486 let col_specs = query_pager.column_specs();
488 let columns: Vec<CqlColumn> = col_specs
489 .iter()
490 .map(|spec| CqlColumn {
491 name: spec.name().to_string(),
492 type_name: format!("{:?}", spec.typ()),
493 })
494 .collect();
495
496 let mut rows_stream = query_pager.rows_stream::<Row>()?;
498 let mut cql_rows = Vec::new();
499
500 while let Some(row) = rows_stream.try_next().await? {
501 let values: Vec<CqlValue> = row
502 .columns
503 .into_iter()
504 .map(|opt_val| match opt_val {
505 Some(v) => Self::convert_scylla_value(v),
506 None => CqlValue::Null,
507 })
508 .collect();
509 cql_rows.push(CqlRow { values });
510 }
511
512 Ok(CqlResult {
513 columns,
514 rows: cql_rows,
515 has_rows: true,
516 tracing_id: None,
517 warnings: Vec::new(),
518 })
519 }
520
521 async fn prepare(&self, query: &str) -> Result<PreparedId> {
522 let prepared = self
523 .session
524 .prepare(query)
525 .await
526 .context("preparing CQL statement")?;
527
528 let id = prepared.get_id().to_vec();
529 self.prepared_cache
530 .lock()
531 .unwrap()
532 .insert(id.clone(), prepared);
533
534 Ok(PreparedId { inner: id })
535 }
536
537 async fn execute_prepared(
538 &self,
539 prepared_id: &PreparedId,
540 values: &[CqlValue],
541 ) -> Result<CqlResult> {
542 let prepared = self
543 .prepared_cache
544 .lock()
545 .unwrap()
546 .get(&prepared_id.inner)
547 .cloned()
548 .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
549
550 let scylla_values: Vec<Option<ScyllaCqlValue>> = values
553 .iter()
554 .map(|v| match v {
555 CqlValue::Null | CqlValue::Unset => None,
556 other => Some(Self::internal_to_scylla_cql(other)),
557 })
558 .collect();
559
560 let result = self
561 .session
562 .execute_unpaged(&prepared, scylla_values)
563 .await
564 .context("executing prepared statement")?;
565
566 self.store_trace_id(&result);
567 Self::convert_query_result(result)
568 }
569
570 async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
571 self.session
572 .use_keyspace(keyspace, false)
573 .await
574 .with_context(|| format!("switching to keyspace: {keyspace}"))?;
575 Ok(())
576 }
577
578 fn get_consistency(&self) -> Consistency {
579 *self.consistency.lock().unwrap()
580 }
581
582 fn set_consistency(&self, consistency: Consistency) {
583 *self.consistency.lock().unwrap() = consistency;
584 }
585
586 fn get_serial_consistency(&self) -> Option<Consistency> {
587 *self.serial_consistency.lock().unwrap()
588 }
589
590 fn set_serial_consistency(&self, consistency: Option<Consistency>) {
591 *self.serial_consistency.lock().unwrap() = consistency;
592 }
593
594 fn set_tracing(&self, enabled: bool) {
595 self.tracing_enabled.store(enabled, Ordering::Relaxed);
596 }
597
598 fn is_tracing_enabled(&self) -> bool {
599 self.tracing_enabled.load(Ordering::Relaxed)
600 }
601
602 fn last_trace_id(&self) -> Option<Uuid> {
603 *self.last_trace_id.lock().unwrap()
604 }
605
606 async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
607 let query = format!(
608 "SELECT client, command, coordinator, duration, parameters, request, started_at \
609 FROM system_traces.sessions WHERE session_id = {}",
610 trace_id
611 );
612 let result = self.execute_unpaged(&query).await?;
613
614 if result.rows.is_empty() {
615 return Ok(None);
616 }
617
618 let events_query = format!(
619 "SELECT activity, source, source_elapsed, thread \
620 FROM system_traces.events WHERE session_id = {}",
621 trace_id
622 );
623 let events_result = self.execute_unpaged(&events_query).await?;
624
625 let events: Vec<TracingEvent> = events_result
626 .rows
627 .iter()
628 .map(|row| TracingEvent {
629 activity: row.get(0).and_then(cql_value_to_string),
630 source: row.get(1).and_then(cql_value_to_string),
631 source_elapsed: row.get(2).and_then(cql_value_to_i32),
632 thread: row.get(3).and_then(cql_value_to_string),
633 })
634 .collect();
635
636 let session_row = &result.rows[0];
637 Ok(Some(TracingSession {
638 trace_id,
639 client: session_row.get(0).and_then(cql_value_to_string),
640 command: session_row.get(1).and_then(cql_value_to_string),
641 coordinator: session_row.get(2).and_then(cql_value_to_string),
642 duration: session_row.get(3).and_then(cql_value_to_i32),
643 parameters: HashMap::new(),
644 request: session_row.get(5).and_then(cql_value_to_string),
645 started_at: session_row.get(6).and_then(cql_value_to_string),
646 events,
647 }))
648 }
649
650 async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
651 let result = self
652 .execute_unpaged(
653 "SELECT keyspace_name, replication, durable_writes \
654 FROM system_schema.keyspaces",
655 )
656 .await?;
657
658 let mut keyspaces = Vec::new();
659 for row in &result.rows {
660 let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
661 let durable_writes = match row.get(2) {
662 Some(CqlValue::Boolean(b)) => *b,
663 _ => true,
664 };
665
666 keyspaces.push(KeyspaceMetadata {
667 name,
668 replication: HashMap::new(),
669 durable_writes,
670 });
671 }
672
673 Ok(keyspaces)
674 }
675
676 async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
677 let result = self
678 .execute_unpaged(&format!(
679 "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{}'",
680 keyspace.replace('\'', "''")
681 ))
682 .await?;
683
684 let mut tables = Vec::new();
685 for row in &result.rows {
686 let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
687
688 let col_result = self
689 .execute_unpaged(&format!(
690 "SELECT column_name, type, kind \
691 FROM system_schema.columns \
692 WHERE keyspace_name = '{}' AND table_name = '{}'",
693 keyspace.replace('\'', "''"),
694 table_name.replace('\'', "''")
695 ))
696 .await?;
697
698 let mut columns = Vec::new();
699 let mut partition_key = Vec::new();
700 let mut clustering_key = Vec::new();
701
702 for col_row in &col_result.rows {
703 let col_name = col_row
704 .get(0)
705 .and_then(cql_value_to_string)
706 .unwrap_or_default();
707 let col_type = col_row
708 .get(1)
709 .and_then(cql_value_to_string)
710 .unwrap_or_default();
711 let kind = col_row
712 .get(2)
713 .and_then(cql_value_to_string)
714 .unwrap_or_default();
715
716 columns.push(ColumnMetadata {
717 name: col_name.clone(),
718 type_name: col_type,
719 });
720
721 match kind.as_str() {
722 "partition_key" => partition_key.push(col_name),
723 "clustering" => clustering_key.push(col_name),
724 _ => {}
725 }
726 }
727
728 tables.push(TableMetadata {
729 keyspace: keyspace.to_string(),
730 name: table_name,
731 columns,
732 partition_key,
733 clustering_key,
734 });
735 }
736
737 Ok(tables)
738 }
739
740 async fn get_table_metadata(
741 &self,
742 keyspace: &str,
743 table: &str,
744 ) -> Result<Option<TableMetadata>> {
745 let tables = self.get_tables(keyspace).await?;
746 Ok(tables.into_iter().find(|t| t.name == table))
747 }
748
749 async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
750 let query = format!(
751 "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
752 keyspace.replace('\'', "''")
753 );
754 let result = self.execute_unpaged(&query).await?;
755 let udts = result
756 .rows
757 .iter()
758 .filter_map(|row| {
759 let name = row.get_by_name("type_name", &result.columns)?.to_string();
760 let field_names =
761 Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
762 let field_types =
763 Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
764 Some(UdtMetadata {
765 keyspace: keyspace.to_string(),
766 name,
767 field_names,
768 field_types,
769 })
770 })
771 .collect();
772 Ok(udts)
773 }
774
775 async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
776 let query = format!(
777 "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
778 keyspace.replace('\'', "''")
779 );
780 let result = self.execute_unpaged(&query).await?;
781 let functions = result
782 .rows
783 .iter()
784 .filter_map(|row| {
785 let name = row
786 .get_by_name("function_name", &result.columns)?
787 .to_string();
788 let argument_types = Self::extract_string_list_val(
789 row.get_by_name("argument_types", &result.columns),
790 );
791 let return_type = row
792 .get_by_name("return_type", &result.columns)
793 .map(|v| v.to_string())
794 .unwrap_or_default();
795 Some(FunctionMetadata {
796 keyspace: keyspace.to_string(),
797 name,
798 argument_types,
799 return_type,
800 })
801 })
802 .collect();
803 Ok(functions)
804 }
805
806 async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
807 let query = format!(
808 "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
809 keyspace.replace('\'', "''")
810 );
811 let result = self.execute_unpaged(&query).await?;
812 let aggregates = result
813 .rows
814 .iter()
815 .filter_map(|row| {
816 let name = row
817 .get_by_name("aggregate_name", &result.columns)?
818 .to_string();
819 let argument_types = Self::extract_string_list_val(
820 row.get_by_name("argument_types", &result.columns),
821 );
822 let return_type = row
823 .get_by_name("return_type", &result.columns)
824 .map(|v| v.to_string())
825 .unwrap_or_default();
826 Some(AggregateMetadata {
827 keyspace: keyspace.to_string(),
828 name,
829 argument_types,
830 return_type,
831 })
832 })
833 .collect();
834 Ok(aggregates)
835 }
836
837 async fn get_cluster_name(&self) -> Result<Option<String>> {
838 let result = self
839 .execute_unpaged("SELECT cluster_name FROM system.local")
840 .await?;
841 Ok(result
842 .rows
843 .first()
844 .and_then(|row| row.get(0))
845 .and_then(cql_value_to_string))
846 }
847
848 async fn get_cql_version(&self) -> Result<Option<String>> {
849 let result = self
850 .execute_unpaged("SELECT cql_version FROM system.local")
851 .await?;
852 Ok(result
853 .rows
854 .first()
855 .and_then(|row| row.get(0))
856 .and_then(cql_value_to_string))
857 }
858
859 async fn get_release_version(&self) -> Result<Option<String>> {
860 let result = self
861 .execute_unpaged("SELECT release_version FROM system.local")
862 .await?;
863 Ok(result
864 .rows
865 .first()
866 .and_then(|row| row.get(0))
867 .and_then(cql_value_to_string))
868 }
869
870 async fn get_scylla_version(&self) -> Result<Option<String>> {
871 let result = self
874 .execute_unpaged("SELECT scylla_version FROM system.local")
875 .await;
876 match result {
877 Ok(r) => Ok(r
878 .rows
879 .first()
880 .and_then(|row| row.get(0))
881 .and_then(cql_value_to_string)),
882 Err(_) => Ok(None), }
884 }
885
886 async fn is_connected(&self) -> bool {
887 self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
888 .await
889 .is_ok()
890 }
891}
892
893fn cql_value_to_string(v: &CqlValue) -> Option<String> {
895 match v {
896 CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
897 CqlValue::Inet(addr) => Some(addr.to_string()),
898 CqlValue::Null => None,
899 other => Some(other.to_string()),
900 }
901}
902
903fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
905 match v {
906 CqlValue::Int(i) => Some(*i),
907 CqlValue::BigInt(i) => Some(*i as i32),
908 CqlValue::SmallInt(i) => Some(*i as i32),
909 CqlValue::TinyInt(i) => Some(*i as i32),
910 _ => None,
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917
918 #[test]
919 fn convert_scylla_value_text() {
920 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
921 assert_eq!(v, CqlValue::Text("hello".to_string()));
922 }
923
924 #[test]
925 fn convert_scylla_value_int() {
926 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
927 assert_eq!(v, CqlValue::Int(42));
928 }
929
930 #[test]
931 fn convert_scylla_value_boolean() {
932 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
933 assert_eq!(v, CqlValue::Boolean(true));
934 }
935
936 #[test]
937 fn convert_scylla_value_null() {
938 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
939 assert_eq!(v, CqlValue::Null);
940 }
941
942 #[test]
943 fn convert_scylla_value_list() {
944 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
945 ScyllaCqlValue::Int(1),
946 ScyllaCqlValue::Int(2),
947 ]));
948 assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
949 }
950
951 #[test]
952 fn convert_scylla_value_uuid() {
953 let id = Uuid::nil();
954 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
955 assert_eq!(v, CqlValue::Uuid(id));
956 }
957
958 #[test]
959 fn convert_scylla_value_blob() {
960 let v =
961 ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
962 assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
963 }
964
965 #[test]
966 fn convert_scylla_value_float() {
967 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
968 assert_eq!(v, CqlValue::Float(1.5));
969 }
970
971 #[test]
972 fn convert_scylla_value_double() {
973 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
974 assert_eq!(v, CqlValue::Double(1.5));
975 }
976
977 #[test]
978 fn convert_scylla_value_map() {
979 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
980 ScyllaCqlValue::Text("key".to_string()),
981 ScyllaCqlValue::Int(42),
982 )]));
983 assert_eq!(
984 v,
985 CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
986 );
987 }
988
989 #[test]
990 fn convert_scylla_value_set() {
991 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
992 ScyllaCqlValue::Int(1),
993 ScyllaCqlValue::Int(2),
994 ]));
995 assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
996 }
997
998 #[test]
999 fn convert_scylla_value_udt() {
1000 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1001 keyspace: "ks".to_string(),
1002 name: "my_type".to_string(),
1003 fields: vec![
1004 ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1005 ("f2".to_string(), None),
1006 ],
1007 });
1008 assert_eq!(
1009 v,
1010 CqlValue::UserDefinedType {
1011 keyspace: "ks".to_string(),
1012 type_name: "my_type".to_string(),
1013 fields: vec![
1014 ("f1".to_string(), Some(CqlValue::Int(1))),
1015 ("f2".to_string(), None),
1016 ],
1017 }
1018 );
1019 }
1020
1021 #[test]
1022 fn to_scylla_consistency_mapping() {
1023 use scylla::statement::Consistency as SC;
1024 assert!(matches!(
1025 ScyllaDriver::to_scylla_consistency(Consistency::One),
1026 SC::One
1027 ));
1028 assert!(matches!(
1029 ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1030 SC::Quorum
1031 ));
1032 assert!(matches!(
1033 ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1034 SC::LocalQuorum
1035 ));
1036 assert!(matches!(
1037 ScyllaDriver::to_scylla_consistency(Consistency::All),
1038 SC::All
1039 ));
1040 }
1041
1042 #[test]
1043 fn to_scylla_serial_consistency_mapping() {
1044 use scylla::statement::SerialConsistency as SC;
1045 assert!(matches!(
1046 ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1047 Some(SC::Serial)
1048 ));
1049 assert!(matches!(
1050 ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1051 Some(SC::LocalSerial)
1052 ));
1053 assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1054 }
1055
1056 #[test]
1057 fn cql_value_to_string_helper() {
1058 assert_eq!(
1059 cql_value_to_string(&CqlValue::Text("hello".to_string())),
1060 Some("hello".to_string())
1061 );
1062 assert_eq!(
1063 cql_value_to_string(&CqlValue::Int(42)),
1064 Some("42".to_string())
1065 );
1066 assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1067 }
1068
1069 #[test]
1070 fn cql_value_to_i32_helper() {
1071 assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1072 assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1073 assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1074 }
1075}