Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions statemachine/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Tuple
from typing import TypeVar

from . import registry
from .event import Event
Expand All @@ -17,6 +19,10 @@
from .transition_list import TransitionList


TModel = TypeVar("TModel")
"""TypeVar for the model type in StateMachine."""


class StateMachineMetaclass(type):
"Metaclass for constructing StateMachine classes"

Expand All @@ -36,7 +42,9 @@ def __init__(

cls._abstract = True
cls._strict_states = strict_states
cls._events: Dict[Event, None] = {} # used Dict to preserve order and avoid duplicates
cls._events: Dict[Event, None] = (
{}
) # used Dict to preserve order and avoid duplicates
cls._protected_attrs: set = set()
cls._events_to_update: Dict[Event, Event | None] = {}

Expand Down Expand Up @@ -98,9 +106,9 @@ def _check_final_states(cls):

if final_state_with_invalid_transitions:
raise InvalidDefinition(
_("Cannot declare transitions from final state. Invalid state(s): {}").format(
[s.id for s in final_state_with_invalid_transitions]
)
_(
"Cannot declare transitions from final state. Invalid state(s): {}"
).format([s.id for s in final_state_with_invalid_transitions])
)

def _check_trap_states(cls):
Expand Down Expand Up @@ -133,7 +141,8 @@ def _states_without_path_to_final_states(cls):
return [
state
for state in cls.states
if not state.final and not any(s.final for s in visit_connected_states(state))
if not state.final
and not any(s.final for s in visit_connected_states(state))
]

def _disconnected_states(cls, starting_state):
Expand Down Expand Up @@ -259,3 +268,17 @@ def _update_event_references(cls):
@property
def events(self):
return list(self._events)


class GenericStateMachineMetaclass(StateMachineMetaclass, type(Generic)): # type: ignore[misc]
"""
Metaclass that combines StateMachineMetaclass with Generic.

This allows StateMachine to be parameterized with a model type using Generic[TModel],
enabling type checkers to infer the correct type of the `model` attribute.

The type: ignore[misc] is necessary because mypy has limitations with generic metaclasses,
but this pattern works correctly at runtime and with type checkers.
"""

pass
23 changes: 16 additions & 7 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Generic
from typing import List

from .callbacks import SPECS_ALL
Expand All @@ -18,7 +19,8 @@
from .exceptions import InvalidDefinition
from .exceptions import InvalidStateValue
from .exceptions import TransitionNotAllowed
from .factory import StateMachineMetaclass
from .factory import GenericStateMachineMetaclass
from .factory import TModel
from .graph import iterate_states_and_transitions
from .i18n import _
from .model import Model
Expand All @@ -29,7 +31,7 @@
from .state import State


class StateMachine(metaclass=StateMachineMetaclass):
class StateMachine(Generic[TModel], metaclass=GenericStateMachineMetaclass):
"""

Args:
Expand Down Expand Up @@ -68,14 +70,14 @@ class StateMachine(metaclass=StateMachineMetaclass):

def __init__(
self,
model: Any = None,
model: "TModel | None" = None,
state_field: str = "state",
start_value: Any = None,
rtc: bool = True,
allow_event_without_transition: bool = False,
listeners: "List[object] | None" = None,
):
self.model = model if model is not None else Model()
self.model: TModel = model if model is not None else Model() # type: ignore[assignment]
self.state_field = state_field
self.start_value = start_value
self.allow_event_without_transition = allow_event_without_transition
Expand Down Expand Up @@ -149,7 +151,9 @@ def __setstate__(self, state):
self._engine = self._get_engine(rtc)

def _get_initial_state(self):
initial_state_value = self.start_value if self.start_value else self.initial_state.value
initial_state_value = (
self.start_value if self.start_value else self.initial_state.value
)
try:
return self.states_map[initial_state_value]
except KeyError as err:
Expand All @@ -170,7 +174,9 @@ def bind_events_to(self, *targets):
continue
setattr(target, event, trigger)

def _add_listener(self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL):
def _add_listener(
self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL
):
registry = self._callbacks
for visited in iterate_states_and_transitions(self.states):
listeners.resolve(
Expand Down Expand Up @@ -292,7 +298,10 @@ def events(self) -> "List[Event]":
@property
def allowed_events(self) -> "List[Event]":
"""List of the current allowed events."""
return [getattr(self, event) for event in self.current_state.transitions.unique_events]
return [
getattr(self, event)
for event in self.current_state.transitions.unique_events
]

def _put_nonblocking(self, trigger_data: TriggerData):
"""Put the trigger on the queue without blocking the caller."""
Expand Down
128 changes: 128 additions & 0 deletions tests/test_generic_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for Generic[TModel] support in StateMachine.

Test that type checkers can infer the correct model type when using Generic[TModel].
"""

import pytest

from statemachine import State
from statemachine import StateMachine


class CustomModel:
"""Custom model for testing"""

def __init__(self):
self.state = None
self.custom_attr = "test_value"
self.counter = 0


class GenericStateMachine(StateMachine[CustomModel]):
"""State machine using Generic[CustomModel] for type safety"""

initial = State("Initial", initial=True)
processing = State("Processing")
final = State("Final", final=True)

start = initial.to(processing)
finish = processing.to(final)


class TestGenericSupport:
"""Test suite for Generic[TModel] support"""

def test_generic_statemachine_with_custom_model(self):
"""Test that StateMachine[CustomModel] works with a custom model instance"""
model = CustomModel()
sm = GenericStateMachine(model=model)

assert sm.model is model
assert sm.model.custom_attr == "test_value"
assert sm.model.counter == 0

def test_generic_statemachine_with_default_model(self):
"""Test that StateMachine[CustomModel] works with default Model()"""
sm = GenericStateMachine()

# Default model should be Model(), not CustomModel
assert sm.model is not None
assert sm.current_state == sm.initial

def test_generic_statemachine_transitions_work(self):
"""Test that transitions work correctly with generic state machine"""
model = CustomModel()
sm = GenericStateMachine(model=model)

assert sm.current_state == sm.initial

sm.start()
assert sm.current_state == sm.processing

sm.finish()
assert sm.current_state == sm.final

def test_generic_statemachine_model_persists_across_transitions(self):
"""Test that model state persists across transitions"""
model = CustomModel()
sm = GenericStateMachine(model=model)

# Modify model
sm.model.counter = 42
sm.model.custom_attr = "modified"

# Transition
sm.start()

# Model state should persist
assert sm.model.counter == 42
assert sm.model.custom_attr == "modified"

def test_backward_compatibility_without_generic(self):
"""Test that traditional usage without Generic still works"""

class TraditionalMachine(StateMachine):
"""Non-generic state machine for backward compatibility"""

idle = State("Idle", initial=True)
running = State("Running")

run = idle.to(running)

sm = TraditionalMachine()
assert sm.current_state == sm.idle

sm.run()
assert sm.current_state == sm.running

def test_multiple_generic_machines_with_different_models(self):
"""Test that different generic machines can use different model types"""

class ModelA:
def __init__(self):
self.state = None
self.value_a = "A"

class ModelB:
def __init__(self):
self.state = None
self.value_b = "B"

class MachineA(StateMachine[ModelA]):
initial = State("Initial", initial=True)
final = State("Final", final=True)
go = initial.to(final)

class MachineB(StateMachine[ModelB]):
start = State("Start", initial=True)
end = State("End", final=True)
advance = start.to(end)

model_a = ModelA()
model_b = ModelB()

sm_a = MachineA(model=model_a)
sm_b = MachineB(model=model_b)

assert sm_a.model.value_a == "A"
assert sm_b.model.value_b == "B"