From 732ff5b891b1b8b9657595799d5fb36024b9a3b1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 16:34:38 +0000 Subject: [PATCH 1/5] dataclass serialization speedups --- src/serializers/fields.rs | 285 ++++++++++++------ src/serializers/type_serializers/dataclass.rs | 93 +++++- tests/benchmarks/test_serialization_micro.py | 40 +++ 3 files changed, 306 insertions(+), 112 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index f48f8e0b9..fbe9ac37f 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -4,6 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use ahash::AHashMap; +use pyo3::types::iter::PyDictIterator; use serde::ser::SerializeMap; use crate::serializers::extra::SerCheck; @@ -100,6 +101,15 @@ pub struct GeneralFieldsSerializer { required_fields: usize, } +macro_rules! option_length { + ($op_has_len:expr) => { + match $op_has_len { + Some(ref has_len) => has_len.len(), + None => 0, + } + }; +} + impl GeneralFieldsSerializer { pub(super) fn new( fields: AHashMap, @@ -136,51 +146,22 @@ impl GeneralFieldsSerializer { } } } -} -macro_rules! option_length { - ($op_has_len:expr) => { - match $op_has_len { - Some(ref has_len) => has_len.len(), - None => 0, - } - }; -} - -impl_py_gc_traverse!(GeneralFieldsSerializer { - fields, - computed_fields -}); - -impl TypeSerializer for GeneralFieldsSerializer { - fn to_python( + pub fn main_to_python<'py>( &self, - value: &PyAny, - include: Option<&PyAny>, - exclude: Option<&PyAny>, - extra: &Extra, - ) -> PyResult { - let py = value.py(); - // If there is already a model registered (from a dataclass, BaseModel) - // then do not touch it - // If there is no model, we (a TypedDict) are the model - let td_extra = Extra { - model: extra.model.map_or_else(|| Some(value), Some), - ..*extra - }; - let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) { - main_extra_dict - } else { - td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?; - return infer_to_python(value, include, exclude, &td_extra); - }; - + py: Python<'py>, + main_iter: impl Iterator>, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + extra: Extra, + ) -> PyResult<&'py PyDict> { let output_dict = PyDict::new(py); let mut used_req_fields: usize = 0; // NOTE! we maintain the order of the input dict assuming that's right - for (key, value) in main_dict { - let key_str = key_str(key)?; + for result in main_iter { + let (key, value) = result?; + let key_str = key.to_str()?; let op_field = self.fields.get(key_str); if extra.exclude_none && value.is_none() { if let Some(field) = op_field { @@ -190,16 +171,16 @@ impl TypeSerializer for GeneralFieldsSerializer { } continue; } - let extra = Extra { + let field_extra = Extra { field_name: Some(key_str), - ..td_extra + ..extra }; if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? { if let Some(field) = op_field { if let Some(ref serializer) = field.serializer { - if !exclude_default(value, &extra, serializer)? { - let value = serializer.to_python(value, next_include, next_exclude, &extra)?; - let output_key = field.get_key_py(output_dict.py(), &extra); + if !exclude_default(value, &field_extra, serializer)? { + let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?; + let output_key = field.get_key_py(output_dict.py(), &field_extra); output_dict.set_item(output_key, value)?; } } @@ -209,23 +190,140 @@ impl TypeSerializer for GeneralFieldsSerializer { } } else if self.mode == FieldsMode::TypedDictAllow { let value = match &self.extra_serializer { - Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?, - None => infer_to_python(value, next_include, next_exclude, &extra)?, + Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?, + None => infer_to_python(value, next_include, next_exclude, &field_extra)?, }; output_dict.set_item(key, value)?; - } else if extra.check == SerCheck::Strict { + } else if field_extra.check == SerCheck::Strict { return Err(PydanticSerializationUnexpectedValue::new_err(None)); } } } - if td_extra.check.enabled() + + if extra.check.enabled() // If any of these are true we can't count fields && !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none) // Check for missing fields, we can't have extra fields here && self.required_fields > used_req_fields { - return Err(PydanticSerializationUnexpectedValue::new_err(None)); + Err(PydanticSerializationUnexpectedValue::new_err(None)) + } else { + Ok(output_dict) + } + } + + pub fn main_serde_serialize<'py, S: serde::ser::Serializer>( + &self, + main_iter: impl Iterator>, + expected_len: usize, + serializer: S, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + extra: Extra, + ) -> Result { + // NOTE! As above, we maintain the order of the input dict assuming that's right + // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used + let mut map = serializer.serialize_map(Some(expected_len))?; + + for result in main_iter { + let (key, value) = result.map_err(py_err_se_err)?; + if extra.exclude_none && value.is_none() { + continue; + } + let key_str = key_str(key).map_err(py_err_se_err)?; + let field_extra = Extra { + field_name: Some(key_str), + ..extra + }; + + let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = filter { + if let Some(field) = self.fields.get(key_str) { + if let Some(ref serializer) = field.serializer { + if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? { + let s = + PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra); + let output_key = field.get_key_json(key_str, &field_extra); + map.serialize_entry(&output_key, &s)?; + } + } + } else if self.mode == FieldsMode::TypedDictAllow { + let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?; + let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra); + map.serialize_entry(&output_key, &s)?; + } + // no error case here since unions (which need the error case) use `to_python(..., mode='json')` + } + } + Ok(map) + } + + pub fn add_computed_fields_python( + &self, + model: Option<&PyAny>, + output_dict: &PyDict, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult<()> { + if let Some(ref computed_fields) = self.computed_fields { + if let Some(model_value) = model { + let cf_extra = Extra { model, ..*extra }; + computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?; + } } + Ok(()) + } + + pub fn add_computed_fields_json( + &self, + model: Option<&PyAny>, + map: &mut S::SerializeMap, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result<(), S::Error> { + if let Some(ref computed_fields) = self.computed_fields { + if let Some(model) = model { + computed_fields.serde_serialize::(model, map, &self.filter, include, exclude, extra)?; + } + } + Ok(()) + } + + pub fn computed_field_count(&self) -> usize { + option_length!(self.computed_fields) + } +} + +impl_py_gc_traverse!(GeneralFieldsSerializer { + fields, + computed_fields +}); + +impl TypeSerializer for GeneralFieldsSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + // If there is already a model registered (from a dataclass, BaseModel) + // then do not touch it + // If there is no model, we (a TypedDict) are the model + let model = extra.model.map_or_else(|| Some(value), Some); + let td_extra = Extra { model, ..*extra }; + let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) { + main_extra_dict + } else { + td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?; + return infer_to_python(value, include, exclude, &td_extra); + }; + + let output_dict = self.main_to_python(py, DictResultIterator::new(main_dict), include, exclude, td_extra)?; + // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { for (key, value) in extra_dict { @@ -241,11 +339,7 @@ impl TypeSerializer for GeneralFieldsSerializer { } } } - if let Some(ref computed_fields) = self.computed_fields { - if let Some(model) = td_extra.model { - computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?; - } - } + self.add_computed_fields_python(model, output_dict, include, exclude, extra)?; Ok(output_dict.into_py(py)) } @@ -271,46 +365,23 @@ impl TypeSerializer for GeneralFieldsSerializer { // If there is already a model registered (from a dataclass, BaseModel) // then do not touch it // If there is no model, we (a TypedDict) are the model - let td_extra = Extra { - model: extra.model.map_or_else(|| Some(value), Some), - ..*extra - }; + let model = extra.model.map_or_else(|| Some(value), Some); + let td_extra = Extra { model, ..*extra }; let expected_len = match self.mode { - FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields), - _ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields), + FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(), + _ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(), }; // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used - let mut map = serializer.serialize_map(Some(expected_len))?; - - for (key, value) in main_dict { - if extra.exclude_none && value.is_none() { - continue; - } - let key_str = key_str(key).map_err(py_err_se_err)?; - let extra = Extra { - field_name: Some(key_str), - ..td_extra - }; + let mut map = self.main_serde_serialize( + DictResultIterator::new(main_dict), + expected_len, + serializer, + include, + exclude, + td_extra, + )?; - let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = filter { - if let Some(field) = self.fields.get(key_str) { - if let Some(ref serializer) = field.serializer { - if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? { - let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra); - let output_key = field.get_key_json(key_str, &extra); - map.serialize_entry(&output_key, &s)?; - } - } - } else if self.mode == FieldsMode::TypedDictAllow { - let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?; - let s = SerializeInfer::new(value, next_include, next_exclude, &extra); - map.serialize_entry(&output_key, &s)?; - } - // no error case here since unions (which need the error case) use `to_python(..., mode='json')` - } - } // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { for (key, value) in extra_dict { @@ -319,17 +390,14 @@ impl TypeSerializer for GeneralFieldsSerializer { } let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; if let Some((next_include, next_exclude)) = filter { - let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?; - let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra); + let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?; + let s = SerializeInfer::new(value, next_include, next_exclude, extra); map.serialize_entry(&output_key, &s)?; } } } - if let Some(ref computed_fields) = self.computed_fields { - if let Some(model) = td_extra.model { - computed_fields.serde_serialize::(model, &mut map, &self.filter, include, exclude, &td_extra)?; - } - } + + self.add_computed_fields_json::(model, &mut map, include, exclude, extra)?; map.end() } @@ -341,3 +409,28 @@ impl TypeSerializer for GeneralFieldsSerializer { fn key_str(key: &PyAny) -> PyResult<&str> { key.downcast::()?.to_str() } + +pub struct DictResultIterator<'py> { + dict_iter: PyDictIterator<'py>, +} + +impl<'py> DictResultIterator<'py> { + pub fn new(dict: &'py PyDict) -> Self { + Self { dict_iter: dict.iter() } + } +} + +impl<'py> Iterator for DictResultIterator<'py> { + type Item = PyResult<(&'py PyString, &'py PyAny)>; + + fn next(&mut self) -> Option { + if let Some((key, value)) = self.dict_iter.next() { + match key.downcast::() { + Ok(key_str) => Some(Ok((key_str, value))), + Err(e) => Some(Err(e.into())), + } + } else { + None + } + } +} diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index a82643186..2ba8e9661 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyType}; use std::borrow::Cow; use ahash::AHashMap; +use serde::ser::SerializeMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; @@ -131,16 +132,30 @@ impl TypeSerializer for DataclassSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let extra = Extra { + let dc_extra = Extra { model: Some(value), ..*extra }; - if self.allow_value(value, &extra)? { - let inner_value = self.get_inner_value(value)?; - self.serializer.to_python(inner_value, include, exclude, &extra) + if self.allow_value(value, extra)? { + let py = value.py(); + if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { + let output_dict = fields_serializer.main_to_python( + py, + DataclassResultIterator::new(&self.fields, value), + include, + exclude, + dc_extra, + )?; + + fields_serializer.add_computed_fields_python(Some(value), output_dict, include, exclude, extra)?; + Ok(output_dict.into_py(py)) + } else { + let inner_value = self.get_inner_value(value)?; + self.serializer.to_python(inner_value, include, exclude, &dc_extra) + } } else { - extra.warnings.on_fallback_py(self.get_name(), value, &extra)?; - infer_to_python(value, include, exclude, &extra) + extra.warnings.on_fallback_py(self.get_name(), value, &dc_extra)?; + infer_to_python(value, include, exclude, &dc_extra) } } @@ -161,17 +176,29 @@ impl TypeSerializer for DataclassSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let extra = Extra { - model: Some(value), - ..*extra - }; - if self.allow_value(value, &extra).map_err(py_err_se_err)? { - let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; - self.serializer - .serde_serialize(inner_value, serializer, include, exclude, &extra) + let model = Some(value); + let dc_extra = Extra { model, ..*extra }; + if self.allow_value(value, extra).map_err(py_err_se_err)? { + if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { + let expected_len = self.fields.len() + fields_serializer.computed_field_count(); + let mut map = fields_serializer.main_serde_serialize( + DataclassResultIterator::new(&self.fields, value), + expected_len, + serializer, + include, + exclude, + dc_extra, + )?; + fields_serializer.add_computed_fields_json::(model, &mut map, include, exclude, extra)?; + map.end() + } else { + let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; + self.serializer + .serde_serialize(inner_value, serializer, include, exclude, extra) + } } else { - extra.warnings.on_fallback_ser::(self.get_name(), value, &extra)?; - infer_serialize(value, serializer, include, exclude, &extra) + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) } } @@ -183,3 +210,37 @@ impl TypeSerializer for DataclassSerializer { true } } + +pub struct DataclassResultIterator<'a, 'py> { + index: usize, + fields: &'a [Py], + dataclass: &'py PyAny, +} + +impl<'a, 'py> DataclassResultIterator<'a, 'py> { + pub fn new(fields: &'a [Py], dataclass: &'py PyAny) -> Self { + Self { + index: 0, + fields, + dataclass, + } + } +} + +impl<'a, 'py> Iterator for DataclassResultIterator<'a, 'py> { + type Item = PyResult<(&'py PyString, &'py PyAny)>; + + fn next(&mut self) -> Option { + if let Some(field) = self.fields.get(self.index) { + self.index += 1; + let py = self.dataclass.py(); + let field_ref = field.clone_ref(py).into_ref(py); + match self.dataclass.getattr(field_ref) { + Ok(value) => Some(Ok((field_ref, value))), + Err(e) => Some(Err(e)), + } + } else { + None + } + } +} diff --git a/tests/benchmarks/test_serialization_micro.py b/tests/benchmarks/test_serialization_micro.py index 9200fc24b..7fc8eb245 100644 --- a/tests/benchmarks/test_serialization_micro.py +++ b/tests/benchmarks/test_serialization_micro.py @@ -1,4 +1,5 @@ import json +from dataclasses import dataclass from datetime import date, datetime from uuid import UUID @@ -409,3 +410,42 @@ def test_ser_list_of_lists(benchmark): data = [[i + j for j in range(10)] for i in range(1000)] benchmark(s.to_json, data) + + +@dataclass +class Foo: + a: str + b: bytes + c: int + d: float + + +dataclass_schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()), + core_schema.dataclass_field(name='c', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='d', schema=core_schema.float_schema()), + ], + ), + ['a', 'b', 'c', 'd'], +) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_serialization_python(benchmark): + s = SchemaSerializer(dataclass_schema) + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + assert s.to_python(dc) == {'a': 'hello', 'b': b'more', 'c': 123, 'd': 1.23} + benchmark(s.to_python, dc) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_serialization_json(benchmark): + s = SchemaSerializer(dataclass_schema) + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + assert s.to_python(dc) == {'a': 'hello', 'b': b'more', 'c': 123, 'd': 1.23} + benchmark(s.to_json, dc) From a6a9ae6662a1fb44c9890547f49c8daa0758d691 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 18:15:17 +0000 Subject: [PATCH 2/5] more dataclass serialization improvements --- src/serializers/fields.rs | 34 +--- src/serializers/infer.rs | 154 +++++++++++------- src/serializers/shared.rs | 93 ++++++++--- src/serializers/type_serializers/dataclass.rs | 2 +- 4 files changed, 175 insertions(+), 108 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index fbe9ac37f..84012be16 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -4,7 +4,6 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use ahash::AHashMap; -use pyo3::types::iter::PyDictIterator; use serde::ser::SerializeMap; use crate::serializers::extra::SerCheck; @@ -16,7 +15,7 @@ use super::extra::Extra; use super::filter::SchemaFilter; use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer}; use super::shared::PydanticSerializer; -use super::shared::{CombinedSerializer, TypeSerializer}; +use super::shared::{CombinedSerializer, DictResultIterator, TypeSerializer}; /// representation of a field for serialization #[derive(Debug, Clone)] @@ -150,7 +149,7 @@ impl GeneralFieldsSerializer { pub fn main_to_python<'py>( &self, py: Python<'py>, - main_iter: impl Iterator>, + main_iter: impl Iterator>, include: Option<&'py PyAny>, exclude: Option<&'py PyAny>, extra: Extra, @@ -161,7 +160,7 @@ impl GeneralFieldsSerializer { // NOTE! we maintain the order of the input dict assuming that's right for result in main_iter { let (key, value) = result?; - let key_str = key.to_str()?; + let key_str = key_str(key)?; let op_field = self.fields.get(key_str); if extra.exclude_none && value.is_none() { if let Some(field) = op_field { @@ -214,7 +213,7 @@ impl GeneralFieldsSerializer { pub fn main_serde_serialize<'py, S: serde::ser::Serializer>( &self, - main_iter: impl Iterator>, + main_iter: impl Iterator>, expected_len: usize, serializer: S, include: Option<&'py PyAny>, @@ -409,28 +408,3 @@ impl TypeSerializer for GeneralFieldsSerializer { fn key_str(key: &PyAny) -> PyResult<&str> { key.downcast::()?.to_str() } - -pub struct DictResultIterator<'py> { - dict_iter: PyDictIterator<'py>, -} - -impl<'py> DictResultIterator<'py> { - pub fn new(dict: &'py PyDict) -> Self { - Self { dict_iter: dict.iter() } - } -} - -impl<'py> Iterator for DictResultIterator<'py> { - type Item = PyResult<(&'py PyString, &'py PyAny)>; - - fn next(&mut self) -> Option { - if let Some((key, value)) = self.dict_iter.next() { - match key.downcast::() { - Ok(key_str) => Some(Ok((key_str, value))), - Err(e) => Some(Err(e.into())), - } - } else { - None - } - } -} diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index e39ca38f8..78f36d885 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -10,19 +10,17 @@ use pyo3::types::{ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; use crate::input::{EitherTimedelta, Int}; -use crate::serializers::config::InfNanMode; -use crate::serializers::errors::SERIALIZATION_ERR_MARKER; -use crate::serializers::filter::SchemaFilter; -use crate::serializers::shared::{PydanticSerializer, TypeSerializer}; -use crate::serializers::SchemaSerializer; use crate::tools::{extract_i64, py_err, safe_repr}; use crate::url::{PyMultiHostUrl, PyUrl}; +use super::config::InfNanMode; +use super::errors::SERIALIZATION_ERR_MARKER; use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; -use super::filter::AnyFilter; +use super::filter::{AnyFilter, SchemaFilter}; use super::ob_type::ObType; -use super::shared::dataclass_to_dict; +use super::shared::{DataclassSerializer, DictResultIterator, PydanticSerializer, TypeSerializer}; +use super::SchemaSerializer; pub(crate) fn infer_to_python( value: &PyAny, @@ -83,22 +81,6 @@ pub(crate) fn infer_to_python_known( }}; } - let serialize_dict = |dict: &PyDict| { - let new_dict = PyDict::new(py); - let filter = AnyFilter::new(); - - for (k, v) in dict { - let op_next = filter.key_filter(k, include, exclude)?; - if let Some((next_include, next_exclude)) = op_next { - let k_str = infer_json_key(k, extra)?; - let k = PyString::new(py, &k_str); - let v = infer_to_python(v, next_include, next_exclude, extra)?; - new_dict.set_item(k, v)?; - } - } - Ok::(new_dict.into_py(py)) - }; - let serialize_with_serializer = || { let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?; let serializer: PyRef = py_serializer.extract()?; @@ -168,7 +150,13 @@ pub(crate) fn infer_to_python_known( let elements = serialize_seq!(PyFrozenSet); PyList::new(py, elements).into_py(py) } - ObType::Dict => serialize_dict(value.downcast()?)?, + ObType::Dict => serialize_pairs_python_mode_json( + py, + DictResultIterator::new(value.downcast()?), + include, + exclude, + extra, + )?, ObType::Datetime => { let py_dt: &PyDateTime = value.downcast()?; let iso_dt = super::type_serializers::datetime_etc::datetime_to_string(py_dt)?; @@ -205,7 +193,9 @@ pub(crate) fn infer_to_python_known( uuid.into_py(py) } ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, + ObType::Dataclass => { + serialize_pairs_python_mode_json(py, DataclassSerializer::new(value)?, include, exclude, extra)? + } ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; infer_to_python(v, include, exclude, extra)?.into_py(py) @@ -256,22 +246,10 @@ pub(crate) fn infer_to_python_known( PyFrozenSet::new(py, &elements)?.into_py(py) } ObType::Dict => { - // different logic for keys from above - let dict: &PyDict = value.downcast()?; - let new_dict = PyDict::new(py); - let filter = AnyFilter::new(); - - for (k, v) in dict { - let op_next = filter.key_filter(k, include, exclude)?; - if let Some((next_include, next_exclude)) = op_next { - let v = infer_to_python(v, next_include, next_exclude, extra)?; - new_dict.set_item(k, v)?; - } - } - new_dict.into_py(py) + serialize_pairs_python(py, DictResultIterator::new(value.downcast()?), include, exclude, extra)? } ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, + ObType::Dataclass => serialize_pairs_python(py, DataclassSerializer::new(value)?, include, exclude, extra)?, ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -405,23 +383,6 @@ pub(crate) fn infer_serialize_known( }}; } - macro_rules! serialize_dict { - ($py_dict:expr) => {{ - let mut map = serializer.serialize_map(Some($py_dict.len()))?; - let filter = AnyFilter::new(); - - for (key, value) in $py_dict { - let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let key = infer_json_key(key, extra).map_err(py_err_se_err)?; - let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra); - map.serialize_entry(&key, &value_serializer)?; - } - } - map.end() - }}; - } - let ser_result = match ob_type { ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(Int), @@ -445,7 +406,10 @@ pub(crate) fn infer_serialize_known( .bytes_mode .serialize_bytes(unsafe { py_byte_array.as_bytes() }, serializer) } - ObType::Dict => serialize_dict!(value.downcast::().map_err(py_err_se_err)?), + ObType::Dict => { + let dict = value.downcast::().map_err(py_err_se_err)?; + serialize_pairs_json(DictResultIterator::new(dict), serializer, include, exclude, extra) + } ObType::List => serialize_seq_filter!(PyList), ObType::Tuple => serialize_seq_filter!(PyTuple), ObType::Set => serialize_seq!(PySet), @@ -503,7 +467,13 @@ pub(crate) fn infer_serialize_known( PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); pydantic_serializer.serialize(serializer) } - ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?), + ObType::Dataclass => serialize_pairs_json( + DataclassSerializer::new(value).map_err(py_err_se_err)?, + serializer, + include, + exclude, + extra, + ), ObType::Uuid => { let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?; let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?; @@ -672,3 +642,71 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra: } } } + +fn serialize_pairs_python<'py>( + py: Python, + pairs_iter: impl Iterator>, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, +) -> PyResult { + let new_dict = PyDict::new(py); + let filter = AnyFilter::new(); + + for result in pairs_iter { + let (k, v) = result?; + let op_next = filter.key_filter(k, include, exclude)?; + if let Some((next_include, next_exclude)) = op_next { + let v = infer_to_python(v, next_include, next_exclude, extra)?; + new_dict.set_item(k, v)?; + } + } + Ok(new_dict.into_py(py)) +} + +fn serialize_pairs_python_mode_json<'py>( + py: Python, + pairs_iter: impl Iterator>, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, +) -> PyResult { + let new_dict = PyDict::new(py); + let filter = AnyFilter::new(); + + for result in pairs_iter { + let (k, v) = result?; + let op_next = filter.key_filter(k, include, exclude)?; + if let Some((next_include, next_exclude)) = op_next { + let k_str = infer_json_key(k, extra)?; + let k = PyString::new(py, &k_str); + let v = infer_to_python(v, next_include, next_exclude, extra)?; + new_dict.set_item(k, v)?; + } + } + Ok(new_dict.into_py(py)) +} + +fn serialize_pairs_json<'py, S: Serializer>( + pairs_iter: impl Iterator>, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, +) -> Result { + let (_, expected) = pairs_iter.size_hint(); + let mut map = serializer.serialize_map(expected)?; + let filter = AnyFilter::new(); + + for result in pairs_iter { + let (key, value) = result.map_err(py_err_se_err)?; + + let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = op_next { + let key = infer_json_key(key, extra).map_err(py_err_se_err)?; + let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra); + map.serialize_entry(&key, &value_serializer)?; + } + } + map.end() +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 11aac037d..797199435 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -4,6 +4,7 @@ use std::fmt::Debug; use pyo3::exceptions::PyTypeError; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; +use pyo3::types::iter::PyDictIterator; use pyo3::types::{PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -364,29 +365,83 @@ pub(crate) fn to_json_bytes( Ok(bytes) } +pub(super) struct DictResultIterator<'py> { + dict_iter: PyDictIterator<'py>, +} + +impl<'py> DictResultIterator<'py> { + pub fn new(dict: &'py PyDict) -> Self { + Self { dict_iter: dict.iter() } + } +} + +impl<'py> Iterator for DictResultIterator<'py> { + type Item = PyResult<(&'py PyAny, &'py PyAny)>; + + fn next(&mut self) -> Option { + self.dict_iter.next().map(Ok) + } + + fn size_hint(&self) -> (usize, Option) { + self.dict_iter.size_hint() + } +} + +pub(super) struct DataclassSerializer<'py> { + dataclass: &'py PyAny, + fields_iter: PyDictIterator<'py>, + field_type_marker: &'py PyAny, +} + +impl<'py> DataclassSerializer<'py> { + pub fn new(dc: &'py PyAny) -> PyResult { + let py = dc.py(); + let fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + Ok(Self { + dataclass: dc, + fields_iter: fields.iter(), + field_type_marker: get_field_marker(py)?, + }) + } + + fn _next(&mut self) -> PyResult> { + if let Some((field_name, field)) = self.fields_iter.next() { + let field_type = field.getattr(intern!(self.dataclass.py(), "_field_type"))?; + if field_type.is(self.field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + let value = self.dataclass.getattr(field_name)?; + Ok(Some((field_name, value))) + } else { + self._next() + } + } else { + Ok(None) + } + } +} + +impl<'py> Iterator for DataclassSerializer<'py> { + type Item = PyResult<(&'py PyAny, &'py PyAny)>; + + fn next(&mut self) -> Option { + match self._next() { + Ok(Some(v)) => Some(Ok(v)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} + static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); /// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` -pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { +fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || { - let field_ = py.import("dataclasses")?.getattr("_FIELD")?; - Ok::(field_.into_py(py)) + py.import("dataclasses")?.getattr("_FIELD").map(|f| f.into_py(py)) })?; Ok(field_type_marker_obj.as_ref(py)) } - -pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { - let py = dc.py(); - let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; - let dict = PyDict::new(py); - - let field_type_marker = get_field_marker(py)?; - for (field_name, field) in dc_fields { - let field_type = field.getattr(intern!(py, "_field_type"))?; - if field_type.is(field_type_marker) { - let field_name: &PyString = field_name.downcast()?; - dict.set_item(field_name, dc.getattr(field_name)?)?; - } - } - Ok(dict) -} diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 2ba8e9661..fea5c68d9 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -228,7 +228,7 @@ impl<'a, 'py> DataclassResultIterator<'a, 'py> { } impl<'a, 'py> Iterator for DataclassResultIterator<'a, 'py> { - type Item = PyResult<(&'py PyString, &'py PyAny)>; + type Item = PyResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { if let Some(field) = self.fields.get(self.index) { From 2a73357777722956fac4bae168ee7a1601f49e76 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 19:04:19 +0000 Subject: [PATCH 3/5] to_json benchmark --- tests/benchmarks/test_serialization_micro.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/benchmarks/test_serialization_micro.py b/tests/benchmarks/test_serialization_micro.py index 7fc8eb245..96170b5eb 100644 --- a/tests/benchmarks/test_serialization_micro.py +++ b/tests/benchmarks/test_serialization_micro.py @@ -449,3 +449,9 @@ def test_dataclass_serialization_json(benchmark): dc = Foo(a='hello', b=b'more', c=123, d=1.23) assert s.to_python(dc) == {'a': 'hello', 'b': b'more', 'c': 123, 'd': 1.23} benchmark(s.to_json, dc) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_to_json(benchmark): + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + benchmark(to_json, dc) From 0797e59997aba60c17438fd918d1a1b076843555 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 16 Jan 2024 09:29:34 +0000 Subject: [PATCH 4/5] rename iterator structs for consistency --- src/serializers/fields.rs | 6 ++--- src/serializers/infer.rs | 26 ++++++++----------- src/serializers/shared.rs | 12 ++++----- src/serializers/type_serializers/dataclass.rs | 10 +++---- 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index 84012be16..e3dea93d1 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -15,7 +15,7 @@ use super::extra::Extra; use super::filter::SchemaFilter; use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer}; use super::shared::PydanticSerializer; -use super::shared::{CombinedSerializer, DictResultIterator, TypeSerializer}; +use super::shared::{CombinedSerializer, DictIterator, TypeSerializer}; /// representation of a field for serialization #[derive(Debug, Clone)] @@ -321,7 +321,7 @@ impl TypeSerializer for GeneralFieldsSerializer { return infer_to_python(value, include, exclude, &td_extra); }; - let output_dict = self.main_to_python(py, DictResultIterator::new(main_dict), include, exclude, td_extra)?; + let output_dict = self.main_to_python(py, DictIterator::new(main_dict), include, exclude, td_extra)?; // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { @@ -373,7 +373,7 @@ impl TypeSerializer for GeneralFieldsSerializer { // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used let mut map = self.main_serde_serialize( - DictResultIterator::new(main_dict), + DictIterator::new(main_dict), expected_len, serializer, include, diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 78f36d885..d4bc2c0b6 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -19,7 +19,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; use super::filter::{AnyFilter, SchemaFilter}; use super::ob_type::ObType; -use super::shared::{DataclassSerializer, DictResultIterator, PydanticSerializer, TypeSerializer}; +use super::shared::{AnyDataclassIterator, DictIterator, PydanticSerializer, TypeSerializer}; use super::SchemaSerializer; pub(crate) fn infer_to_python( @@ -150,13 +150,9 @@ pub(crate) fn infer_to_python_known( let elements = serialize_seq!(PyFrozenSet); PyList::new(py, elements).into_py(py) } - ObType::Dict => serialize_pairs_python_mode_json( - py, - DictResultIterator::new(value.downcast()?), - include, - exclude, - extra, - )?, + ObType::Dict => { + serialize_pairs_python_mode_json(py, DictIterator::new(value.downcast()?), include, exclude, extra)? + } ObType::Datetime => { let py_dt: &PyDateTime = value.downcast()?; let iso_dt = super::type_serializers::datetime_etc::datetime_to_string(py_dt)?; @@ -194,7 +190,7 @@ pub(crate) fn infer_to_python_known( } ObType::PydanticSerializable => serialize_with_serializer()?, ObType::Dataclass => { - serialize_pairs_python_mode_json(py, DataclassSerializer::new(value)?, include, exclude, extra)? + serialize_pairs_python_mode_json(py, AnyDataclassIterator::new(value)?, include, exclude, extra)? } ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; @@ -245,11 +241,11 @@ pub(crate) fn infer_to_python_known( let elements = serialize_seq!(PyFrozenSet); PyFrozenSet::new(py, &elements)?.into_py(py) } - ObType::Dict => { - serialize_pairs_python(py, DictResultIterator::new(value.downcast()?), include, exclude, extra)? - } + ObType::Dict => serialize_pairs_python(py, DictIterator::new(value.downcast()?), include, exclude, extra)?, ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => serialize_pairs_python(py, DataclassSerializer::new(value)?, include, exclude, extra)?, + ObType::Dataclass => { + serialize_pairs_python(py, AnyDataclassIterator::new(value)?, include, exclude, extra)? + } ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -408,7 +404,7 @@ pub(crate) fn infer_serialize_known( } ObType::Dict => { let dict = value.downcast::().map_err(py_err_se_err)?; - serialize_pairs_json(DictResultIterator::new(dict), serializer, include, exclude, extra) + serialize_pairs_json(DictIterator::new(dict), serializer, include, exclude, extra) } ObType::List => serialize_seq_filter!(PyList), ObType::Tuple => serialize_seq_filter!(PyTuple), @@ -468,7 +464,7 @@ pub(crate) fn infer_serialize_known( pydantic_serializer.serialize(serializer) } ObType::Dataclass => serialize_pairs_json( - DataclassSerializer::new(value).map_err(py_err_se_err)?, + AnyDataclassIterator::new(value).map_err(py_err_se_err)?, serializer, include, exclude, diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 797199435..7f44a4c29 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -365,17 +365,17 @@ pub(crate) fn to_json_bytes( Ok(bytes) } -pub(super) struct DictResultIterator<'py> { +pub(super) struct DictIterator<'py> { dict_iter: PyDictIterator<'py>, } -impl<'py> DictResultIterator<'py> { +impl<'py> DictIterator<'py> { pub fn new(dict: &'py PyDict) -> Self { Self { dict_iter: dict.iter() } } } -impl<'py> Iterator for DictResultIterator<'py> { +impl<'py> Iterator for DictIterator<'py> { type Item = PyResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { @@ -387,13 +387,13 @@ impl<'py> Iterator for DictResultIterator<'py> { } } -pub(super) struct DataclassSerializer<'py> { +pub(super) struct AnyDataclassIterator<'py> { dataclass: &'py PyAny, fields_iter: PyDictIterator<'py>, field_type_marker: &'py PyAny, } -impl<'py> DataclassSerializer<'py> { +impl<'py> AnyDataclassIterator<'py> { pub fn new(dc: &'py PyAny) -> PyResult { let py = dc.py(); let fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; @@ -420,7 +420,7 @@ impl<'py> DataclassSerializer<'py> { } } -impl<'py> Iterator for DataclassSerializer<'py> { +impl<'py> Iterator for AnyDataclassIterator<'py> { type Item = PyResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index fea5c68d9..c42f03b65 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -141,7 +141,7 @@ impl TypeSerializer for DataclassSerializer { if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let output_dict = fields_serializer.main_to_python( py, - DataclassResultIterator::new(&self.fields, value), + KnownDataclassIterator::new(&self.fields, value), include, exclude, dc_extra, @@ -182,7 +182,7 @@ impl TypeSerializer for DataclassSerializer { if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let expected_len = self.fields.len() + fields_serializer.computed_field_count(); let mut map = fields_serializer.main_serde_serialize( - DataclassResultIterator::new(&self.fields, value), + KnownDataclassIterator::new(&self.fields, value), expected_len, serializer, include, @@ -211,13 +211,13 @@ impl TypeSerializer for DataclassSerializer { } } -pub struct DataclassResultIterator<'a, 'py> { +pub struct KnownDataclassIterator<'a, 'py> { index: usize, fields: &'a [Py], dataclass: &'py PyAny, } -impl<'a, 'py> DataclassResultIterator<'a, 'py> { +impl<'a, 'py> KnownDataclassIterator<'a, 'py> { pub fn new(fields: &'a [Py], dataclass: &'py PyAny) -> Self { Self { index: 0, @@ -227,7 +227,7 @@ impl<'a, 'py> DataclassResultIterator<'a, 'py> { } } -impl<'a, 'py> Iterator for DataclassResultIterator<'a, 'py> { +impl<'a, 'py> Iterator for KnownDataclassIterator<'a, 'py> { type Item = PyResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { From 6740e739415d6856d60c7dfe80b41bfb1f6cd066 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 16 Jan 2024 18:58:53 +0000 Subject: [PATCH 5/5] implement suggestions --- src/serializers/fields.rs | 6 +- src/serializers/infer.rs | 62 ++++++--------- src/serializers/shared.rs | 79 ++++--------------- src/serializers/type_serializers/dataclass.rs | 45 +++-------- 4 files changed, 52 insertions(+), 140 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index e3dea93d1..cefdec1d7 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -15,7 +15,7 @@ use super::extra::Extra; use super::filter::SchemaFilter; use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer}; use super::shared::PydanticSerializer; -use super::shared::{CombinedSerializer, DictIterator, TypeSerializer}; +use super::shared::{CombinedSerializer, TypeSerializer}; /// representation of a field for serialization #[derive(Debug, Clone)] @@ -321,7 +321,7 @@ impl TypeSerializer for GeneralFieldsSerializer { return infer_to_python(value, include, exclude, &td_extra); }; - let output_dict = self.main_to_python(py, DictIterator::new(main_dict), include, exclude, td_extra)?; + let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?; // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { @@ -373,7 +373,7 @@ impl TypeSerializer for GeneralFieldsSerializer { // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used let mut map = self.main_serde_serialize( - DictIterator::new(main_dict), + main_dict.iter().map(Ok), expected_len, serializer, include, diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index d4bc2c0b6..5ddf77597 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -19,7 +19,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; use super::filter::{AnyFilter, SchemaFilter}; use super::ob_type::ObType; -use super::shared::{AnyDataclassIterator, DictIterator, PydanticSerializer, TypeSerializer}; +use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer}; use super::SchemaSerializer; pub(crate) fn infer_to_python( @@ -151,7 +151,10 @@ pub(crate) fn infer_to_python_known( PyList::new(py, elements).into_py(py) } ObType::Dict => { - serialize_pairs_python_mode_json(py, DictIterator::new(value.downcast()?), include, exclude, extra)? + let dict: &PyDict = value.downcast()?; + serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, |k| { + Ok(PyString::new(py, &infer_json_key(k, extra)?)) + })? } ObType::Datetime => { let py_dt: &PyDateTime = value.downcast()?; @@ -190,7 +193,9 @@ pub(crate) fn infer_to_python_known( } ObType::PydanticSerializable => serialize_with_serializer()?, ObType::Dataclass => { - serialize_pairs_python_mode_json(py, AnyDataclassIterator::new(value)?, include, exclude, extra)? + serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, |k| { + Ok(PyString::new(py, &infer_json_key(k, extra)?)) + })? } ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; @@ -241,11 +246,12 @@ pub(crate) fn infer_to_python_known( let elements = serialize_seq!(PyFrozenSet); PyFrozenSet::new(py, &elements)?.into_py(py) } - ObType::Dict => serialize_pairs_python(py, DictIterator::new(value.downcast()?), include, exclude, extra)?, - ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => { - serialize_pairs_python(py, AnyDataclassIterator::new(value)?, include, exclude, extra)? + ObType::Dict => { + let dict: &PyDict = value.downcast()?; + serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, Ok)? } + ObType::PydanticSerializable => serialize_with_serializer()?, + ObType::Dataclass => serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, Ok)?, ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -404,7 +410,7 @@ pub(crate) fn infer_serialize_known( } ObType::Dict => { let dict = value.downcast::().map_err(py_err_se_err)?; - serialize_pairs_json(DictIterator::new(dict), serializer, include, exclude, extra) + serialize_pairs_json(dict.iter().map(Ok), dict.len(), serializer, include, exclude, extra) } ObType::List => serialize_seq_filter!(PyList), ObType::Tuple => serialize_seq_filter!(PyTuple), @@ -463,13 +469,10 @@ pub(crate) fn infer_serialize_known( PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); pydantic_serializer.serialize(serializer) } - ObType::Dataclass => serialize_pairs_json( - AnyDataclassIterator::new(value).map_err(py_err_se_err)?, - serializer, - include, - exclude, - extra, - ), + ObType::Dataclass => { + let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?; + serialize_pairs_json(pairs_iter, fields_dict.len(), serializer, include, exclude, extra) + } ObType::Uuid => { let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?; let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?; @@ -645,6 +648,7 @@ fn serialize_pairs_python<'py>( include: Option<&PyAny>, exclude: Option<&PyAny>, extra: &Extra, + key_transform: impl Fn(&'py PyAny) -> PyResult<&'py PyAny>, ) -> PyResult { let new_dict = PyDict::new(py); let filter = AnyFilter::new(); @@ -653,29 +657,7 @@ fn serialize_pairs_python<'py>( let (k, v) = result?; let op_next = filter.key_filter(k, include, exclude)?; if let Some((next_include, next_exclude)) = op_next { - let v = infer_to_python(v, next_include, next_exclude, extra)?; - new_dict.set_item(k, v)?; - } - } - Ok(new_dict.into_py(py)) -} - -fn serialize_pairs_python_mode_json<'py>( - py: Python, - pairs_iter: impl Iterator>, - include: Option<&PyAny>, - exclude: Option<&PyAny>, - extra: &Extra, -) -> PyResult { - let new_dict = PyDict::new(py); - let filter = AnyFilter::new(); - - for result in pairs_iter { - let (k, v) = result?; - let op_next = filter.key_filter(k, include, exclude)?; - if let Some((next_include, next_exclude)) = op_next { - let k_str = infer_json_key(k, extra)?; - let k = PyString::new(py, &k_str); + let k = key_transform(k)?; let v = infer_to_python(v, next_include, next_exclude, extra)?; new_dict.set_item(k, v)?; } @@ -685,13 +667,13 @@ fn serialize_pairs_python_mode_json<'py>( fn serialize_pairs_json<'py, S: Serializer>( pairs_iter: impl Iterator>, + iter_size: usize, serializer: S, include: Option<&PyAny>, exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let (_, expected) = pairs_iter.size_hint(); - let mut map = serializer.serialize_map(expected)?; + let mut map = serializer.serialize_map(Some(iter_size))?; let filter = AnyFilter::new(); for result in pairs_iter { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 7f44a4c29..7cfe6ce6e 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -4,7 +4,6 @@ use std::fmt::Debug; use pyo3::exceptions::PyTypeError; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; -use pyo3::types::iter::PyDictIterator; use pyo3::types::{PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -365,75 +364,25 @@ pub(crate) fn to_json_bytes( Ok(bytes) } -pub(super) struct DictIterator<'py> { - dict_iter: PyDictIterator<'py>, -} - -impl<'py> DictIterator<'py> { - pub fn new(dict: &'py PyDict) -> Self { - Self { dict_iter: dict.iter() } - } -} - -impl<'py> Iterator for DictIterator<'py> { - type Item = PyResult<(&'py PyAny, &'py PyAny)>; - - fn next(&mut self) -> Option { - self.dict_iter.next().map(Ok) - } - - fn size_hint(&self) -> (usize, Option) { - self.dict_iter.size_hint() - } -} - -pub(super) struct AnyDataclassIterator<'py> { +pub(super) fn any_dataclass_iter<'py>( dataclass: &'py PyAny, - fields_iter: PyDictIterator<'py>, - field_type_marker: &'py PyAny, -} - -impl<'py> AnyDataclassIterator<'py> { - pub fn new(dc: &'py PyAny) -> PyResult { - let py = dc.py(); - let fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; - Ok(Self { - dataclass: dc, - fields_iter: fields.iter(), - field_type_marker: get_field_marker(py)?, - }) - } - - fn _next(&mut self) -> PyResult> { - if let Some((field_name, field)) = self.fields_iter.next() { - let field_type = field.getattr(intern!(self.dataclass.py(), "_field_type"))?; - if field_type.is(self.field_type_marker) { - let field_name: &PyString = field_name.downcast()?; - let value = self.dataclass.getattr(field_name)?; - Ok(Some((field_name, value))) - } else { - self._next() - } +) -> PyResult<(impl Iterator> + 'py, &PyDict)> { + let py = dataclass.py(); + let fields: &PyDict = dataclass.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let field_type_marker = get_field_marker(py)?; + + let next = move |(field_name, field): (&'py PyAny, &'py PyAny)| -> PyResult> { + let field_type = field.getattr(intern!(py, "_field_type"))?; + if field_type.is(field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + let value = dataclass.getattr(field_name)?; + Ok(Some((field_name, value))) } else { Ok(None) } - } -} - -impl<'py> Iterator for AnyDataclassIterator<'py> { - type Item = PyResult<(&'py PyAny, &'py PyAny)>; - - fn next(&mut self) -> Option { - match self._next() { - Ok(Some(v)) => Some(Ok(v)), - Ok(None) => None, - Err(e) => Some(Err(e)), - } - } + }; - fn size_hint(&self) -> (usize, Option) { - (0, None) - } + Ok((fields.iter().filter_map(move |field| next(field).transpose()), fields)) } static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index c42f03b65..93548bcba 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -141,7 +141,7 @@ impl TypeSerializer for DataclassSerializer { if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let output_dict = fields_serializer.main_to_python( py, - KnownDataclassIterator::new(&self.fields, value), + known_dataclass_iter(&self.fields, value), include, exclude, dc_extra, @@ -182,7 +182,7 @@ impl TypeSerializer for DataclassSerializer { if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let expected_len = self.fields.len() + fields_serializer.computed_field_count(); let mut map = fields_serializer.main_serde_serialize( - KnownDataclassIterator::new(&self.fields, value), + known_dataclass_iter(&self.fields, value), expected_len, serializer, include, @@ -211,36 +211,17 @@ impl TypeSerializer for DataclassSerializer { } } -pub struct KnownDataclassIterator<'a, 'py> { - index: usize, +fn known_dataclass_iter<'a, 'py>( fields: &'a [Py], dataclass: &'py PyAny, -} - -impl<'a, 'py> KnownDataclassIterator<'a, 'py> { - pub fn new(fields: &'a [Py], dataclass: &'py PyAny) -> Self { - Self { - index: 0, - fields, - dataclass, - } - } -} - -impl<'a, 'py> Iterator for KnownDataclassIterator<'a, 'py> { - type Item = PyResult<(&'py PyAny, &'py PyAny)>; - - fn next(&mut self) -> Option { - if let Some(field) = self.fields.get(self.index) { - self.index += 1; - let py = self.dataclass.py(); - let field_ref = field.clone_ref(py).into_ref(py); - match self.dataclass.getattr(field_ref) { - Ok(value) => Some(Ok((field_ref, value))), - Err(e) => Some(Err(e)), - } - } else { - None - } - } +) -> impl Iterator> + 'a +where + 'py: 'a, +{ + let py = dataclass.py(); + fields.iter().map(move |field| { + let field_ref = field.clone_ref(py).into_ref(py); + let value = dataclass.getattr(field_ref)?; + Ok((field_ref as &PyAny, value)) + }) }