From e40a35a983fc2b2c88406327acd0c854879c1351 Mon Sep 17 00:00:00 2001 From: "o.ermakov" Date: Tue, 18 Oct 2022 19:03:38 +0400 Subject: [PATCH] fix(python): Review fixes --- bindings/python/src/ser.rs | 30 +++++++++++++-------- bindings/python/tests-py/test_jsonschema.py | 5 ++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/bindings/python/src/ser.rs b/bindings/python/src/ser.rs index e257505..accc3d1 100644 --- a/bindings/python/src/ser.rs +++ b/bindings/python/src/ser.rs @@ -74,6 +74,23 @@ fn get_object_type_from_object(object: *mut pyo3::ffi::PyObject) -> ObjectType { } } +fn get_type_name(object_type: *mut pyo3::ffi::PyTypeObject) -> String { + let type_name = unsafe { CStr::from_ptr((*object_type).tp_name).to_string_lossy() }; + type_name.to_string() +} + +#[inline] +fn check_type_is_str(object: *mut pyo3::ffi::PyObject) -> Result<(), E> { + let object_type = unsafe { Py_TYPE(object) }; + if object_type != unsafe { types::STR_TYPE } { + return Err(ser::Error::custom(format!( + "Dict key must be str. Got '{}'", + get_type_name(object_type) + ))); + } + Ok(()) +} + #[inline] pub fn get_object_type(object_type: *mut pyo3::ffi::PyTypeObject) -> ObjectType { if object_type == unsafe { types::STR_TYPE } { @@ -95,8 +112,7 @@ pub fn get_object_type(object_type: *mut pyo3::ffi::PyTypeObject) -> ObjectType } else if is_enum_subclass(object_type) { ObjectType::Enum } else { - let type_name = unsafe { CStr::from_ptr((*object_type).tp_name).to_string_lossy() }; - ObjectType::Unknown(type_name.to_string()) + ObjectType::Unknown(get_type_name(object_type)) } } @@ -141,15 +157,7 @@ impl Serialize for SerializePyObject { unsafe { pyo3::ffi::PyDict_Next(self.object, &mut pos, &mut key, &mut value); } - match get_object_type_from_object(key) { - ObjectType::Str => {} - object_type => { - return Err(ser::Error::custom(format!( - "Supported only str key type. Provided type '{:?}'", - object_type - ))) - } - } + check_type_is_str(key)?; let uni = unsafe { string::read_utf8_from_str(key, &mut str_size) }; let slice = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts( diff --git a/bindings/python/tests-py/test_jsonschema.py b/bindings/python/tests-py/test_jsonschema.py index 9734463..afdae45 100644 --- a/bindings/python/tests-py/test_jsonschema.py +++ b/bindings/python/tests-py/test_jsonschema.py @@ -1,3 +1,4 @@ +import uuid from collections import namedtuple from contextlib import suppress from enum import Enum @@ -236,7 +237,7 @@ def test_enums(type_, value, expected): def test_dict_with_non_str_keys(): schema = {"type": "object"} - instance = {1234567: "foo"} + instance = {uuid.uuid4(): "foo"} with pytest.raises(ValueError) as exec_info: validate(schema, instance) - assert exec_info.value.args[0] == "Supported only str key type. Provided type 'Int'" + assert exec_info.value.args[0] == "Dict key must be str. Got 'UUID'"