diff --git a/CHANGELOG.md b/CHANGELOG.md index 438cdec..2790c9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [2.1.0] - 2025-11-24 :notes: +## [2.1.0] - 2026-03-?? - Improve `resolve()` typing, by @sobolevn. - Use `Self` type for Container, by @sobolevn. @@ -18,6 +18,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove Codecov from GitHub Workflow and from README. - Upgrade type annotations to Python >= 3.10. - Remove code checks for Python <= 3.10. +- Support mixing `__init__` parameters and class-level annotated properties for + dependency injection. Previously, when a class defined a custom `__init__`, + rodi would only inspect constructor parameters and ignore class-level type + annotations. Now both are resolved: constructor parameters are injected as + arguments, and any remaining class-level annotated properties are injected via + `setattr` after instantiation. This enables patterns like: + + ```python + class MyService: + extra_dep: ExtraDependency # injected via setattr + + def __init__(self, main_dep: MainDependency) -> None: + self.main_dep = main_dep + ``` + + Resolves [issue #43](https://github.com/Neoteroi/rodi/issues/43), reported by + [@lucas-labs](https://github.com/lucas-labs). ## [2.0.8] - 2025-04-12 diff --git a/Makefile b/Makefile index 9a738cd..e2b6242 100644 --- a/Makefile +++ b/Makefile @@ -42,3 +42,24 @@ format: lint-types: mypy rodi --explicit-package-bases + + +check-flake8: + @echo "$(BOLD)Checking flake8$(RESET)" + @flake8 rodi 2>&1 + @flake8 tests 2>&1 + + +check-isort: + @echo "$(BOLD)Checking isort$(RESET)" + @isort --check-only rodi 2>&1 + @isort --check-only tests 2>&1 + + +check-black: ## Run the black tool in check mode only (won't modify files) + @echo "$(BOLD)Checking black$(RESET)" + @black --check rodi 2>&1 + @black --check tests 2>&1 + + +lint: check-flake8 check-isort check-black diff --git a/rodi/__init__.py b/rodi/__init__.py index a36aa88..84a3997 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -463,6 +463,28 @@ def factory(context, parent_type): return FactoryResolver(concrete_type, factory, life_style)(resolver_context) +def get_mixed_type_provider( + concrete_type: Type, + args_callbacks: list, + annotation_resolvers: Mapping[str, Callable], + life_style: ServiceLifeStyle, + resolver_context: ResolutionContext, +): + """ + Provider that combines __init__ argument injection with class-level annotation + property injection. Used when a class defines both a custom __init__ (with or + without parameters) and class-level annotated attributes. + """ + + def factory(context, parent_type): + instance = concrete_type(*[fn(context, parent_type) for fn in args_callbacks]) + for name, resolver in annotation_resolvers.items(): + setattr(instance, name, resolver(context, parent_type)) + return instance + + return FactoryResolver(concrete_type, factory, life_style)(resolver_context) + + def _get_plain_class_factory(concrete_type: Type): def factory(*args): return concrete_type() @@ -645,6 +667,48 @@ def _resolve_by_annotations( self.concrete_type, resolvers, self.life_style, context ) + def _resolve_by_init_and_annotations( + self, context: ResolutionContext, extra_annotations: dict[str, Type] + ): + """ + Resolves by both __init__ parameters and class-level annotated properties. + Used when a class defines a custom __init__ AND class-level type annotations. + The __init__ parameters are injected as constructor arguments; the class + annotations are injected via setattr after instantiation. + """ + sig = Signature.from_callable(self.concrete_type.__init__) + params = { + key: Dependency(key, value.annotation) + for key, value in sig.parameters.items() + } + + if sys.version_info >= (3, 10): # pragma: no cover + globalns = dict(vars(sys.modules[self.concrete_type.__module__])) + globalns.update(_get_obj_globals(self.concrete_type)) + annotations = get_type_hints( + self.concrete_type.__init__, + globalns, + _get_obj_locals(self.concrete_type), + ) + for key, value in params.items(): + if key in annotations: + value.annotation = annotations[key] + + concrete_type = self.concrete_type + init_fns = self._get_resolvers_for_parameters(concrete_type, context, params) + + ann_params = { + key: Dependency(key, value) for key, value in extra_annotations.items() + } + ann_fns = self._get_resolvers_for_parameters(concrete_type, context, ann_params) + annotation_resolvers = { + name: ann_fns[i] for i, name in enumerate(ann_params.keys()) + } + + return get_mixed_type_provider( + concrete_type, init_fns, annotation_resolvers, self.life_style, context + ) + def __call__(self, context: ResolutionContext): concrete_type = self.concrete_type @@ -670,6 +734,35 @@ def __call__(self, context: ResolutionContext): concrete_type, _get_plain_class_factory(concrete_type), self.life_style )(context) + # Custom __init__: also check for class-level annotations to inject as + # properties. The cheap __annotations__ check avoids the expensive + # get_type_hints call for the common case of no class-level annotations. + if concrete_type.__annotations__: + class_annotations = get_type_hints( + concrete_type, + { + **dict(vars(sys.modules[concrete_type.__module__])), + **_get_obj_globals(concrete_type), + }, + _get_obj_locals(concrete_type), + ) + if class_annotations: + sig = Signature.from_callable(concrete_type.__init__) + init_param_names = set(sig.parameters.keys()) - {"self"} + extra_annotations = { + k: v + for k, v in class_annotations.items() + if k not in init_param_names + and not self._ignore_class_attribute(k, v) + } + if extra_annotations: + try: + return self._resolve_by_init_and_annotations( + context, extra_annotations + ) + except RecursionError: + raise CircularDependencyException(chain[0], concrete_type) + try: return self._resolve_by_init_method(context) except RecursionError: diff --git a/tests/examples.py b/tests/examples.py index 4ced794..77dfc76 100644 --- a/tests/examples.py +++ b/tests/examples.py @@ -264,3 +264,66 @@ class PrecedenceOfTypeHintsOverNames: def __init__(self, foo: Q, ko: P): self.q = foo self.p = ko + + +# Classes for testing mixed __init__ + class annotation injection + + +class MixedDep1: + pass + + +class MixedDep2: + pass + + +class MixedNoInitArgs: + """Has a custom __init__ with no injectable args, plus class-level annotations.""" + + injected: MixedDep1 + + def __init__(self) -> None: + self.value = "hello" + + +class MixedWithInitArgs: + """ + Has a custom __init__ with injectable args, plus additional class-level + annotations. + """ + + extra: MixedDep2 + + def __init__(self, dep1: MixedDep1) -> None: + self.dep1 = dep1 + self.value = "hello" + + +class MixedSingleton: + """Singleton variant for mixed injection.""" + + dep2: MixedDep2 + + def __init__(self, dep1: MixedDep1) -> None: + self.dep1 = dep1 + + +class MixedScoped: + """Scoped variant for mixed injection.""" + + dep2: MixedDep2 + + def __init__(self, dep1: MixedDep1) -> None: + self.dep1 = dep1 + + +class MixedAnnotationOverlapsInit: + """ + Class where a class annotation name matches an __init__ parameter. + The annotation should NOT be double-injected; init param takes precedence. + """ + + dep1: MixedDep1 # same name as __init__ param - should be handled by init only + + def __init__(self, dep1: MixedDep1) -> None: + self.dep1 = dep1 diff --git a/tests/test_services.py b/tests/test_services.py index c2cb3bf..2dbdd8f 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -5,7 +5,6 @@ from typing import ( Any, ClassVar, - Dict, Generic, Iterable, List, @@ -65,6 +64,13 @@ Jang, Jing, Ko, + MixedAnnotationOverlapsInit, + MixedDep1, + MixedDep2, + MixedNoInitArgs, + MixedScoped, + MixedSingleton, + MixedWithInitArgs, Ok, P, PrecedenceOfTypeHintsOverNames, @@ -2146,15 +2152,15 @@ class C: def test_dict_generic_alias(): container = Container() - def mapping_int_factory() -> Dict[int, int]: + def mapping_int_factory() -> dict[int, int]: return {1: 1, 2: 2, 3: 3} - def mapping_str_factory() -> Dict[str, int]: + def mapping_str_factory() -> dict[str, int]: return {"a": 1, "b": 2, "c": 3} class C: - a: Dict[int, int] - b: Dict[str, int] + a: dict[int, int] + b: dict[str, int] container.add_scoped_by_factory(mapping_int_factory) container.add_scoped_by_factory(mapping_str_factory) @@ -2760,51 +2766,149 @@ async def test_nested_scope_async_1(): nested_scope_async(), nested_scope_async(), ) - - -# Tests for inject(globalsns=...) being honoured during type resolution (#60) - - -def test_inject_globalsns_honoured_for_annotation_resolution(): - """ - When a class uses a forward reference in a class-level annotation and the - type is provided via inject(globalsns=...), it should be resolved correctly. - """ - - class LocalDep: - pass - - @inject(globalsns={"LocalDep": LocalDep}) - class Service: - dep: "LocalDep" - - container = Container() - container.add_transient(LocalDep) - container.add_transient(Service) - provider = container.build_provider() - - instance = provider.get(Service) - assert isinstance(instance.dep, LocalDep) - - -def test_inject_globalsns_honoured_for_init_resolution(): - """ - When a class uses a forward reference in __init__ and the type is provided - via inject(globalsns=...), it should be resolved correctly. - """ - - class LocalDep: - pass - - @inject(globalsns={"LocalDep": LocalDep}) - class Service: - def __init__(self, dep: "LocalDep") -> None: - self.dep = dep - - container = Container() - container.add_transient(LocalDep) - container.add_transient(Service) - provider = container.build_provider() - - instance = provider.get(Service) - assert isinstance(instance.dep, LocalDep) + + +# Tests for inject(globalsns=...) being honoured during type resolution (#60) + + +def test_inject_globalsns_honoured_for_annotation_resolution(): + """ + When a class uses a forward reference in a class-level annotation and the + type is provided via inject(globalsns=...), it should be resolved correctly. + """ + + class LocalDep: + pass + + @inject(globalsns={"LocalDep": LocalDep}) + class Service: + dep: "LocalDep" + + container = Container() + container.add_transient(LocalDep) + container.add_transient(Service) + provider = container.build_provider() + + instance = provider.get(Service) + assert isinstance(instance.dep, LocalDep) + + +def test_inject_globalsns_honoured_for_init_resolution(): + """ + When a class uses a forward reference in __init__ and the type is provided + via inject(globalsns=...), it should be resolved correctly. + """ + + class LocalDep: + pass + + @inject(globalsns={"LocalDep": LocalDep}) + class Service: + def __init__(self, dep: "LocalDep") -> None: + self.dep = dep + + container = Container() + container.add_transient(LocalDep) + container.add_transient(Service) + provider = container.build_provider() + + instance = provider.get(Service) + assert isinstance(instance.dep, LocalDep) + + +# Tests for mixed __init__ + class annotation injection (issue #43) + + +def test_mixed_no_init_args_transient(): + """ + Class with a custom no-arg __init__ AND class-level annotations: + annotations should be injected via setattr after instantiation. + """ + container = Container() + container.add_transient(MixedDep1) + container.add_transient(MixedNoInitArgs) + provider = container.build_provider() + + instance = provider.get(MixedNoInitArgs) + assert instance is not None + assert isinstance(instance.injected, MixedDep1) + assert instance.value == "hello" + + +def test_mixed_no_init_args_new_instance_each_time(): + """Transient mixed services produce a new instance on each resolve.""" + container = Container() + container.add_transient(MixedDep1) + container.add_transient(MixedNoInitArgs) + provider = container.build_provider() + + a = provider.get(MixedNoInitArgs) + b = provider.get(MixedNoInitArgs) + assert a is not b + + +def test_mixed_with_init_args_transient(): + """ + Class with a custom __init__ that has injectable params AND class-level annotations: + both should be injected. + """ + container = Container() + container.add_transient(MixedDep1) + container.add_transient(MixedDep2) + container.add_transient(MixedWithInitArgs) + provider = container.build_provider() + + instance = provider.get(MixedWithInitArgs) + assert instance is not None + assert isinstance(instance.dep1, MixedDep1) + assert isinstance(instance.extra, MixedDep2) + assert instance.value == "hello" + + +def test_mixed_with_init_args_singleton(): + """Singleton mixed service: same instance returned each time.""" + container = Container() + container.add_singleton(MixedDep1) + container.add_singleton(MixedDep2) + container.add_singleton(MixedSingleton) + provider = container.build_provider() + + a = provider.get(MixedSingleton) + b = provider.get(MixedSingleton) + assert a is b + assert isinstance(a.dep1, MixedDep1) + assert isinstance(a.dep2, MixedDep2) + + +def test_mixed_with_init_args_scoped(): + """Scoped mixed service: same instance within a scope, new across scopes.""" + container = Container() + container.add_scoped(MixedDep1) + container.add_scoped(MixedDep2) + container.add_scoped(MixedScoped) + provider = container.build_provider() + + with provider.create_scope() as scope1: + a = provider.get(MixedScoped, scope1) + b = provider.get(MixedScoped, scope1) + assert a is b + assert isinstance(a.dep1, MixedDep1) + assert isinstance(a.dep2, MixedDep2) + + with provider.create_scope() as scope2: + c = provider.get(MixedScoped, scope2) + assert c is not a + + +def test_mixed_annotation_overlaps_init_param(): + """ + When a class annotation has the same name as an __init__ parameter, + the annotation is NOT double-injected — init handles it, setattr is skipped. + """ + container = Container() + container.add_transient(MixedDep1) + container.add_transient(MixedAnnotationOverlapsInit) + provider = container.build_provider() + + instance = provider.get(MixedAnnotationOverlapsInit) + assert isinstance(instance.dep1, MixedDep1)