diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index a4f5ffa4..9693116e 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -97,7 +97,7 @@ def custom_transform(x): from pydantic import InstanceOf, validate_call from pydantic.dataclasses import dataclass -from pymc.distributions.shape_utils import Dims +from pymc.distributions.shape_utils import Dims, StrongDims from pymc_extras.deserialize import deserialize, register_deserialization @@ -576,7 +576,7 @@ def __init__( ) -> None: self.distribution = distribution self.parameters = parameters - self.dims = dims + self.dims: StrongDims = dims self.centered = centered self.transform = transform @@ -606,12 +606,16 @@ def transform(self, transform: str | None) -> None: self.pytensor_transform = not transform or _get_transform(transform) # type: ignore @property - def dims(self) -> Dims: - """The dimensions of the variable.""" + def dims(self) -> StrongDims: + """The dimensions of the variable. + + It will always be a tuple. Empty tuple for scalar variables. + + """ return self._dims @dims.setter - def dims(self, dims) -> None: + def dims(self, dims: Dims | None) -> None: if isinstance(dims, str): dims = (dims,)