8000 di container optimizations, ability to resolve as ref by RomanEmreis · Pull Request #70 · RomanEmreis/volga · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

di container optimizations, ability to resolve as ref #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 209 additions & 46 deletions src/di/container.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use super::{Inject, DiError};
use crate::error::Error;
use futures_util::TryFutureExt;
use tokio::sync::OnceCell;
use std::{
any::{Any, TypeId},
collections::HashMap,
hash::{BuildHasherDefault, Hasher},
sync::Arc
};

Expand All @@ -13,15 +15,48 @@ type ArcService = Arc<
+ Sync
>;

#[derive(Clone)]
pub(crate) enum ServiceEntry {
Singleton(ArcService),
Scoped(OnceCell<ArcService>),
Transient,
}

impl ServiceEntry {
#[inline]
fn as_scope(&self) -> Self {
match self {
ServiceEntry::Singleton(service) => ServiceEntry::Singleton(service.clone()),
ServiceEntry::Scoped(_) => ServiceEntry::Scoped(OnceCell::new()),
ServiceEntry::Transient => ServiceEntry::Transient,
}
}
}

type ServiceMap = HashMap<TypeId, ServiceEntry, BuildHasherDefault<TypeIdHasher>>;

#[derive(Default)]
struct TypeIdHasher(u64);

impl Hasher for TypeIdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}

#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}

#[inline]
fn finish(&self) -> u64 {
self.0
}
}

/// Represents a DI container builder
pub struct ContainerBuilder {
services: HashMap<TypeId, ServiceEntry>
services: ServiceMap
}

impl Default for ContainerBuilder {
Expand All @@ -33,7 +68,7 @@ impl Default for ContainerBuilder {
impl ContainerBuilder {
/// Create a new DI container builder
pub fn new() -> Self {
Self { services: HashMap::new() }
Self { services: ServiceMap::default() }
}

/// Build a DI container
@@ -64,73 +99,81 @@ impl ContainerBuilder {

/// Represents a DI container
pub struct Container {
services: HashMap<TypeId, ServiceEntry>
services: ServiceMap
}

impl Clone for Container {
#[inline]
fn clone(&self) -> Self {
let mut new_services = HashMap::new();
for (key, value) in &self.services {
let cloned_value = match value {
ServiceEntry::Singleton(service) => ServiceEntry::Singleton(service.clone()),
ServiceEntry::Scoped(service) => ServiceEntry::Scoped(service.clone()),
ServiceEntry::Transient => ServiceEntry::Transient,
};
new_services.insert(*key, cloned_value);
}
Self { services: new_services }
let services = self.services.iter()
.map(|(key, value)| (*key, value.clone()))
.collect();
Self { services }
}
}

impl Container {
/// Creates a new container where Scoped services are not created yet
#[inline]
pub fn create_scope(&self) -> Self {
let mut new_services = HashMap::new();
for (key, value) in &self.services {
let cloned_value = match value {
ServiceEntry::Singleton(service) => ServiceEntry::Singleton(service.clone()),
ServiceEntry::Scoped(_) => ServiceEntry::Scoped(OnceCell::new()),
ServiceEntry::Transient => ServiceEntry::Transient,
};
new_services.insert(*key, cloned_value);
}
Self { services: new_services }
let services = self.services.iter()
.map(|(key, value)| (*key, value.as_scope()))
.collect();
Self { services }
}

/// Resolve a service
pub async fn resolve<T: Inject + 'static>(&mut self) -> Result<T, Error> {
let type_id = TypeId::of::<T>();
let entry = self.services.get(&type_id);
if entry.is_none() {
return Err(DiError::service_not_registered(std::any::type_name::<T>()));
match self.get_service_entry::<T>()? {
ServiceEntry::Transient => T::inject(self).await,
ServiceEntry::Singleton(instance) => Self::resolve_internal(instance).cloned(),
ServiceEntry::Scoped(cell) => {
let instance = cell
.get_or_try_init(|| async {
T::inject(&mut self.clone())
.map_ok(|scoped| Arc::new(scoped) as ArcService)
.await
})
.await?;
Self::resolve_internal(instance).cloned()
}
}
if let Some(service_entry) = entry {
return match service_entry {
ServiceEntry::Transient => T::inject(self).await,
ServiceEntry::Singleton(instance) => Self::resolve_internal(instance),
ServiceEntry::Scoped(cell) => {
let instance = cell
.get_or_try_init(|| async {
self.clone().resolve_scoped::<T>().await
}).await?;
Self::resolve_internal(instance)
}
}

/// Resolve a service as ref
pub async fn resolve_ref<T: Inject + 'static>(&mut self) -> Result<&T, Error> {
match self.get_service_entry::<T>()? {
ServiceEntry::Transient => Err(Error::server_error(
"cannot resolve a `Transient` service as ref, use `resolve::<T>()` or `Dc<T>` instead",
)),
ServiceEntry::Singleton(instance) => Self::resolve_internal(instance),
ServiceEntry::Scoped(cell) => {
let instance = cell
.get_or_try_init(|| async {
T::inject(&mut self.clone())
.map_ok(|scoped| Arc::new(scoped) as ArcService)
.await
})
.await?;
Self::resolve_internal(instance)
}
}
unreachable!();
}

/// Fetch the service entry or return an error if not registered.
#[inline]
fn resolve_internal<T: Inject + 'static>(instance: &ArcService) -> Result<T, Error> {
(**insta 10000 nce).downcast_ref::<T>()
.ok_or(DiError::resolve_error(std::any::type_name::<T>()))
.cloned()
fn get_service_entry<T: Inject + 'static>(&self) -> Result<&ServiceEntry, Error> {
let type_id = TypeId::of::<T>();
self.services
.get(&type_id)
.ok_or_else(|| DiError::service_not_registered(std::any::type_name::<T>()))
}

#[inline]
async fn resolve_scoped<T: Inject + 'static>(&mut self) -> Result<ArcService, Error> {
let scoped = T::inject(self).await?;
Ok(Arc::new(scoped))
fn resolve_internal<T: Inject + 'static>(instance: &ArcService) -> Result<&T, Error> {
(**instance)
.downcast_ref::<T>()
.ok_or(DiError::resolve_error(std::any::type_name::<T>()))
}
}

Expand Down Expand Up @@ -194,6 +237,22 @@ mod tests {
assert_eq!(key, "value");
}

#[tokio::test]
async fn it_registers_singleton_and_resolves_as_ref() {
let mut container = ContainerBuilder::new();
container.register_singleton(InMemoryCache::default());

let mut container = container.build();

let cache = container.resolve_ref::<InMemoryCache>().await.unwrap();
cache.set("key", "value");

let cache = container.resolve_ref::<InMemoryCache>().await.unwrap();
let key = cache.get("key").unwrap();

assert_eq!(key, "value");
}

#[tokio::test]
async fn it_registers_transient() {
let mut container = ContainerBuilder::new();
Expand Down Expand Up @@ -247,6 +306,43 @@ mod tests {
assert_eq!(key, "value 1");
}

#[tokio::test]
async fn it_registers_scoped_and_resolves_as_ref() {
let mut container = ContainerBuilder::new();
container.register_scoped::<InMemoryCache>();

let mut container = container.build();

// working in the initial scope
let cache = container.resolve::<InMemoryCache>().await.unwrap();
cache.set("key", "value 1");

// create a new scope so new instance of InMemoryCache will be created
{
let mut scope = container.create_scope();
let cache = scope.resolve_ref::<InMemoryCache>().await.unwrap();
cache.set("key", "value 2");

let cache = scope.resolve_ref::<InMemoryCache>().await.unwrap();
let key = cache.get("key").unwrap();

assert_eq!(key, "value 2");
}

// create a new scope so new instance of InMemoryCache will be created
{
let mut scope = container.create_scope();
let cache = scope.resolve_ref::<InMemoryCache>().await.unwrap();
let key = cache.get("key");

assert!(key.is_none());
}

let key = cache.get("key").unwrap();

assert_eq!(key, "value 1");
}

#[tokio::test]
async fn it_resolves_inner_dependencies() {
let mut container = ContainerBuilder::new();
Expand All @@ -268,6 +364,27 @@ mod tests {
assert_eq!(key, "value 1");
}

#[tokio::test]
async fn it_resolves_inner_dependencies_as_ref() {
let mut container = ContainerBuilder::new();

container.register_singleton(InMemoryCache::default());
container.register_scoped::<CacheWrapper>();

let mut container = container.build();

{
let mut scope = container.create_scope();
let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();
cache.inner.set("key", "value 1");
}

let cache = container.resolve_ref::<InMemoryCache>().await.unwrap();
let key = cache.get("key").unwrap();

assert_eq!(key, "value 1");
}

#[tokio::test]
async fn inner_scope_does_not_affect_outer() {
let mut container = ContainerBuilder::new();
Expand All @@ -292,6 +409,30 @@ mod tests {
assert!(key.is_none())
}

#[tokio::test]
async fn inner_scope_does_not_affect_outer_with_ref() {
let mut container = ContainerBuilder::new();

container.register_scoped::<InMemoryCache>();
container.register_scoped::<CacheWrapper>();

let mut container = container.build();

{
let mut scope = container.create_scope();
let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();
cache.inner.set("key", "value 1");

let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();
cache.inner.set("key", "value 2");
}

let cache = container.resolve_ref::<InMemoryCache>().await.unwrap();
let key = cache.get("key");

assert!(key.is_none())
}

#[tokio::test]
async fn it_resolves_inner_scoped_dependencies() {
let mut container = ContainerBuilder::new();
Expand All @@ -313,4 +454,26 @@ mod tests {
assert_eq!(cache.inner.get("key1").unwrap(), "value 1");
assert_eq!(cache.inner.get("key2").unwrap(), "value 2");
}

#[tokio::test]
async fn it_resolves_inner_scoped_dependencies_as_ref() {
let mut container = ContainerBuilder::new();

container.register_scoped::<InMemoryCache>();
container.register_scoped::<CacheWrapper>();

let container = container.build();

let mut scope = container.create_scope();
let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();
cache.inner.set("key1", "value 1");

let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();
cache.inner.set("key2", "value 2");

let cache = scope.resolve_ref::<CacheWrapper>().await.unwrap();

assert_eq!(cache.inner.get("key1").unwrap(), "value 1");
assert_eq!(cache.inner.get("key2").unwrap(), "value 2");
}
}
0