|
14 | 14 | import pytest |
15 | 15 |
|
16 | 16 | from transformer_lens.tools.model_registry.schemas import ( |
| 17 | + ArchitectureAnalysis, |
17 | 18 | ArchitectureGap, |
18 | 19 | ArchitectureGapsReport, |
19 | 20 | ModelEntry, |
20 | 21 | ModelMetadata, |
21 | 22 | SupportedModelsReport, |
| 23 | + TopModel, |
22 | 24 | ) |
23 | 25 |
|
24 | 26 |
|
@@ -295,6 +297,250 @@ def test_json_serialization(self): |
295 | 297 | assert parsed["architecture_id"] == "LlamaForCausalLM" |
296 | 298 |
|
297 | 299 |
|
| 300 | +class TestTopModel: |
| 301 | + """Tests for TopModel dataclass.""" |
| 302 | + |
| 303 | + def test_initialization(self): |
| 304 | + """Test TopModel initialization.""" |
| 305 | + model = TopModel( |
| 306 | + model_id="meta-llama/Llama-2-7b-hf", |
| 307 | + downloads=1_000_000, |
| 308 | + ) |
| 309 | + assert model.model_id == "meta-llama/Llama-2-7b-hf" |
| 310 | + assert model.downloads == 1_000_000 |
| 311 | + |
| 312 | + def test_to_dict(self): |
| 313 | + """Test to_dict serialization.""" |
| 314 | + model = TopModel( |
| 315 | + model_id="google/gemma-2b", |
| 316 | + downloads=500_000, |
| 317 | + ) |
| 318 | + result = model.to_dict() |
| 319 | + assert result == { |
| 320 | + "model_id": "google/gemma-2b", |
| 321 | + "downloads": 500_000, |
| 322 | + } |
| 323 | + |
| 324 | + def test_from_dict(self): |
| 325 | + """Test from_dict deserialization.""" |
| 326 | + data = { |
| 327 | + "model_id": "mistralai/Mistral-7B-v0.1", |
| 328 | + "downloads": 750_000, |
| 329 | + } |
| 330 | + model = TopModel.from_dict(data) |
| 331 | + assert model.model_id == "mistralai/Mistral-7B-v0.1" |
| 332 | + assert model.downloads == 750_000 |
| 333 | + |
| 334 | + def test_roundtrip_serialization(self): |
| 335 | + """Test to_dict -> from_dict roundtrip.""" |
| 336 | + original = TopModel( |
| 337 | + model_id="test/model", |
| 338 | + downloads=12345, |
| 339 | + ) |
| 340 | + serialized = original.to_dict() |
| 341 | + deserialized = TopModel.from_dict(serialized) |
| 342 | + assert deserialized.model_id == original.model_id |
| 343 | + assert deserialized.downloads == original.downloads |
| 344 | + |
| 345 | + def test_json_serialization(self): |
| 346 | + """Test that to_dict output is JSON serializable.""" |
| 347 | + model = TopModel(model_id="org/model", downloads=100) |
| 348 | + json_str = json.dumps(model.to_dict()) |
| 349 | + parsed = json.loads(json_str) |
| 350 | + assert parsed["model_id"] == "org/model" |
| 351 | + assert parsed["downloads"] == 100 |
| 352 | + |
| 353 | + |
| 354 | +class TestArchitectureAnalysis: |
| 355 | + """Tests for ArchitectureAnalysis dataclass.""" |
| 356 | + |
| 357 | + def test_required_fields_only(self): |
| 358 | + """Test ArchitectureAnalysis with required fields only.""" |
| 359 | + analysis = ArchitectureAnalysis( |
| 360 | + architecture_id="LlamaForCausalLM", |
| 361 | + total_models=1000, |
| 362 | + total_downloads=50_000_000, |
| 363 | + avg_model_downloads=50_000.0, |
| 364 | + ) |
| 365 | + assert analysis.architecture_id == "LlamaForCausalLM" |
| 366 | + assert analysis.total_models == 1000 |
| 367 | + assert analysis.total_downloads == 50_000_000 |
| 368 | + assert analysis.avg_model_downloads == 50_000.0 |
| 369 | + assert analysis.top_models == [] |
| 370 | + assert analysis.priority_score == 0.0 |
| 371 | + assert analysis.has_official_implementation is False |
| 372 | + |
| 373 | + def test_all_fields(self): |
| 374 | + """Test ArchitectureAnalysis with all fields populated.""" |
| 375 | + top_models = [ |
| 376 | + TopModel(model_id="meta-llama/Llama-2-7b-hf", downloads=5_000_000), |
| 377 | + TopModel(model_id="meta-llama/Llama-2-13b-hf", downloads=3_000_000), |
| 378 | + ] |
| 379 | + analysis = ArchitectureAnalysis( |
| 380 | + architecture_id="LlamaForCausalLM", |
| 381 | + total_models=1000, |
| 382 | + total_downloads=50_000_000, |
| 383 | + avg_model_downloads=50_000.0, |
| 384 | + top_models=top_models, |
| 385 | + priority_score=0.85, |
| 386 | + has_official_implementation=True, |
| 387 | + ) |
| 388 | + assert analysis.architecture_id == "LlamaForCausalLM" |
| 389 | + assert analysis.total_models == 1000 |
| 390 | + assert analysis.total_downloads == 50_000_000 |
| 391 | + assert analysis.avg_model_downloads == 50_000.0 |
| 392 | + assert len(analysis.top_models) == 2 |
| 393 | + assert analysis.top_models[0].model_id == "meta-llama/Llama-2-7b-hf" |
| 394 | + assert analysis.priority_score == 0.85 |
| 395 | + assert analysis.has_official_implementation is True |
| 396 | + |
| 397 | + def test_to_dict_without_top_models(self): |
| 398 | + """Test to_dict serialization without top models.""" |
| 399 | + analysis = ArchitectureAnalysis( |
| 400 | + architecture_id="GPT2LMHeadModel", |
| 401 | + total_models=500, |
| 402 | + total_downloads=10_000_000, |
| 403 | + avg_model_downloads=20_000.0, |
| 404 | + ) |
| 405 | + result = analysis.to_dict() |
| 406 | + assert result == { |
| 407 | + "architecture_id": "GPT2LMHeadModel", |
| 408 | + "total_models": 500, |
| 409 | + "total_downloads": 10_000_000, |
| 410 | + "avg_model_downloads": 20_000.0, |
| 411 | + "top_models": [], |
| 412 | + "priority_score": 0.0, |
| 413 | + "has_official_implementation": False, |
| 414 | + } |
| 415 | + |
| 416 | + def test_to_dict_with_top_models(self): |
| 417 | + """Test to_dict serialization with top models.""" |
| 418 | + top_models = [ |
| 419 | + TopModel(model_id="openai-community/gpt2", downloads=2_000_000), |
| 420 | + TopModel(model_id="openai-community/gpt2-medium", downloads=500_000), |
| 421 | + ] |
| 422 | + analysis = ArchitectureAnalysis( |
| 423 | + architecture_id="GPT2LMHeadModel", |
| 424 | + total_models=500, |
| 425 | + total_downloads=10_000_000, |
| 426 | + avg_model_downloads=20_000.0, |
| 427 | + top_models=top_models, |
| 428 | + priority_score=0.75, |
| 429 | + has_official_implementation=True, |
| 430 | + ) |
| 431 | + result = analysis.to_dict() |
| 432 | + assert result["architecture_id"] == "GPT2LMHeadModel" |
| 433 | + assert result["total_models"] == 500 |
| 434 | + assert result["total_downloads"] == 10_000_000 |
| 435 | + assert result["avg_model_downloads"] == 20_000.0 |
| 436 | + assert len(result["top_models"]) == 2 |
| 437 | + assert result["top_models"][0]["model_id"] == "openai-community/gpt2" |
| 438 | + assert result["priority_score"] == 0.75 |
| 439 | + assert result["has_official_implementation"] is True |
| 440 | + |
| 441 | + def test_from_dict_required_fields_only(self): |
| 442 | + """Test from_dict with required fields only.""" |
| 443 | + data = { |
| 444 | + "architecture_id": "MistralForCausalLM", |
| 445 | + "total_models": 200, |
| 446 | + "total_downloads": 5_000_000, |
| 447 | + "avg_model_downloads": 25_000.0, |
| 448 | + } |
| 449 | + analysis = ArchitectureAnalysis.from_dict(data) |
| 450 | + assert analysis.architecture_id == "MistralForCausalLM" |
| 451 | + assert analysis.total_models == 200 |
| 452 | + assert analysis.total_downloads == 5_000_000 |
| 453 | + assert analysis.avg_model_downloads == 25_000.0 |
| 454 | + assert analysis.top_models == [] |
| 455 | + assert analysis.priority_score == 0.0 |
| 456 | + assert analysis.has_official_implementation is False |
| 457 | + |
| 458 | + def test_from_dict_all_fields(self): |
| 459 | + """Test from_dict with all fields.""" |
| 460 | + data = { |
| 461 | + "architecture_id": "LlamaForCausalLM", |
| 462 | + "total_models": 1000, |
| 463 | + "total_downloads": 50_000_000, |
| 464 | + "avg_model_downloads": 50_000.0, |
| 465 | + "top_models": [ |
| 466 | + {"model_id": "meta-llama/Llama-2-7b-hf", "downloads": 5_000_000}, |
| 467 | + {"model_id": "meta-llama/Llama-2-13b-hf", "downloads": 3_000_000}, |
| 468 | + ], |
| 469 | + "priority_score": 0.85, |
| 470 | + "has_official_implementation": True, |
| 471 | + } |
| 472 | + analysis = ArchitectureAnalysis.from_dict(data) |
| 473 | + assert analysis.architecture_id == "LlamaForCausalLM" |
| 474 | + assert analysis.total_models == 1000 |
| 475 | + assert len(analysis.top_models) == 2 |
| 476 | + assert analysis.top_models[0].model_id == "meta-llama/Llama-2-7b-hf" |
| 477 | + assert analysis.top_models[1].downloads == 3_000_000 |
| 478 | + assert analysis.priority_score == 0.85 |
| 479 | + assert analysis.has_official_implementation is True |
| 480 | + |
| 481 | + def test_roundtrip_serialization(self): |
| 482 | + """Test to_dict -> from_dict roundtrip.""" |
| 483 | + top_models = [ |
| 484 | + TopModel(model_id="org/model1", downloads=100_000), |
| 485 | + TopModel(model_id="org/model2", downloads=50_000), |
| 486 | + TopModel(model_id="org/model3", downloads=25_000), |
| 487 | + ] |
| 488 | + original = ArchitectureAnalysis( |
| 489 | + architecture_id="TestArch", |
| 490 | + total_models=300, |
| 491 | + total_downloads=1_000_000, |
| 492 | + avg_model_downloads=3333.33, |
| 493 | + top_models=top_models, |
| 494 | + priority_score=0.42, |
| 495 | + has_official_implementation=True, |
| 496 | + ) |
| 497 | + serialized = original.to_dict() |
| 498 | + deserialized = ArchitectureAnalysis.from_dict(serialized) |
| 499 | + assert deserialized.architecture_id == original.architecture_id |
| 500 | + assert deserialized.total_models == original.total_models |
| 501 | + assert deserialized.total_downloads == original.total_downloads |
| 502 | + assert deserialized.avg_model_downloads == original.avg_model_downloads |
| 503 | + assert len(deserialized.top_models) == len(original.top_models) |
| 504 | + assert deserialized.top_models[0].model_id == original.top_models[0].model_id |
| 505 | + assert deserialized.priority_score == original.priority_score |
| 506 | + assert deserialized.has_official_implementation == original.has_official_implementation |
| 507 | + |
| 508 | + def test_json_serialization(self): |
| 509 | + """Test that to_dict output is JSON serializable.""" |
| 510 | + analysis = ArchitectureAnalysis( |
| 511 | + architecture_id="JsonArch", |
| 512 | + total_models=50, |
| 513 | + total_downloads=100_000, |
| 514 | + avg_model_downloads=2000.0, |
| 515 | + top_models=[TopModel(model_id="test/model", downloads=10_000)], |
| 516 | + priority_score=0.5, |
| 517 | + has_official_implementation=False, |
| 518 | + ) |
| 519 | + json_str = json.dumps(analysis.to_dict()) |
| 520 | + parsed = json.loads(json_str) |
| 521 | + assert parsed["architecture_id"] == "JsonArch" |
| 522 | + assert parsed["total_models"] == 50 |
| 523 | + assert len(parsed["top_models"]) == 1 |
| 524 | + |
| 525 | + def test_top_models_empty_list_default(self): |
| 526 | + """Test that top_models is a new list instance for each ArchitectureAnalysis.""" |
| 527 | + analysis1 = ArchitectureAnalysis( |
| 528 | + architecture_id="Arch1", |
| 529 | + total_models=100, |
| 530 | + total_downloads=1000, |
| 531 | + avg_model_downloads=10.0, |
| 532 | + ) |
| 533 | + analysis2 = ArchitectureAnalysis( |
| 534 | + architecture_id="Arch2", |
| 535 | + total_models=200, |
| 536 | + total_downloads=2000, |
| 537 | + avg_model_downloads=10.0, |
| 538 | + ) |
| 539 | + analysis1.top_models.append(TopModel(model_id="test/model", downloads=100)) |
| 540 | + # Verify that modifying analysis1's top_models doesn't affect analysis2 |
| 541 | + assert analysis2.top_models == [] |
| 542 | + |
| 543 | + |
298 | 544 | class TestArchitectureGap: |
299 | 545 | """Tests for ArchitectureGap dataclass.""" |
300 | 546 |
|
|
0 commit comments