prc: fix possible stack overflow in StreamReader

Fixes #754
This commit is contained in:
rdb 2019-10-08 11:27:35 +02:00
parent 3c9591cfbf
commit c6990b9f9b
3 changed files with 63 additions and 16 deletions

View File

@ -26,14 +26,20 @@ get_string() {
// First, get the length of the string
size_t size = get_uint16();
std::string result(size, 0);
if (size == 0) {
return string();
return result;
}
char *buffer = (char *)alloca(size);
_in->read(buffer, size);
_in->read(&result[0], size);
size_t read_bytes = _in->gcount();
return string(buffer, read_bytes);
if (read_bytes == size) {
return result;
} else {
return result.substr(0, read_bytes);
}
}
/**
@ -82,17 +88,17 @@ string StreamReader::
get_fixed_string(size_t size) {
nassertr(!_in->eof() && !_in->fail(), string());
std::string result(size, 0);
if (size == 0) {
return string();
return result;
}
char *buffer = (char *)alloca(size);
_in->read(buffer, size);
_in->read(&result[0], size);
size_t read_bytes = _in->gcount();
string result(buffer, read_bytes);
result.resize(read_bytes);
size_t zero_byte = result.find('\0');
return result.substr(0, zero_byte);
return result.substr(0, std::min(zero_byte, read_bytes));
}
/**

View File

@ -21,14 +21,20 @@
*/
PyObject *Extension<StreamReader>::
extract_bytes(size_t size) {
unsigned char *buffer = (unsigned char *)alloca(size);
size_t read_bytes = _this->extract_bytes(buffer, size);
std::istream *in = _this->get_istream();
if (in->eof() || in->fail() || size == 0) {
return PyBytes_FromStringAndSize(nullptr, 0);
}
#if PY_MAJOR_VERSION >= 3
return PyBytes_FromStringAndSize((char *)buffer, read_bytes);
#else
return PyString_FromStringAndSize((char *)buffer, read_bytes);
#endif
PyObject *bytes = PyBytes_FromStringAndSize(nullptr, size);
in->read(PyBytes_AS_STRING(bytes), size);
size_t read_bytes = in->gcount();
if (read_bytes == size || _PyBytes_Resize(&bytes, read_bytes) == 0) {
return bytes;
} else {
return nullptr;
}
}
/**

View File

@ -155,3 +155,38 @@ def test_streamreader_readline():
stream = StringStream(b'\x00\x00')
reader = StreamReader(stream, False)
assert reader.readline() == b'\x00\x00'
def test_streamreader_extract_bytes():
# Empty bytes
stream = StringStream(b'')
reader = StreamReader(stream, False)
assert reader.extract_bytes(10) == b''
# Small bytes object, small reads
stream = StringStream(b'abcd')
reader = StreamReader(stream, False)
assert reader.extract_bytes(2) == b'ab'
assert reader.extract_bytes(2) == b'cd'
assert reader.extract_bytes(2) == b''
# Embedded null bytes
stream = StringStream(b'a\x00b\x00c')
reader = StreamReader(stream, False)
assert reader.extract_bytes(5) == b'a\x00b\x00c'
# Not enough data in stream to fill buffer
stream = StringStream(b'abcdefghijklmnop')
reader = StreamReader(stream, False)
assert reader.extract_bytes(10) == b'abcdefghij'
assert reader.extract_bytes(10) == b'klmnop'
assert reader.extract_bytes(10) == b''
# Read of 0 bytes
stream = StringStream(b'abcd')
reader = StreamReader(stream, False)
assert reader.extract_bytes(0) == b''
assert reader.extract_bytes(0) == b''
# Very large read (8 MiB buffer allocation)
assert reader.extract_bytes(8 * 1024 * 1024) == b'abcd'