db.redis: add TLS support (#25374)

This commit is contained in:
David Legrand 2025-09-23 21:38:06 +02:00 committed by GitHub
parent 2015035218
commit 0df4435635
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@
module redis
import net
import net.ssl
import strings
// RedisValue represents all possible RESP (Redis Serialization Protocol) data types
@ -19,7 +20,9 @@ pub struct DB {
pub mut:
version int // RESP protocol version
mut:
conn &net.TcpConn = unsafe { nil } // TCP connection to Redis
conn &net.TcpConn = unsafe { nil } // TCP connection to Redis
ssl_conn &ssl.SSLConn = unsafe { nil } // SSL connection to Redis
tls bool
// Pre-allocated buffers to reduce memory allocations
cmd_buf []u8 // Buffer for building commands
@ -36,18 +39,27 @@ pub mut:
host string = '127.0.0.1' // Redis server host
port u16 = 6379 // Redis server port
password string // Redis server password (optional)
tls bool // Enable TLS/SSL connection
version int = 2 // RESP protocol version (default: v2)
}
// connect establishes a connection to a Redis server
pub fn connect(config Config) !DB {
mut db := DB{
conn: net.dial_tcp('${config.host}:${config.port}')!
version: config.version
tls: config.tls
cmd_buf: []u8{cap: cmd_buf_pre_allocate_len}
resp_buf: []u8{cap: resp_buf_pre_allocate_len}
}
if config.tls {
mut ssl_conn := ssl.new_ssl_conn(ssl.SSLConnectConfig{ validate: false })!
ssl_conn.dial(config.host, int(config.port))!
db.ssl_conn = ssl_conn
} else {
db.conn = net.dial_tcp('${config.host}:${config.port}')!
}
// Authenticate if password is provided
if config.password.len > 0 {
db.auth(config.password)!
@ -58,7 +70,36 @@ pub fn connect(config Config) !DB {
// close terminates the connection to Redis server
pub fn (mut db DB) close() ! {
db.conn.close()!
if db.tls {
db.ssl_conn.close()!
} else {
db.conn.close()!
}
}
// Helper methods for TLS abstraction
fn (mut db DB) write_data(data []u8) ! {
if db.tls {
db.ssl_conn.write(data)!
} else {
db.conn.write(data)!
}
}
fn (mut db DB) read_data(mut buf []u8) !int {
if db.tls {
return db.ssl_conn.read(mut buf)
} else {
return db.conn.read(mut buf)
}
}
fn (mut db DB) read_ptr_data(ptr &u8, len int) !int {
if db.tls {
return db.ssl_conn.socket_read_into_ptr(ptr, len)!
} else {
return db.conn.read_ptr(ptr, len)!
}
}
// auth sends an AUTH command to the server with the given password.
@ -78,8 +119,7 @@ pub fn (mut db DB) auth(password string) ! {
// ping sends a PING command to verify server responsiveness
pub fn (mut db DB) ping() !string {
db.conn.write_string('*1\r\n$4\r\nPING\r\n')!
return db.read_response()! as string
return db.cmd('PING')! as string
}
// del deletes a `key`
@ -92,7 +132,7 @@ pub fn (mut db DB) del(key string) !i64 {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
// read resp
return db.read_response()! as i64
@ -121,7 +161,7 @@ pub fn (mut db DB) set[T](key string, value T) !string {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
return db.read_response()! as string
}
return ''
@ -137,7 +177,7 @@ pub fn (mut db DB) get[T](key string) !T {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
resp := db.read_response()!
match resp {
[]u8 {
@ -171,7 +211,7 @@ pub fn (mut db DB) incr(key string) !i64 {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
// read resp
return db.read_response()! as i64
@ -189,7 +229,7 @@ pub fn (mut db DB) decr(key string) !i64 {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
// read resp
return db.read_response()! as i64
@ -222,7 +262,7 @@ pub fn (mut db DB) hset[T](key string, m map[string]T) !int {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
return int(db.read_response()! as i64)
}
return 0
@ -239,7 +279,7 @@ pub fn (mut db DB) hget[T](key string, m_key string) !T {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
resp := db.read_response()! as []u8
$if T is string {
return resp.bytestr()
@ -266,7 +306,7 @@ pub fn (mut db DB) hgetall[T](key string) !map[string]T {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
resp := db.read_response()!
match resp {
[]RedisValue {
@ -315,7 +355,7 @@ pub fn (mut db DB) expire(key string, seconds int) !bool {
db.pipeline_buffer << db.cmd_buf
db.pipeline_cmd_count++
} else {
db.conn.write(db.cmd_buf)!
db.write_data(db.cmd_buf)!
// read resp
resp := db.read_response()! as i64
@ -331,7 +371,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue {
db.resp_buf.clear()
for {
bytes_read := db.conn.read(mut chunk) or {
bytes_read := db.read_data(mut chunk) or {
return error('`read_response_bulk_string()`: connection error ${err}')
}
if bytes_read == 0 {
@ -355,7 +395,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue {
if data_length == 0 {
mut terminator := []u8{len: 2}
db.conn.read(mut terminator)!
db.read_data(mut terminator)!
if terminator[0] != `\r` || terminator[1] != `\n` {
return error('invalid terminator for empty string')
}
@ -373,7 +413,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue {
chunk_size := if remaining > 1 { 1 } else { remaining }
mut chunk_ptr := unsafe { &data_buf[total_read] }
bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)!
bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)!
total_read += bytes_read
if bytes_read == 0 && total_read < data_buf.len {
@ -400,7 +440,7 @@ fn (mut db DB) read_response_i64() !i64 {
chunk_size := if remaining > 1 { 1 } else { remaining }
mut chunk_ptr := unsafe { &db.resp_buf[total_read] }
bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)!
bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)!
total_read += bytes_read
if total_read > 2 {
@ -427,7 +467,7 @@ fn (mut db DB) read_response_simple_string() !string {
chunk_size := if remaining > 1 { 1 } else { remaining }
mut chunk_ptr := unsafe { &db.resp_buf[total_read] }
bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)!
bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)!
total_read += bytes_read
if total_read > 2 {
@ -449,7 +489,7 @@ fn (mut db DB) read_response_array() !RedisValue {
db.resp_buf.clear()
for {
bytes_read := db.conn.read(mut chunk) or {
bytes_read := db.read_data(mut chunk) or {
return error('`read_response_array()`: connection error: ${err}')
}
if bytes_read == 0 {
@ -492,7 +532,7 @@ fn (mut db DB) read_response_array() !RedisValue {
fn (mut db DB) read_response() !RedisValue {
db.resp_buf.clear()
unsafe { db.resp_buf.grow_len(1) }
read_len := db.conn.read(mut db.resp_buf)!
read_len := db.read_data(mut db.resp_buf)!
if read_len != 1 {
return error('`read_response()`: empty response from server')
}
@ -532,7 +572,7 @@ pub fn (mut db DB) cmd(cmd ...string) !RedisValue {
db.pipeline_buffer << unsafe { sb.reuse_as_plain_u8_array() }
db.pipeline_cmd_count++
} else {
unsafe { db.conn.write(sb.reuse_as_plain_u8_array())! }
db.write_data(unsafe { sb.reuse_as_plain_u8_array() })!
return db.read_response()!
}
return RedisNull{}
@ -557,7 +597,7 @@ pub fn (mut db DB) pipeline_execute() ![]RedisValue {
return []RedisValue{}
}
db.conn.write(db.pipeline_buffer)!
db.write_data(db.pipeline_buffer)!
mut results := []RedisValue{cap: db.pipeline_cmd_count}
for _ in 0 .. db.pipeline_cmd_count {