refactor: avoid serialization logic duplication

This commit is contained in:
Blayne Chard 2021-11-02 11:40:10 +13:00 committed by Dmitry Dygalo
parent 734445856b
commit 045d4aaab5
2 changed files with 17 additions and 34 deletions

View File

@ -2,7 +2,7 @@ use pyo3::{
exceptions,
ffi::{
PyDictObject, PyFloat_AS_DOUBLE, PyList_GET_ITEM, PyList_GET_SIZE, PyLong_AsLongLong,
Py_TYPE, PyTuple_GET_ITEM, PyTuple_GET_SIZE
PyTuple_GET_ITEM, PyTuple_GET_SIZE, Py_TYPE,
},
prelude::*,
types::PyAny,
@ -149,11 +149,16 @@ impl Serialize for SerializePyObject {
map.end()
}
}
ObjectType::List => {
ObjectType::Tuple | ObjectType::List => {
if self.recursion_depth == RECURSION_LIMIT {
return Err(ser::Error::custom("Recursion limit reached"));
}
let length = unsafe { PyList_GET_SIZE(self.object) } as usize;
let length = match self.object_type {
ObjectType::Tuple => unsafe { PyTuple_GET_SIZE(self.object) as usize },
ObjectType::List => unsafe { PyList_GET_SIZE(self.object) as usize },
_ => return Err(ser::Error::custom("Object is not a list or tuple")),
};
if length == 0 {
serializer.serialize_seq(Some(0))?.end()
} else {
@ -161,35 +166,13 @@ impl Serialize for SerializePyObject {
let mut ob_type = ObjectType::Str;
let mut sequence = serializer.serialize_seq(Some(length))?;
for i in 0..length {
let elem = unsafe { PyList_GET_ITEM(self.object, i as isize) };
let current_ob_type = unsafe { Py_TYPE(elem) };
if current_ob_type != type_ptr {
type_ptr = current_ob_type;
ob_type = get_object_type(current_ob_type);
}
#[allow(clippy::integer_arithmetic)]
sequence.serialize_element(&SerializePyObject::with_obtype(
elem,
ob_type.clone(),
self.recursion_depth + 1,
))?;
}
sequence.end()
}
}
ObjectType::Tuple => {
if self.recursion_depth == RECURSION_LIMIT {
return Err(ser::Error::custom("Recursion limit reached"));
}
let length = unsafe { PyTuple_GET_SIZE(self.object) } as usize;
if length == 0 {
serializer.serialize_seq(Some(0))?.end()
} else {
let mut type_ptr = std::ptr::null_mut();
let mut ob_type = ObjectType::Str;
let mut sequence = serializer.serialize_seq(Some(length))?;
for i in 0..length {
let elem = unsafe { PyTuple_GET_ITEM(self.object, i as isize) };
let elem = match self.object_type {
ObjectType::Tuple => unsafe {
PyTuple_GET_ITEM(self.object, i as isize)
},
ObjectType::List => unsafe { PyList_GET_ITEM(self.object, i as isize) },
_ => return Err(ser::Error::custom("Object is not a list or tuple")),
};
let current_ob_type = unsafe { Py_TYPE(elem) };
if current_ob_type != type_ptr {
type_ptr = current_ob_type;

View File

@ -1,6 +1,6 @@
use pyo3::ffi::{
PyDict_New, PyFloat_FromDouble, PyList_New, PyLong_FromLongLong, PyTypeObject, PyUnicode_New,
Py_None, Py_TYPE, Py_True, PyTuple_New
PyDict_New, PyFloat_FromDouble, PyList_New, PyLong_FromLongLong, PyTuple_New, PyTypeObject,
PyUnicode_New, Py_None, Py_TYPE, Py_True,
};
use std::sync::Once;