1use anyhow::{bail, Result};
8
9use crate::config::MergedConfig;
10use crate::driver::types::CqlValue;
11use crate::driver::{
12 AggregateMetadata, ConnectionConfig, Consistency, CqlDriver, CqlResult, FunctionMetadata,
13 KeyspaceMetadata, PreparedId, ScyllaDriver, SslConfig, TableMetadata, TracingSession,
14 UdtMetadata,
15};
16
17pub struct CqlSession {
19 driver: ScyllaDriver,
20 current_keyspace: Option<String>,
22 pub connection_display: String,
24 pub cluster_name: Option<String>,
26 pub cql_version: Option<String>,
28 pub release_version: Option<String>,
30 pub scylla_version: Option<String>,
32}
33
34impl CqlSession {
35 pub async fn connect(config: &MergedConfig) -> Result<Self> {
37 let ssl_config = if config.ssl {
38 Some(SslConfig {
39 certfile: config.cqlshrc.ssl.certfile.clone(),
40 validate: config.cqlshrc.ssl.validate.unwrap_or(false),
41 userkey: config.cqlshrc.ssl.userkey.clone(),
42 usercert: config.cqlshrc.ssl.usercert.clone(),
43 host_certfiles: config.cqlshrc.certfiles.clone(),
44 })
45 } else {
46 None
47 };
48
49 let conn_config = ConnectionConfig {
50 host: config.host.clone(),
51 port: config.port,
52 username: config.username.clone(),
53 password: config.password.clone(),
54 keyspace: config.keyspace.clone(),
55 connect_timeout: config.connect_timeout,
56 request_timeout: config.request_timeout,
57 ssl: config.ssl,
58 ssl_config,
59 protocol_version: config.protocol_version,
60 };
61
62 let driver = ScyllaDriver::connect(&conn_config).await?;
63
64 let connection_display = format!("{}:{}", config.host, config.port);
65
66 let cluster_name = driver.get_cluster_name().await.ok().flatten();
68 let cql_version = driver.get_cql_version().await.ok().flatten();
69 let release_version = driver.get_release_version().await.ok().flatten();
70 let scylla_version = driver.get_scylla_version().await.ok().flatten();
71
72 if let Some(cl_str) = &config.consistency_level {
74 if let Some(cl) = Consistency::from_str_cql(cl_str) {
75 driver.set_consistency(cl);
76 }
77 }
78
79 if let Some(scl_str) = &config.serial_consistency_level {
81 if let Some(scl) = Consistency::from_str_cql(scl_str) {
82 driver.set_serial_consistency(Some(scl));
83 }
84 }
85
86 Ok(CqlSession {
87 driver,
88 current_keyspace: config.keyspace.clone(),
89 connection_display,
90 cluster_name,
91 cql_version,
92 release_version,
93 scylla_version,
94 })
95 }
96
97 pub async fn execute(&mut self, query: &str) -> Result<CqlResult> {
99 let trimmed = query.trim();
100
101 if let Some(keyspace) = parse_use_command(trimmed) {
103 self.use_keyspace(&keyspace).await?;
104 return Ok(CqlResult::empty());
105 }
106
107 self.driver.execute_unpaged(query).await
108 }
109
110 pub async fn execute_query(&self, query: &str) -> Result<CqlResult> {
115 self.driver.execute_unpaged(query).await
116 }
117
118 pub async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
120 self.driver.execute_paged(query, page_size).await
121 }
122
123 pub async fn prepare(&self, query: &str) -> Result<PreparedId> {
125 self.driver.prepare(query).await
126 }
127
128 pub async fn execute_prepared(
130 &self,
131 id: &PreparedId,
132 values: &[CqlValue],
133 ) -> Result<CqlResult> {
134 self.driver.execute_prepared(id, values).await
135 }
136
137 pub async fn use_keyspace(&mut self, keyspace: &str) -> Result<()> {
139 self.driver.use_keyspace(keyspace).await?;
140 self.current_keyspace = Some(keyspace.to_string());
141 Ok(())
142 }
143
144 pub fn current_keyspace(&self) -> Option<&str> {
146 self.current_keyspace.as_deref()
147 }
148
149 pub fn get_consistency(&self) -> Consistency {
151 self.driver.get_consistency()
152 }
153
154 pub fn set_consistency(&self, consistency: Consistency) {
156 self.driver.set_consistency(consistency);
157 }
158
159 pub fn set_consistency_str(&self, level: &str) -> Result<()> {
161 let consistency = Consistency::from_str_cql(level)
162 .ok_or_else(|| anyhow::anyhow!("invalid consistency level: {level}"))?;
163 self.driver.set_consistency(consistency);
164 Ok(())
165 }
166
167 pub fn get_serial_consistency(&self) -> Option<Consistency> {
169 self.driver.get_serial_consistency()
170 }
171
172 pub fn set_serial_consistency(&self, consistency: Option<Consistency>) {
174 self.driver.set_serial_consistency(consistency);
175 }
176
177 pub fn set_serial_consistency_str(&self, level: &str) -> Result<()> {
179 let consistency = Consistency::from_str_cql(level)
180 .ok_or_else(|| anyhow::anyhow!("invalid serial consistency level: {level}"))?;
181 match consistency {
182 Consistency::Serial | Consistency::LocalSerial => {
183 self.driver.set_serial_consistency(Some(consistency));
184 Ok(())
185 }
186 _ => bail!("serial consistency must be SERIAL or LOCAL_SERIAL, got {level}"),
187 }
188 }
189
190 pub fn set_tracing(&self, enabled: bool) {
192 self.driver.set_tracing(enabled);
193 }
194
195 pub fn is_tracing_enabled(&self) -> bool {
197 self.driver.is_tracing_enabled()
198 }
199
200 pub fn last_trace_id(&self) -> Option<uuid::Uuid> {
202 self.driver.last_trace_id()
203 }
204
205 pub async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>> {
207 self.driver.get_trace_session(trace_id).await
208 }
209
210 pub async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
212 self.driver.get_keyspaces().await
213 }
214
215 pub async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
217 self.driver.get_tables(keyspace).await
218 }
219
220 pub async fn get_table_metadata(
222 &self,
223 keyspace: &str,
224 table: &str,
225 ) -> Result<Option<TableMetadata>> {
226 self.driver.get_table_metadata(keyspace, table).await
227 }
228
229 pub async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
231 self.driver.get_udts(keyspace).await
232 }
233
234 pub async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
236 self.driver.get_functions(keyspace).await
237 }
238
239 pub async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
241 self.driver.get_aggregates(keyspace).await
242 }
243
244 pub async fn is_connected(&self) -> bool {
246 self.driver.is_connected().await
247 }
248}
249
250fn parse_use_command(query: &str) -> Option<String> {
252 let upper = query.to_uppercase();
253 let trimmed = upper.trim().trim_end_matches(';').trim();
254
255 if !trimmed.starts_with("USE ") {
256 return None;
257 }
258
259 let keyspace = query
260 .trim()
261 .trim_end_matches(';')
262 .trim()
263 .strip_prefix("USE ")
264 .or_else(|| {
265 query
266 .trim()
267 .trim_end_matches(';')
268 .trim()
269 .strip_prefix("use ")
270 })
271 .map(|s| s.trim())?;
272
273 let keyspace = if (keyspace.starts_with('"') && keyspace.ends_with('"'))
275 || (keyspace.starts_with('\'') && keyspace.ends_with('\''))
276 {
277 &keyspace[1..keyspace.len() - 1]
278 } else {
279 keyspace
280 };
281
282 if keyspace.is_empty() {
283 None
284 } else {
285 Some(keyspace.to_string())
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn parse_use_simple() {
295 assert_eq!(
296 parse_use_command("USE my_keyspace"),
297 Some("my_keyspace".to_string())
298 );
299 }
300
301 #[test]
302 fn parse_use_semicolon() {
303 assert_eq!(
304 parse_use_command("USE my_keyspace;"),
305 Some("my_keyspace".to_string())
306 );
307 }
308
309 #[test]
310 fn parse_use_lowercase() {
311 assert_eq!(
312 parse_use_command("use test_ks"),
313 Some("test_ks".to_string())
314 );
315 }
316
317 #[test]
318 fn parse_use_quoted() {
319 assert_eq!(
320 parse_use_command("USE \"MyKeyspace\""),
321 Some("MyKeyspace".to_string())
322 );
323 }
324
325 #[test]
326 fn parse_use_single_quoted() {
327 assert_eq!(parse_use_command("USE 'my_ks'"), Some("my_ks".to_string()));
328 }
329
330 #[test]
331 fn parse_use_with_whitespace() {
332 assert_eq!(
333 parse_use_command(" USE my_keyspace ; "),
334 Some("my_keyspace".to_string())
335 );
336 }
337
338 #[test]
339 fn parse_not_use_command() {
340 assert_eq!(parse_use_command("SELECT * FROM table"), None);
341 assert_eq!(parse_use_command("INSERT INTO users"), None);
342 }
343
344 #[test]
345 fn parse_use_empty() {
346 assert_eq!(parse_use_command("USE "), None);
347 assert_eq!(parse_use_command("USE ;"), None);
348 }
349}