diff --git a/vlib/db/redis/redis.v b/vlib/db/redis/redis.v index ba664416ee..32802c1ba2 100644 --- a/vlib/db/redis/redis.v +++ b/vlib/db/redis/redis.v @@ -3,6 +3,7 @@ module redis import net +import net.ssl import strings // RedisValue represents all possible RESP (Redis Serialization Protocol) data types @@ -19,7 +20,9 @@ pub struct DB { pub mut: version int // RESP protocol version mut: - conn &net.TcpConn = unsafe { nil } // TCP connection to Redis + conn &net.TcpConn = unsafe { nil } // TCP connection to Redis + ssl_conn &ssl.SSLConn = unsafe { nil } // SSL connection to Redis + tls bool // Pre-allocated buffers to reduce memory allocations cmd_buf []u8 // Buffer for building commands @@ -36,18 +39,27 @@ pub mut: host string = '127.0.0.1' // Redis server host port u16 = 6379 // Redis server port password string // Redis server password (optional) + tls bool // Enable TLS/SSL connection version int = 2 // RESP protocol version (default: v2) } // connect establishes a connection to a Redis server pub fn connect(config Config) !DB { mut db := DB{ - conn: net.dial_tcp('${config.host}:${config.port}')! version: config.version + tls: config.tls cmd_buf: []u8{cap: cmd_buf_pre_allocate_len} resp_buf: []u8{cap: resp_buf_pre_allocate_len} } + if config.tls { + mut ssl_conn := ssl.new_ssl_conn(ssl.SSLConnectConfig{ validate: false })! + ssl_conn.dial(config.host, int(config.port))! + db.ssl_conn = ssl_conn + } else { + db.conn = net.dial_tcp('${config.host}:${config.port}')! + } + // Authenticate if password is provided if config.password.len > 0 { db.auth(config.password)! @@ -58,7 +70,36 @@ pub fn connect(config Config) !DB { // close terminates the connection to Redis server pub fn (mut db DB) close() ! { - db.conn.close()! + if db.tls { + db.ssl_conn.close()! + } else { + db.conn.close()! + } +} + +// Helper methods for TLS abstraction +fn (mut db DB) write_data(data []u8) ! { + if db.tls { + db.ssl_conn.write(data)! + } else { + db.conn.write(data)! + } +} + +fn (mut db DB) read_data(mut buf []u8) !int { + if db.tls { + return db.ssl_conn.read(mut buf) + } else { + return db.conn.read(mut buf) + } +} + +fn (mut db DB) read_ptr_data(ptr &u8, len int) !int { + if db.tls { + return db.ssl_conn.socket_read_into_ptr(ptr, len)! + } else { + return db.conn.read_ptr(ptr, len)! + } } // auth sends an AUTH command to the server with the given password. @@ -78,8 +119,7 @@ pub fn (mut db DB) auth(password string) ! { // ping sends a PING command to verify server responsiveness pub fn (mut db DB) ping() !string { - db.conn.write_string('*1\r\n$4\r\nPING\r\n')! - return db.read_response()! as string + return db.cmd('PING')! as string } // del deletes a `key` @@ -92,7 +132,7 @@ pub fn (mut db DB) del(key string) !i64 { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! // read resp return db.read_response()! as i64 @@ -121,7 +161,7 @@ pub fn (mut db DB) set[T](key string, value T) !string { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! return db.read_response()! as string } return '' @@ -137,7 +177,7 @@ pub fn (mut db DB) get[T](key string) !T { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! resp := db.read_response()! match resp { []u8 { @@ -171,7 +211,7 @@ pub fn (mut db DB) incr(key string) !i64 { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! // read resp return db.read_response()! as i64 @@ -189,7 +229,7 @@ pub fn (mut db DB) decr(key string) !i64 { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! // read resp return db.read_response()! as i64 @@ -222,7 +262,7 @@ pub fn (mut db DB) hset[T](key string, m map[string]T) !int { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! return int(db.read_response()! as i64) } return 0 @@ -239,7 +279,7 @@ pub fn (mut db DB) hget[T](key string, m_key string) !T { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! resp := db.read_response()! as []u8 $if T is string { return resp.bytestr() @@ -266,7 +306,7 @@ pub fn (mut db DB) hgetall[T](key string) !map[string]T { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! resp := db.read_response()! match resp { []RedisValue { @@ -315,7 +355,7 @@ pub fn (mut db DB) expire(key string, seconds int) !bool { db.pipeline_buffer << db.cmd_buf db.pipeline_cmd_count++ } else { - db.conn.write(db.cmd_buf)! + db.write_data(db.cmd_buf)! // read resp resp := db.read_response()! as i64 @@ -331,7 +371,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue { db.resp_buf.clear() for { - bytes_read := db.conn.read(mut chunk) or { + bytes_read := db.read_data(mut chunk) or { return error('`read_response_bulk_string()`: connection error ${err}') } if bytes_read == 0 { @@ -355,7 +395,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue { if data_length == 0 { mut terminator := []u8{len: 2} - db.conn.read(mut terminator)! + db.read_data(mut terminator)! if terminator[0] != `\r` || terminator[1] != `\n` { return error('invalid terminator for empty string') } @@ -373,7 +413,7 @@ fn (mut db DB) read_response_bulk_string() !RedisValue { chunk_size := if remaining > 1 { 1 } else { remaining } mut chunk_ptr := unsafe { &data_buf[total_read] } - bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)! + bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)! total_read += bytes_read if bytes_read == 0 && total_read < data_buf.len { @@ -400,7 +440,7 @@ fn (mut db DB) read_response_i64() !i64 { chunk_size := if remaining > 1 { 1 } else { remaining } mut chunk_ptr := unsafe { &db.resp_buf[total_read] } - bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)! + bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)! total_read += bytes_read if total_read > 2 { @@ -427,7 +467,7 @@ fn (mut db DB) read_response_simple_string() !string { chunk_size := if remaining > 1 { 1 } else { remaining } mut chunk_ptr := unsafe { &db.resp_buf[total_read] } - bytes_read := db.conn.read_ptr(chunk_ptr, chunk_size)! + bytes_read := db.read_ptr_data(chunk_ptr, chunk_size)! total_read += bytes_read if total_read > 2 { @@ -449,7 +489,7 @@ fn (mut db DB) read_response_array() !RedisValue { db.resp_buf.clear() for { - bytes_read := db.conn.read(mut chunk) or { + bytes_read := db.read_data(mut chunk) or { return error('`read_response_array()`: connection error: ${err}') } if bytes_read == 0 { @@ -492,7 +532,7 @@ fn (mut db DB) read_response_array() !RedisValue { fn (mut db DB) read_response() !RedisValue { db.resp_buf.clear() unsafe { db.resp_buf.grow_len(1) } - read_len := db.conn.read(mut db.resp_buf)! + read_len := db.read_data(mut db.resp_buf)! if read_len != 1 { return error('`read_response()`: empty response from server') } @@ -532,7 +572,7 @@ pub fn (mut db DB) cmd(cmd ...string) !RedisValue { db.pipeline_buffer << unsafe { sb.reuse_as_plain_u8_array() } db.pipeline_cmd_count++ } else { - unsafe { db.conn.write(sb.reuse_as_plain_u8_array())! } + db.write_data(unsafe { sb.reuse_as_plain_u8_array() })! return db.read_response()! } return RedisNull{} @@ -557,7 +597,7 @@ pub fn (mut db DB) pipeline_execute() ![]RedisValue { return []RedisValue{} } - db.conn.write(db.pipeline_buffer)! + db.write_data(db.pipeline_buffer)! mut results := []RedisValue{cap: db.pipeline_cmd_count} for _ in 0 .. db.pipeline_cmd_count {