Skip to content

Commit 4f7b277

Browse files
committed
Type multi.pyi
1 parent 7981f2c commit 4f7b277

File tree

5 files changed

+160
-40
lines changed

5 files changed

+160
-40
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ np_1darray_complex: TypeAlias = np_1darray[np.complexfloating]
951951
np_1darray_object: TypeAlias = np_1darray[np.object_]
952952
np_1darray_bool: TypeAlias = np_1darray[np.bool]
953953
np_1darray_intp: TypeAlias = np_1darray[np.intp]
954+
np_1darray_int8: TypeAlias = np_1darray[np.int8]
954955
np_1darray_int64: TypeAlias = np_1darray[np.int64]
955956
np_1darray_anyint: TypeAlias = np_1darray[np.integer]
956957
np_1darray_float: TypeAlias = np_1darray[np.floating]

pandas-stubs/core/indexes/multi.pyi

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ from collections.abc import (
77
)
88
from typing import (
99
Any,
10-
final,
1110
overload,
1211
)
1312

@@ -19,15 +18,18 @@ from typing_extensions import Self
1918
from pandas._typing import (
2019
AnyAll,
2120
Axes,
22-
DropKeep,
2321
Dtype,
2422
HashableT,
2523
IndexLabel,
24+
Label,
2625
Level,
2726
MaskType,
2827
NaPosition,
2928
SequenceNotStr,
29+
Shape,
3030
np_1darray_bool,
31+
np_1darray_int8,
32+
np_1darray_intp,
3133
np_ndarray_anyint,
3234
)
3335

@@ -70,19 +72,46 @@ class MultiIndex(Index):
7072
sortorder: int | None = ...,
7173
names: SequenceNotStr[Hashable] = ...,
7274
) -> Self: ...
73-
@property
74-
def shape(self): ...
7575
@property # Should be read-only
7676
def levels(self) -> list[Index]: ...
77-
def set_levels(self, levels, *, level=..., verify_integrity: bool = ...): ...
77+
@overload
78+
def set_levels(
79+
self,
80+
levels: Sequence[SequenceNotStr[Hashable]],
81+
*,
82+
level: Sequence[Level] | None = None,
83+
verify_integrity: bool = True,
84+
) -> MultiIndex: ...
85+
@overload
86+
def set_levels(
87+
self,
88+
levels: SequenceNotStr[Hashable],
89+
*,
90+
level: Level,
91+
verify_integrity: bool = True,
92+
) -> MultiIndex: ...
7893
@property
79-
def codes(self): ...
80-
def set_codes(self, codes, *, level=..., verify_integrity: bool = ...): ...
94+
def codes(self) -> list[np_1darray_int8]: ...
95+
@overload
96+
def set_codes(
97+
self,
98+
codes: Sequence[Sequence[int]],
99+
*,
100+
level: Sequence[Level] | None = None,
101+
verify_integrity: bool = True,
102+
) -> MultiIndex: ...
103+
@overload
104+
def set_codes(
105+
self,
106+
codes: Sequence[int],
107+
*,
108+
level: Level,
109+
verify_integrity: bool = True,
110+
) -> MultiIndex: ...
81111
def copy( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # pyrefly: ignore
82112
self, names: SequenceNotStr[Hashable] = ..., deep: bool = False
83113
) -> Self: ...
84-
def view(self, cls=...): ...
85-
def __contains__(self, key) -> bool: ...
114+
def view(self, cls: Any = None) -> MultiIndex: ... # type: ignore[override] # pyrefly: ignore[bad-override] # pyright: ignore[reportIncompatibleMethodOverride]
86115
@property
87116
def dtype(self) -> np.dtype: ...
88117
@property
@@ -92,29 +121,34 @@ class MultiIndex(Index):
92121
def nbytes(self) -> int: ...
93122
def __len__(self) -> int: ...
94123
@property
95-
def values(self): ...
96-
@property
97124
def is_monotonic_increasing(self) -> bool: ...
98125
@property
99126
def is_monotonic_decreasing(self) -> bool: ...
100-
def duplicated(self, keep: DropKeep = "first"): ...
101127
def dropna(self, how: AnyAll = "any") -> Self: ...
102128
def droplevel(self, level: Level | Sequence[Level] = 0) -> MultiIndex | Index: ... # type: ignore[override]
103129
def get_level_values(self, level: str | int) -> Index: ...
104-
def unique(self, level=...): ...
130+
@overload # type: ignore[override]
131+
def unique( # pyrefly: ignore[bad-override]
132+
self, level: None = None
133+
) -> MultiIndex: ...
134+
@overload
135+
def unique( # ty: ignore[invalid-method-override] # pyright: ignore[reportIncompatibleMethodOverride]
136+
self, level: Level
137+
) -> (
138+
Index
139+
): ... # ty: ignore[invalid-method-override] # pyrefly: ignore[bad-override]
105140
def to_frame( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
106141
self,
107142
index: bool = True,
108143
name: list[HashableT] = ...,
109144
allow_duplicates: bool = False,
110145
) -> pd.DataFrame: ...
111146
def to_flat_index(self) -> Index: ...
112-
def remove_unused_levels(self): ...
147+
def remove_unused_levels(self) -> MultiIndex: ...
113148
@property
114149
def nlevels(self) -> int: ...
115150
@property
116-
def levshape(self): ...
117-
def __reduce__(self): ...
151+
def levshape(self) -> Shape: ...
118152
@overload # type: ignore[override]
119153
# pyrefly: ignore # bad-override
120154
def __getitem__(
@@ -125,36 +159,33 @@ class MultiIndex(Index):
125159
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] # ty: ignore[invalid-method-override]
126160
self, key: int
127161
) -> tuple[Hashable, ...]: ...
128-
def append(self, other): ...
129-
def repeat(self, repeats, axis=...): ...
130-
def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
162+
@overload # type: ignore[override]
163+
def append(self, other: MultiIndex | Sequence[MultiIndex]) -> MultiIndex: ...
164+
@overload
165+
def append( # pyright: ignore[reportIncompatibleMethodOverride]
166+
self, other: Index | Sequence[Index]
167+
) -> Index: ... # pyrefly: ignore[bad-override]
168+
def drop(self, codes: Level | Sequence[Level], level: Level | None = None, errors: str = "raise") -> MultiIndex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
131169
def swaplevel(self, i: int = -2, j: int = -1) -> Self: ...
132-
def reorder_levels(self, order): ...
170+
def reorder_levels(self, order: Sequence[Level]) -> MultiIndex: ...
133171
def sortlevel(
134172
self,
135173
level: Level | Sequence[Level] = 0,
136174
ascending: bool = True,
137175
sort_remaining: bool = True,
138176
na_position: NaPosition = "first",
139-
): ...
140-
@final
141-
def get_indexer(self, target, method=..., limit=..., tolerance=...): ...
142-
def get_indexer_non_unique(self, target): ...
143-
def reindex(self, target, method=..., level=..., limit=..., tolerance=...): ...
144-
def get_slice_bound(
145-
self, label: Hashable | Sequence[Hashable], side: str
146-
) -> int: ...
177+
) -> tuple[MultiIndex, np_1darray_intp]: ...
147178
def get_loc_level(
148-
self, key, level: Level | list[Level] | None = None, drop_level: bool = True
149-
): ...
150-
def get_locs(self, seq): ...
179+
self,
180+
key: Label | Sequence[Label],
181+
level: Level | Sequence[Level] | None = None,
182+
drop_level: bool = True,
183+
) -> tuple[int | slice | np_1darray_bool, Index]: ...
184+
def get_locs(self, seq: Level | Sequence[Level]) -> np_1darray_intp: ...
151185
def truncate(
152186
self, before: IndexLabel | None = None, after: IndexLabel | None = None
153-
): ...
154-
def equals(self, other) -> bool: ...
155-
def equal_levels(self, other): ...
156-
def insert(self, loc, item): ...
157-
def delete(self, loc): ...
187+
) -> MultiIndex: ...
188+
def equal_levels(self, other: MultiIndex) -> bool: ...
158189
@overload # type: ignore[override]
159190
def isin( # pyrefly: ignore[bad-override]
160191
self, values: Iterable[Any], level: Level

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ ignore = [
205205
"PYI042", # https://docs.astral.sh/ruff/rules/snake-case-type-alias/
206206
"ERA001", "PLR0402", "PLC0105"
207207
]
208-
"multi.pyi" = [
209-
# TODO: remove when multi.pyi is fully typed
210-
"ANN001", "ANN201", "ANN204", "ANN206",
211-
]
212208
"indexing.pyi" = [
213209
# TODO: remove when indexing.pyi is fully typed
214210
"ANN001", "ANN201", "ANN204", "ANN206",

tests/_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
np_1darray_complex,
3232
np_1darray_dt,
3333
np_1darray_float,
34+
np_1darray_int8,
3435
np_1darray_int64,
3536
np_1darray_intp,
3637
np_1darray_object,
@@ -81,6 +82,7 @@
8182
"np_ndarray_dt",
8283
"np_1darray_object",
8384
"np_1darray_td",
85+
"np_1darray_int8",
8486
"np_1darray_int64",
8587
"np_ndarray_num",
8688
"FloatDtypeArg",
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
import pandas as pd
4+
from typing_extensions import (
5+
assert_type,
6+
)
7+
8+
from tests import (
9+
check,
10+
)
11+
from tests._typing import (
12+
np_1darray_int8,
13+
np_1darray_intp,
14+
)
15+
16+
17+
def test_multiindex_unique() -> None:
18+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
19+
check(assert_type(mi.unique(), pd.MultiIndex), pd.MultiIndex)
20+
check(assert_type(mi.unique(level=0), pd.Index), pd.Index)
21+
22+
23+
def test_multiindex_set_levels() -> None:
24+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
25+
res = mi.set_levels([[10, 20, 30], [40, 50, 60]])
26+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
27+
res = mi.set_levels([10, 20, 30], level=0)
28+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
29+
30+
31+
def test_multiindex_codes() -> None:
32+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
33+
check(assert_type(mi.codes, list[np_1darray_int8]), list)
34+
35+
36+
def test_multiindex_set_codes() -> None:
37+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
38+
res = mi.set_codes([[0, 1, 2], [0, 1, 2]])
39+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
40+
res = mi.set_codes([0, 1, 2], level=0)
41+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
42+
43+
44+
def test_multiindex_view() -> None:
45+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
46+
check(assert_type(mi.view(), pd.MultiIndex), pd.MultiIndex)
47+
check(assert_type(mi.view(pd.Index), pd.MultiIndex), pd.MultiIndex)
48+
49+
50+
def test_multiindex_remove_unused_levels() -> None:
51+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
52+
res = mi.remove_unused_levels()
53+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
54+
55+
56+
def test_multiindex_levshape() -> None:
57+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
58+
ls = mi.levshape
59+
check(assert_type(ls, tuple[int, ...]), tuple, int)
60+
61+
62+
def test_multiindex_append() -> None:
63+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
64+
check(assert_type(mi.append([mi]), pd.MultiIndex), pd.MultiIndex)
65+
check(assert_type(mi.append([pd.Index([1, 2])]), pd.Index), pd.Index)
66+
67+
68+
def test_multiindex_drop() -> None:
69+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
70+
dropped = mi.drop([1])
71+
check(assert_type(dropped, pd.MultiIndex), pd.MultiIndex)
72+
73+
74+
def test_multiindex_reorder_levels() -> None:
75+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
76+
reordered = mi.reorder_levels([1, 0])
77+
check(assert_type(reordered, pd.MultiIndex), pd.MultiIndex)
78+
79+
80+
def test_multiindex_get_locs() -> None:
81+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
82+
locs = mi.get_locs([1, 4])
83+
check(assert_type(locs, np_1darray_intp), np_1darray_intp)
84+
85+
86+
def test_multiindex_equal_levels() -> None:
87+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
88+
mi2 = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
89+
eq = mi.equal_levels(mi2)
90+
check(assert_type(eq, bool), bool)

0 commit comments

Comments
 (0)