Skip to content

Commit 024b2e0

Browse files
bryce13950claude
andcommitted
MAESTRO: Add ArchitectureAnalysis dataclass for priority analysis (Phase 06)
Added new dataclasses to schemas.py for architecture prioritization: - TopModel: Represents a top model with model_id and downloads count - ArchitectureAnalysis: Comprehensive analysis dataclass containing: - architecture_id, total_models, total_downloads, avg_model_downloads - top_models (list of TopModel), priority_score, has_official_implementation Both classes include to_dict() and from_dict() methods for JSON serialization. Added comprehensive unit tests in test_schemas.py. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 845b68b commit 024b2e0

File tree

3 files changed

+346
-0
lines changed

3 files changed

+346
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Phase 06: Architecture Priority Analysis
2+
3+
This phase enhances the architecture gap analysis tool with intelligent prioritization features. Instead of just showing model counts, it provides actionable insights about which architectures would be most valuable to support next based on multiple factors.
4+
5+
## Tasks
6+
7+
- [x] Enhance `schemas.py` with ArchitectureAnalysis dataclass containing: architecture_id, total_models, total_downloads, avg_model_downloads, top_models (list of top 5 by downloads), priority_score, has_official_implementation
8+
- **Completed**: Added `TopModel` dataclass (model_id, downloads) and `ArchitectureAnalysis` dataclass with all required fields. Both include `to_dict()`, `from_dict()` methods. Added comprehensive tests to `test_schemas.py` (TestTopModel and TestArchitectureAnalysis classes).
9+
- [ ] Update `get_all_architectures()` to also collect aggregate download statistics per architecture
10+
- [ ] Create `transformer_lens/tools/model_registry/priority.py` with function `calculate_priority_score(architecture)` using weighted formula: (total_downloads * 0.4) + (model_count * 0.3) + (recent_activity * 0.3)
11+
- [ ] Add function `get_top_models_per_architecture(architecture_id, n=5)` returning the most downloaded models for each architecture
12+
- [ ] Add function `detect_architecture_family(architecture_id)` to group related architectures (e.g., Llama variants, GPT variants)
13+
- [ ] Update `generate_architecture_gaps.py` to use ArchitectureAnalysis and include priority scores in output
14+
- [ ] Create `docs/ARCHITECTURE_ROADMAP.md` generator showing prioritized list with reasoning for each priority level
15+
- [ ] Add `--top-n` CLI flag to limit architecture gap analysis to top N by priority score (default: 50)
16+
- [ ] Add function `find_similar_supported_architectures(unsupported_id)` that suggests which existing adapter might be closest match for implementation reference
17+
- [ ] Create `docs/IMPLEMENTATION_GUIDE.md` generator that for each top priority architecture shows: similar supported architecture, key differences to handle, estimated complexity
18+
- [ ] Add `--analyze ARCHITECTURE_ID` CLI command that provides deep analysis of a specific unsupported architecture
19+
- [ ] Add monthly trend tracking by storing snapshots of architecture statistics in `data/trends/YYYY-MM.json`
20+
- [ ] Create `docs/TRENDS.md` generator showing which unsupported architectures are growing fastest

tests/tools/model_registry/test_schemas.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import pytest
1515

1616
from transformer_lens.tools.model_registry.schemas import (
17+
ArchitectureAnalysis,
1718
ArchitectureGap,
1819
ArchitectureGapsReport,
1920
ModelEntry,
2021
ModelMetadata,
2122
SupportedModelsReport,
23+
TopModel,
2224
)
2325

2426

@@ -295,6 +297,250 @@ def test_json_serialization(self):
295297
assert parsed["architecture_id"] == "LlamaForCausalLM"
296298

297299

300+
class TestTopModel:
301+
"""Tests for TopModel dataclass."""
302+
303+
def test_initialization(self):
304+
"""Test TopModel initialization."""
305+
model = TopModel(
306+
model_id="meta-llama/Llama-2-7b-hf",
307+
downloads=1_000_000,
308+
)
309+
assert model.model_id == "meta-llama/Llama-2-7b-hf"
310+
assert model.downloads == 1_000_000
311+
312+
def test_to_dict(self):
313+
"""Test to_dict serialization."""
314+
model = TopModel(
315+
model_id="google/gemma-2b",
316+
downloads=500_000,
317+
)
318+
result = model.to_dict()
319+
assert result == {
320+
"model_id": "google/gemma-2b",
321+
"downloads": 500_000,
322+
}
323+
324+
def test_from_dict(self):
325+
"""Test from_dict deserialization."""
326+
data = {
327+
"model_id": "mistralai/Mistral-7B-v0.1",
328+
"downloads": 750_000,
329+
}
330+
model = TopModel.from_dict(data)
331+
assert model.model_id == "mistralai/Mistral-7B-v0.1"
332+
assert model.downloads == 750_000
333+
334+
def test_roundtrip_serialization(self):
335+
"""Test to_dict -> from_dict roundtrip."""
336+
original = TopModel(
337+
model_id="test/model",
338+
downloads=12345,
339+
)
340+
serialized = original.to_dict()
341+
deserialized = TopModel.from_dict(serialized)
342+
assert deserialized.model_id == original.model_id
343+
assert deserialized.downloads == original.downloads
344+
345+
def test_json_serialization(self):
346+
"""Test that to_dict output is JSON serializable."""
347+
model = TopModel(model_id="org/model", downloads=100)
348+
json_str = json.dumps(model.to_dict())
349+
parsed = json.loads(json_str)
350+
assert parsed["model_id"] == "org/model"
351+
assert parsed["downloads"] == 100
352+
353+
354+
class TestArchitectureAnalysis:
355+
"""Tests for ArchitectureAnalysis dataclass."""
356+
357+
def test_required_fields_only(self):
358+
"""Test ArchitectureAnalysis with required fields only."""
359+
analysis = ArchitectureAnalysis(
360+
architecture_id="LlamaForCausalLM",
361+
total_models=1000,
362+
total_downloads=50_000_000,
363+
avg_model_downloads=50_000.0,
364+
)
365+
assert analysis.architecture_id == "LlamaForCausalLM"
366+
assert analysis.total_models == 1000
367+
assert analysis.total_downloads == 50_000_000
368+
assert analysis.avg_model_downloads == 50_000.0
369+
assert analysis.top_models == []
370+
assert analysis.priority_score == 0.0
371+
assert analysis.has_official_implementation is False
372+
373+
def test_all_fields(self):
374+
"""Test ArchitectureAnalysis with all fields populated."""
375+
top_models = [
376+
TopModel(model_id="meta-llama/Llama-2-7b-hf", downloads=5_000_000),
377+
TopModel(model_id="meta-llama/Llama-2-13b-hf", downloads=3_000_000),
378+
]
379+
analysis = ArchitectureAnalysis(
380+
architecture_id="LlamaForCausalLM",
381+
total_models=1000,
382+
total_downloads=50_000_000,
383+
avg_model_downloads=50_000.0,
384+
top_models=top_models,
385+
priority_score=0.85,
386+
has_official_implementation=True,
387+
)
388+
assert analysis.architecture_id == "LlamaForCausalLM"
389+
assert analysis.total_models == 1000
390+
assert analysis.total_downloads == 50_000_000
391+
assert analysis.avg_model_downloads == 50_000.0
392+
assert len(analysis.top_models) == 2
393+
assert analysis.top_models[0].model_id == "meta-llama/Llama-2-7b-hf"
394+
assert analysis.priority_score == 0.85
395+
assert analysis.has_official_implementation is True
396+
397+
def test_to_dict_without_top_models(self):
398+
"""Test to_dict serialization without top models."""
399+
analysis = ArchitectureAnalysis(
400+
architecture_id="GPT2LMHeadModel",
401+
total_models=500,
402+
total_downloads=10_000_000,
403+
avg_model_downloads=20_000.0,
404+
)
405+
result = analysis.to_dict()
406+
assert result == {
407+
"architecture_id": "GPT2LMHeadModel",
408+
"total_models": 500,
409+
"total_downloads": 10_000_000,
410+
"avg_model_downloads": 20_000.0,
411+
"top_models": [],
412+
"priority_score": 0.0,
413+
"has_official_implementation": False,
414+
}
415+
416+
def test_to_dict_with_top_models(self):
417+
"""Test to_dict serialization with top models."""
418+
top_models = [
419+
TopModel(model_id="openai-community/gpt2", downloads=2_000_000),
420+
TopModel(model_id="openai-community/gpt2-medium", downloads=500_000),
421+
]
422+
analysis = ArchitectureAnalysis(
423+
architecture_id="GPT2LMHeadModel",
424+
total_models=500,
425+
total_downloads=10_000_000,
426+
avg_model_downloads=20_000.0,
427+
top_models=top_models,
428+
priority_score=0.75,
429+
has_official_implementation=True,
430+
)
431+
result = analysis.to_dict()
432+
assert result["architecture_id"] == "GPT2LMHeadModel"
433+
assert result["total_models"] == 500
434+
assert result["total_downloads"] == 10_000_000
435+
assert result["avg_model_downloads"] == 20_000.0
436+
assert len(result["top_models"]) == 2
437+
assert result["top_models"][0]["model_id"] == "openai-community/gpt2"
438+
assert result["priority_score"] == 0.75
439+
assert result["has_official_implementation"] is True
440+
441+
def test_from_dict_required_fields_only(self):
442+
"""Test from_dict with required fields only."""
443+
data = {
444+
"architecture_id": "MistralForCausalLM",
445+
"total_models": 200,
446+
"total_downloads": 5_000_000,
447+
"avg_model_downloads": 25_000.0,
448+
}
449+
analysis = ArchitectureAnalysis.from_dict(data)
450+
assert analysis.architecture_id == "MistralForCausalLM"
451+
assert analysis.total_models == 200
452+
assert analysis.total_downloads == 5_000_000
453+
assert analysis.avg_model_downloads == 25_000.0
454+
assert analysis.top_models == []
455+
assert analysis.priority_score == 0.0
456+
assert analysis.has_official_implementation is False
457+
458+
def test_from_dict_all_fields(self):
459+
"""Test from_dict with all fields."""
460+
data = {
461+
"architecture_id": "LlamaForCausalLM",
462+
"total_models": 1000,
463+
"total_downloads": 50_000_000,
464+
"avg_model_downloads": 50_000.0,
465+
"top_models": [
466+
{"model_id": "meta-llama/Llama-2-7b-hf", "downloads": 5_000_000},
467+
{"model_id": "meta-llama/Llama-2-13b-hf", "downloads": 3_000_000},
468+
],
469+
"priority_score": 0.85,
470+
"has_official_implementation": True,
471+
}
472+
analysis = ArchitectureAnalysis.from_dict(data)
473+
assert analysis.architecture_id == "LlamaForCausalLM"
474+
assert analysis.total_models == 1000
475+
assert len(analysis.top_models) == 2
476+
assert analysis.top_models[0].model_id == "meta-llama/Llama-2-7b-hf"
477+
assert analysis.top_models[1].downloads == 3_000_000
478+
assert analysis.priority_score == 0.85
479+
assert analysis.has_official_implementation is True
480+
481+
def test_roundtrip_serialization(self):
482+
"""Test to_dict -> from_dict roundtrip."""
483+
top_models = [
484+
TopModel(model_id="org/model1", downloads=100_000),
485+
TopModel(model_id="org/model2", downloads=50_000),
486+
TopModel(model_id="org/model3", downloads=25_000),
487+
]
488+
original = ArchitectureAnalysis(
489+
architecture_id="TestArch",
490+
total_models=300,
491+
total_downloads=1_000_000,
492+
avg_model_downloads=3333.33,
493+
top_models=top_models,
494+
priority_score=0.42,
495+
has_official_implementation=True,
496+
)
497+
serialized = original.to_dict()
498+
deserialized = ArchitectureAnalysis.from_dict(serialized)
499+
assert deserialized.architecture_id == original.architecture_id
500+
assert deserialized.total_models == original.total_models
501+
assert deserialized.total_downloads == original.total_downloads
502+
assert deserialized.avg_model_downloads == original.avg_model_downloads
503+
assert len(deserialized.top_models) == len(original.top_models)
504+
assert deserialized.top_models[0].model_id == original.top_models[0].model_id
505+
assert deserialized.priority_score == original.priority_score
506+
assert deserialized.has_official_implementation == original.has_official_implementation
507+
508+
def test_json_serialization(self):
509+
"""Test that to_dict output is JSON serializable."""
510+
analysis = ArchitectureAnalysis(
511+
architecture_id="JsonArch",
512+
total_models=50,
513+
total_downloads=100_000,
514+
avg_model_downloads=2000.0,
515+
top_models=[TopModel(model_id="test/model", downloads=10_000)],
516+
priority_score=0.5,
517+
has_official_implementation=False,
518+
)
519+
json_str = json.dumps(analysis.to_dict())
520+
parsed = json.loads(json_str)
521+
assert parsed["architecture_id"] == "JsonArch"
522+
assert parsed["total_models"] == 50
523+
assert len(parsed["top_models"]) == 1
524+
525+
def test_top_models_empty_list_default(self):
526+
"""Test that top_models is a new list instance for each ArchitectureAnalysis."""
527+
analysis1 = ArchitectureAnalysis(
528+
architecture_id="Arch1",
529+
total_models=100,
530+
total_downloads=1000,
531+
avg_model_downloads=10.0,
532+
)
533+
analysis2 = ArchitectureAnalysis(
534+
architecture_id="Arch2",
535+
total_models=200,
536+
total_downloads=2000,
537+
avg_model_downloads=10.0,
538+
)
539+
analysis1.top_models.append(TopModel(model_id="test/model", downloads=100))
540+
# Verify that modifying analysis1's top_models doesn't affect analysis2
541+
assert analysis2.top_models == []
542+
543+
298544
class TestArchitectureGap:
299545
"""Tests for ArchitectureGap dataclass."""
300546

transformer_lens/tools/model_registry/schemas.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,86 @@ def from_dict(cls, data: dict) -> "ModelEntry":
9999
)
100100

101101

102+
@dataclass
103+
class TopModel:
104+
"""Represents a top model for an architecture.
105+
106+
Attributes:
107+
model_id: The HuggingFace model identifier
108+
downloads: Total download count for the model
109+
"""
110+
111+
model_id: str
112+
downloads: int
113+
114+
def to_dict(self) -> dict:
115+
"""Convert to a dictionary for JSON serialization."""
116+
return {
117+
"model_id": self.model_id,
118+
"downloads": self.downloads,
119+
}
120+
121+
@classmethod
122+
def from_dict(cls, data: dict) -> "TopModel":
123+
"""Create a TopModel from a dictionary."""
124+
return cls(
125+
model_id=data["model_id"],
126+
downloads=data["downloads"],
127+
)
128+
129+
130+
@dataclass
131+
class ArchitectureAnalysis:
132+
"""Comprehensive analysis of an architecture with prioritization data.
133+
134+
This dataclass provides detailed information about an architecture to help
135+
prioritize which unsupported architectures should be implemented next.
136+
137+
Attributes:
138+
architecture_id: The HuggingFace architecture class name (e.g., "LlamaForCausalLM")
139+
total_models: The total number of models using this architecture on HuggingFace
140+
total_downloads: The aggregate download count across all models of this architecture
141+
avg_model_downloads: The average downloads per model for this architecture
142+
top_models: List of top 5 models by downloads for this architecture
143+
priority_score: Calculated priority score for implementation (higher = more important)
144+
has_official_implementation: Whether an official (non-community) implementation exists
145+
"""
146+
147+
architecture_id: str
148+
total_models: int
149+
total_downloads: int
150+
avg_model_downloads: float
151+
top_models: list[TopModel] = field(default_factory=list)
152+
priority_score: float = 0.0
153+
has_official_implementation: bool = False
154+
155+
def to_dict(self) -> dict:
156+
"""Convert to a dictionary for JSON serialization."""
157+
return {
158+
"architecture_id": self.architecture_id,
159+
"total_models": self.total_models,
160+
"total_downloads": self.total_downloads,
161+
"avg_model_downloads": self.avg_model_downloads,
162+
"top_models": [m.to_dict() for m in self.top_models],
163+
"priority_score": self.priority_score,
164+
"has_official_implementation": self.has_official_implementation,
165+
}
166+
167+
@classmethod
168+
def from_dict(cls, data: dict) -> "ArchitectureAnalysis":
169+
"""Create an ArchitectureAnalysis from a dictionary."""
170+
top_models = [TopModel.from_dict(m) for m in data.get("top_models", [])]
171+
return cls(
172+
architecture_id=data["architecture_id"],
173+
total_models=data["total_models"],
174+
total_downloads=data["total_downloads"],
175+
avg_model_downloads=data["avg_model_downloads"],
176+
top_models=top_models,
177+
priority_score=data.get("priority_score", 0.0),
178+
has_official_implementation=data.get("has_official_implementation", False),
179+
)
180+
181+
102182
@dataclass
103183
class ArchitectureGap:
104184
"""Represents an unsupported architecture and its model count on HuggingFace.

0 commit comments

Comments
 (0)