use pyo3::{ exceptions, ffi::{ PyDictObject, PyFloat_AS_DOUBLE, PyList_GET_ITEM, PyList_GET_SIZE, PyLong_AsLongLong, PyObject_GetAttr, PyTuple_GET_ITEM, PyTuple_GET_SIZE, PyUnicode_AsUTF8AndSize, Py_DECREF, Py_TYPE, }, prelude::*, types::PyAny, }; use serde::{ ser::{self, Serialize, SerializeMap, SerializeSeq}, Serializer, }; use crate::{ffi, types}; use std::ffi::CStr; pub const RECURSION_LIMIT: u8 = 255; #[derive(Clone)] pub enum ObjectType { Str, Int, Bool, None, Float, List, Dict, Tuple, Enum, Unknown(String), } pub(crate) struct SerializePyObject { object: *mut pyo3::ffi::PyObject, object_type: ObjectType, recursion_depth: u8, } impl SerializePyObject { #[inline] pub fn new(object: *mut pyo3::ffi::PyObject, recursion_depth: u8) -> Self { SerializePyObject { object, object_type: get_object_type_from_object(object), recursion_depth, } } #[inline] pub const fn with_obtype( object: *mut pyo3::ffi::PyObject, object_type: ObjectType, recursion_depth: u8, ) -> Self { SerializePyObject { object, object_type, recursion_depth, } } } #[inline] fn is_enum_subclass(object_type: *mut pyo3::ffi::PyTypeObject) -> bool { unsafe { (*(object_type.cast::())).ob_type == types::ENUM_TYPE } } #[inline] fn is_dict_subclass(object_type: *mut pyo3::ffi::PyTypeObject) -> bool { // traverse the object's inheritance chain to check if it's a dict subclass let mut current_type = object_type; while !current_type.is_null() { if current_type == unsafe { types::DICT_TYPE } { return true; } current_type = unsafe { (*current_type).tp_base } } false } fn get_object_type_from_object(object: *mut pyo3::ffi::PyObject) -> ObjectType { unsafe { let object_type = Py_TYPE(object); get_object_type(object_type) } } fn get_type_name(object_type: *mut pyo3::ffi::PyTypeObject) -> std::borrow::Cow<'static, str> { unsafe { CStr::from_ptr((*object_type).tp_name).to_string_lossy() } } #[inline] fn check_type_is_str(object: *mut pyo3::ffi::PyObject) -> Result<(), E> { let object_type = unsafe { Py_TYPE(object) }; if object_type != unsafe { types::STR_TYPE } { return Err(ser::Error::custom(format!( "Dict key must be str. Got '{}'", get_type_name(object_type) ))); } Ok(()) } #[inline] pub fn get_object_type(object_type: *mut pyo3::ffi::PyTypeObject) -> ObjectType { if object_type == unsafe { types::STR_TYPE } { ObjectType::Str } else if object_type == unsafe { types::FLOAT_TYPE } { ObjectType::Float } else if object_type == unsafe { types::BOOL_TYPE } { ObjectType::Bool } else if object_type == unsafe { types::INT_TYPE } { ObjectType::Int } else if object_type == unsafe { types::NONE_TYPE } { ObjectType::None } else if object_type == unsafe { types::LIST_TYPE } { ObjectType::List } else if object_type == unsafe { types::TUPLE_TYPE } { ObjectType::Tuple } else if object_type == unsafe { types::DICT_TYPE } { ObjectType::Dict } else if is_enum_subclass(object_type) { ObjectType::Enum } else if is_dict_subclass(object_type) { ObjectType::Dict } else { ObjectType::Unknown(get_type_name(object_type).to_string()) } } macro_rules! bail_on_integer_conversion_error { ($value:expr) => { if !$value.is_null() { let repr = unsafe { pyo3::ffi::PyObject_Str($value) }; let mut size = 0; let ptr = unsafe { PyUnicode_AsUTF8AndSize(repr, &mut size) }; return if !ptr.is_null() { let slice = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts( ptr.cast::(), size as usize, )) }; let message = String::from(slice); unsafe { Py_DECREF(repr) }; Err(ser::Error::custom(message)) } else { Err(ser::Error::custom( "Internal Error: Failed to convert exception to string", )) }; } }; } /// Convert a Python value to `serde_json::Value` impl Serialize for SerializePyObject { fn serialize(&self, serializer: S) -> Result where S: Serializer, { match self.object_type { ObjectType::Str => { let mut str_size: pyo3::ffi::Py_ssize_t = 0; let ptr = unsafe { PyUnicode_AsUTF8AndSize(self.object, &mut str_size) }; let slice = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts( ptr.cast::(), str_size as usize, )) }; serializer.serialize_str(slice) } ObjectType::Int => { let value = unsafe { PyLong_AsLongLong(self.object) }; if value == -1 { #[cfg(Py_3_12)] { let exception = unsafe { pyo3::ffi::PyErr_GetRaisedException() }; bail_on_integer_conversion_error!(exception); }; #[cfg(not(Py_3_12))] { let mut ptype: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); let mut pvalue: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); let mut ptraceback: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); unsafe { pyo3::ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback) }; bail_on_integer_conversion_error!(pvalue); }; } serializer.serialize_i64(value) } ObjectType::Float => { serializer.serialize_f64(unsafe { PyFloat_AS_DOUBLE(self.object) }) } ObjectType::Bool => serializer.serialize_bool(self.object == unsafe { types::TRUE }), ObjectType::None => serializer.serialize_unit(), ObjectType::Dict => { if self.recursion_depth == RECURSION_LIMIT { return Err(ser::Error::custom("Recursion limit reached")); } let length = unsafe { (*self.object.cast::()).ma_used } as usize; if length == 0 { serializer.serialize_map(Some(0))?.end() } else { let mut map = serializer.serialize_map(Some(length))?; let mut pos = 0_isize; let mut str_size: pyo3::ffi::Py_ssize_t = 0; let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); for _ in 0..length { unsafe { pyo3::ffi::PyDict_Next(self.object, &mut pos, &mut key, &mut value); } check_type_is_str(key)?; let ptr = unsafe { PyUnicode_AsUTF8AndSize(key, &mut str_size) }; let slice = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts( ptr.cast::(), str_size as usize, )) }; #[allow(clippy::arithmetic_side_effects)] map.serialize_entry( slice, &SerializePyObject::new(value, self.recursion_depth + 1), )?; } map.end() } } ObjectType::List => { if self.recursion_depth == RECURSION_LIMIT { return Err(ser::Error::custom("Recursion limit reached")); } let length = unsafe { PyList_GET_SIZE(self.object) as usize }; if length == 0 { serializer.serialize_seq(Some(0))?.end() } else { let mut type_ptr = std::ptr::null_mut(); let mut ob_type = ObjectType::Str; let mut sequence = serializer.serialize_seq(Some(length))?; for i in 0..length { let elem = unsafe { PyList_GET_ITEM(self.object, i as isize) }; let current_ob_type = unsafe { Py_TYPE(elem) }; if current_ob_type != type_ptr { type_ptr = current_ob_type; ob_type = get_object_type(current_ob_type); } #[allow(clippy::arithmetic_side_effects)] sequence.serialize_element(&SerializePyObject::with_obtype( elem, ob_type.clone(), self.recursion_depth + 1, ))?; } sequence.end() } } ObjectType::Tuple => { if self.recursion_depth == RECURSION_LIMIT { return Err(ser::Error::custom("Recursion limit reached")); } let length = unsafe { PyTuple_GET_SIZE(self.object) as usize }; if length == 0 { serializer.serialize_seq(Some(0))?.end() } else { let mut type_ptr = std::ptr::null_mut(); let mut ob_type = ObjectType::Str; let mut sequence = serializer.serialize_seq(Some(length))?; for i in 0..length { let elem = unsafe { PyTuple_GET_ITEM(self.object, i as isize) }; let current_ob_type = unsafe { Py_TYPE(elem) }; if current_ob_type != type_ptr { type_ptr = current_ob_type; ob_type = get_object_type(current_ob_type); } #[allow(clippy::arithmetic_side_effects)] sequence.serialize_element(&SerializePyObject::with_obtype( elem, ob_type.clone(), self.recursion_depth + 1, ))?; } sequence.end() } } ObjectType::Enum => { let value = unsafe { PyObject_GetAttr(self.object, types::VALUE_STR) }; #[allow(clippy::arithmetic_side_effects)] SerializePyObject::new(value, self.recursion_depth + 1).serialize(serializer) } ObjectType::Unknown(ref type_name) => Err(ser::Error::custom(format!( "Unsupported type: '{}'", type_name ))), } } } #[inline] pub(crate) fn to_value(object: &PyAny) -> PyResult { serde_json::to_value(SerializePyObject::new(object.as_ptr(), 0)) .map_err(|err| exceptions::PyValueError::new_err(err.to_string())) }