-
Notifications
You must be signed in to change notification settings - Fork 40
Implement guided generation for SystemLanguageModel
#59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements guided generation for SystemLanguageModel, enabling structured output generation using schema-based constraints. The implementation adds support for generating strongly-typed Swift structs from natural language prompts by converting GenerationSchema to FoundationModels' schema format and handling both streaming and non-streaming responses.
Key Changes:
- Added schema-based generation support for non-String types using FoundationModels' schema APIs with proper $defs dependency extraction
- Implemented partial JSON decoding with fallback to placeholder content for graceful error recovery
- Refactored ResponseStream to use fallbackSnapshot pattern instead of storing concrete content values
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| Sources/AnyLanguageModel/Models/SystemLanguageModel.swift | Core implementation of guided generation with schema conversion, streaming support, and placeholder fallback logic |
| Tests/AnyLanguageModelTests/SystemLanguageModelTests.swift | Comprehensive test suite covering simple structs, nested structures, arrays, enum constraints, and streaming |
| Sources/AnyLanguageModel/LanguageModelSession.swift | Refactored ResponseStream to use optional fallbackSnapshot for cleaner handling of streaming vs non-streaming cases |
| Sources/AnyLanguageModel/GenerationSchema.swift | Made defs property accessible to support schema resolution in SystemLanguageModel |
| Sources/AnyLanguageModel/Generable.swift | Enhanced asPartiallyGenerated() with safer conversion logic and fallback to generatedContent reconstruction |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| .init(content: placeholder.content, rawContent: placeholder.rawContent) | ||
| ) | ||
| } | ||
| continuation.finish() |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the error handling block of processTextFallback(), the error is caught but not propagated or logged. This silently swallows errors from the streaming process, making it difficult to debug issues. The catch block should either call continuation.finish(throwing: error) to propagate the error to the consumer, or at a minimum log the error before finishing normally with a placeholder.
| continuation.finish() | |
| continuation.finish(throwing: error) |
| if chunkText.count >= lastLength, chunkText.hasPrefix(accumulatedText) { | ||
| let startIdx = chunkText.index(chunkText.startIndex, offsetBy: lastLength) | ||
| let delta = String(chunkText[startIdx...]) | ||
| accumulatedText += delta | ||
| lastLength = chunkText.count | ||
| } else if chunkText.hasPrefix(accumulatedText) { | ||
| accumulatedText = chunkText | ||
| lastLength = chunkText.count | ||
| } else if accumulatedText.hasPrefix(chunkText) { | ||
| accumulatedText = chunkText | ||
| lastLength = chunkText.count | ||
| } else { | ||
| accumulatedText += chunkText | ||
| lastLength = accumulatedText.count | ||
| } |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This text accumulation logic is duplicated across processStringStream (lines 174-188) and processTextFallback (lines 224-238). Consider extracting this into a shared helper function to reduce code duplication and make maintenance easier.
| @Test func guidedGenerationMathCalculation() async throws { | ||
| let session = LanguageModelSession(model: SystemLanguageModel.default) | ||
|
|
||
| let response = try await session.respond( | ||
| to: "Calculate 15 + 27", | ||
| generating: MathResult.self | ||
| ) | ||
|
|
||
| #expect(!response.content.expression.isEmpty) | ||
| #expect(response.content.result == 42) |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test assumes the model will always return exactly 42 for "15 + 27", but language models can be non-deterministic and may format responses differently (e.g., "forty-two" as text, or produce calculation errors). This assertion could lead to flaky tests. Consider using a range check or verifying the response is reasonable rather than expecting an exact value.
| // As a last resort, return an empty payload. | ||
| let empty = GeneratedContent("") | ||
| return LanguageModelSession.Response( | ||
| content: content, | ||
| rawContent: rawContent, | ||
| content: try Content(empty), | ||
| rawContent: empty, | ||
| transcriptEntries: [] | ||
| ) |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last resort empty payload fallback in collect() attempts to create Content from an empty string, which will throw an error for most structured types. This means collect() can fail unexpectedly when no snapshots are received. The try keyword propagates this error, but the fallback should either return a proper error indicating no content was received, or ensure the empty GeneratedContent is valid for the Content type.
|
|
||
| func finalize(content: Content) -> LanguageModelSession.Response<Content> { | ||
| let normalizedRaw = content.generatedContent | ||
| if normalizedRaw.jsonString.contains("[]"), let placeholder = placeholderContent(for: type) { |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for empty arrays using contains("[]") is fragile and could produce false positives. For example, a valid JSON object with a field containing the literal string "[]" or an array property would incorrectly trigger this placeholder logic. Consider using proper JSON parsing to check if the response is actually empty or invalid, rather than string matching.
| if normalizedRaw.jsonString.contains("[]"), let placeholder = placeholderContent(for: type) { | |
| if let data = normalizedRaw.jsonString.data(using: .utf8), | |
| let json = try? JSONSerialization.jsonObject(with: data, options: []), | |
| let array = json as? [Any], | |
| array.isEmpty, | |
| let placeholder = placeholderContent(for: type) { |
| if let partial: Self.PartiallyGenerated = try? .init(self.generatedContent) { | ||
| return partial | ||
| } | ||
| fatalError("Unable to convert \(Self.self) to partially generated form") |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fatalError will crash the application when a type cannot be converted to its partially generated form. This is a critical runtime failure that could occur during normal streaming operations when partial data is invalid. Consider throwing an error instead of using fatalError, or return a fallback value that indicates the conversion failed, allowing the caller to handle the error gracefully.
| func processTextFallback() async { | ||
| let fmTextStream: FoundationModels.LanguageModelSession.ResponseStream<String> = | ||
| fmSession.streamResponse(to: fmPrompt, options: fmOptions) | ||
|
|
||
| var accumulatedText = "" | ||
| var didYield = false | ||
| do { | ||
| var lastLength = 0 | ||
| for try await snapshot in fmTextStream { | ||
| var chunkText: String = snapshot.content | ||
| if chunkText == "null" && accumulatedText.isEmpty { | ||
| chunkText = "" | ||
| } | ||
|
|
||
| if chunkText.count >= lastLength, chunkText.hasPrefix(accumulatedText) { | ||
| let startIdx = chunkText.index(chunkText.startIndex, offsetBy: lastLength) | ||
| let delta = String(chunkText[startIdx...]) | ||
| accumulatedText += delta | ||
| lastLength = chunkText.count | ||
| } else if chunkText.hasPrefix(accumulatedText) { | ||
| accumulatedText = chunkText | ||
| lastLength = chunkText.count | ||
| } else if accumulatedText.hasPrefix(chunkText) { | ||
| accumulatedText = chunkText | ||
| lastLength = chunkText.count | ||
| } else { | ||
| accumulatedText += chunkText | ||
| lastLength = accumulatedText.count | ||
| } | ||
|
|
||
| let jsonString = accumulatedText | ||
| if let partialContent = try? partialDecoder.decode( | ||
| GeneratedContent.self, | ||
| from: jsonString | ||
| ) | ||
| .value { | ||
| let partial: Content.PartiallyGenerated? = try? .init(partialContent) | ||
| if let partial { | ||
| continuation.yield(.init(content: partial, rawContent: partialContent)) | ||
| didYield = true | ||
| } | ||
| } | ||
| } | ||
| if !didYield, let placeholder = placeholderPartialContent(for: type) { | ||
| continuation.yield( | ||
| .init(content: placeholder.content, rawContent: placeholder.rawContent) | ||
| ) | ||
| } | ||
| continuation.finish() | ||
| } catch { | ||
| if !didYield, let placeholder = placeholderPartialContent(for: type) { | ||
| continuation.yield( | ||
| .init(content: placeholder.content, rawContent: placeholder.rawContent) | ||
| ) | ||
| } | ||
| continuation.finish() | ||
| } | ||
| // Build raw content from plain text | ||
| let raw: GeneratedContent = GeneratedContent(accumulatedText) | ||
| } |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The processTextFallback function initiates a completely new streaming request when the structured stream fails. This creates redundant API calls and doubles the resource usage. If the structured stream fails partway through, the fallback discards all previously received data and starts over. Consider either passing through the error immediately, or implementing a mechanism to preserve partial results before falling back.
| } | ||
|
|
||
| continuation.yield(.init(content: snapshotContent, rawContent: raw)) | ||
| let task: _Concurrency.Task<Void, Never> = _Concurrency.Task(priority: nil) { |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name 'task' shadows the outer 'Task' type from Swift Concurrency. While the explicit '_Concurrency.Task' qualification is used here, this shadowing could cause confusion. Consider renaming the variable to something more descriptive like 'streamingTask' or 'processingTask'.
| private final class UnsafeSendableBox<T>: @unchecked Sendable { | ||
| var value: T | ||
| init(value: T) { self.value = value } | ||
| } |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The UnsafeSendableBox type is changed from a struct to a class, but there are no usages of this type in the codebase. If this type is not needed, consider removing it. If it's intended for future use, consider adding a comment explaining its purpose and when it should be used. The mutable var property in the class also introduces additional concurrency risks compared to the previous immutable struct.
Related to #27