Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0acab2b
Added bridge class for llms to fix retrieval issue
pk-zipstack Feb 12, 2026
5fc3ef8
Fixed import error with sub-question retrieval
pk-zipstack Feb 12, 2026
dc57c53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
0055521
[FEAT] Rewrite LLMCompat to emulate llama-index interface without dep…
hari-kuriakose Feb 19, 2026
0be9fb1
Merge branch 'main' into fix/retriever-llm-bridge-class
pk-zipstack Feb 19, 2026
922a8a1
Remove investigation notes file from branch
pk-zipstack Mar 3, 2026
4a6c392
Update prompt-service/src/unstract/prompt_service/core/retrievers/bas…
pk-zipstack Mar 3, 2026
92dba83
Address PR review comments: simplify LLM conversion in retrievers
pk-zipstack Mar 3, 2026
a1040c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
e3e91fe
Merge branch 'main' into fix/retriever-llm-bridge-class
pk-zipstack Mar 3, 2026
2b3c3ff
Merge branch 'main' into fix/retriever-llm-bridge-class
pk-zipstack Mar 12, 2026
306e3e7
Add LLMCompat.from_llm() factory and unit tests for retriever LLM
pk-zipstack Mar 13, 2026
97f5d2c
Commit uv.lock changes
pk-zipstack Mar 13, 2026
c6fe158
Move SDK1 tests to sdk1/tests and keep only RetrieverLLM tests in pro…
pk-zipstack Mar 13, 2026
cbec777
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
211baf6
Address code review feedback from greptile
pk-zipstack Mar 13, 2026
3273408
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
12f78e1
Fix test mock for system_prompt and move litellm.drop_params to modul…
pk-zipstack Mar 13, 2026
2ac1613
Make RetrieverLLM construction lazy to avoid redundant init
pk-zipstack Mar 13, 2026
0b9c5e5
Merge branch 'main' into fix/retriever-llm-bridge-class
pk-zipstack Mar 13, 2026
d55f8e4
Reuse existing LLM instance in LLMCompat.from_llm() instead of re-cre…
pk-zipstack Mar 18, 2026
ebeb1a3
Delegate LLMCompat calls to LLM instead of calling litellm directly
pk-zipstack Mar 18, 2026
f8f88c0
Add tests for LLMCompat delegation and BaseRetriever.llm lazy property
pk-zipstack Mar 18, 2026
787eb0d
Use drop litellm params at the module level
pk-zipstack Mar 18, 2026
4d68900
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
3b5b6f0
Update LLMCompat and _messages_to_prompt docstrings
pk-zipstack Mar 18, 2026
17aabaf
Preserve system messages in _messages_to_prompt and add pytest-asynci…
pk-zipstack Mar 18, 2026
b77c707
Commit uv.lock changes
pk-zipstack Mar 18, 2026
aa2e5ee
Merge branch 'main' into fix/retriever-llm-bridge-class
pk-zipstack Mar 18, 2026
08564f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
b4f63d5
Add require_llm() guard for retrievers that need an LLM
pk-zipstack Mar 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions platform-service/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions prompt-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ unstract-sdk1 = { path = "../unstract/sdk1", editable = true }
[dependency-groups]
test = [
"pytest~=8.0.1",
"pytest-asyncio>=0.23.0",
"pytest-dotenv==0.5.2",
"pytest-mock~=3.14.0",
"pytest-md-report>=0.6.2",
"python-dotenv==1.0.1",
"flask-WTF~=1.1",
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unstract.prompt_service.core.retrievers.retriever_llm import RetrieverLLM
from unstract.sdk1.llm import LLM
from unstract.sdk1.vector_db import VectorDB

Expand All @@ -23,7 +24,37 @@ def __init__(
self.prompt = prompt
self.doc_id = doc_id
self.top_k = top_k
self.llm = llm if llm else None
self._llm: LLM | None = llm
self._retriever_llm: RetrieverLLM | None = None

@property
def llm(self) -> RetrieverLLM | None:
"""Return a llama-index compatible LLM, lazily created on first access.

Avoids the cost of RetrieverLLM construction (adapter init,
CallbackManager setup) for retrievers that never use the LLM
(Simple, Automerging, Recursive).
"""
if self._llm is None:
return None
if self._retriever_llm is None:
self._retriever_llm = RetrieverLLM(llm=self._llm)
return self._retriever_llm

def require_llm(self) -> RetrieverLLM:
"""Return the llama-index LLM or raise if not configured.

Call this in retrievers that need an LLM (KeywordTable, Fusion,
Subquestion) to fail early with a clear message instead of
letting llama-index silently fall back to its default OpenAI LLM.
"""
llm = self.llm
if llm is None:
raise ValueError(
f"{type(self).__name__} requires an LLM. "
"Pass llm= when constructing the retriever."
)
return llm

@staticmethod
def retrieve() -> set[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def retrieve(self) -> set[str]:
set[str]: A set of text chunks retrieved from the database.
"""
try:
llm = self.require_llm()
logger.info(
f"Retrieving chunks for {self.doc_id} using LlamaIndex QueryFusionRetriever."
)
Expand Down Expand Up @@ -64,7 +65,7 @@ def retrieve(self) -> set[str]:
mode="simple", # Use simple fusion mode (reciprocal rank fusion)
use_async=False,
verbose=True,
llm=self.llm, # LLM generates query variations
llm=llm,
)

# Retrieve nodes using fusion technique
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def retrieve(self) -> set[str]:
set[str]: A set of text chunks retrieved from the database.
"""
try:
llm = self.require_llm()
logger.info(
f"Retrieving chunks for {self.doc_id} using LlamaIndex KeywordTableIndex."
)
Expand Down Expand Up @@ -48,7 +49,7 @@ def retrieve(self) -> set[str]:
keyword_index = KeywordTableIndex(
nodes=[node.node for node in all_nodes],
show_progress=True,
llm=self.llm, # Use the provided LLM instead of defaulting to OpenAI
llm=llm,
)

# Create retriever from keyword index
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections.abc import Sequence
from typing import Any

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.core.llms.llm import LLM as LlamaIndexBaseLLM # noqa: N811
from pydantic import PrivateAttr

from unstract.sdk1.llm import LLM, LLMCompat


class RetrieverLLM(LlamaIndexBaseLLM):
"""Bridges SDK1's LLMCompat with llama-index's LLM for retriever use.

Llama-index's ``resolve_llm()`` asserts ``isinstance(llm, LLM)``
where ``LLM`` is ``llama_index.core.llms.llm.LLM``. Since SDK1's
``LLMCompat`` is a plain class without llama-index inheritance,
it fails this check.

``RetrieverLLM`` inherits from llama-index's ``LLM`` base class
(passing the isinstance check) and delegates all LLM calls to an
internal ``LLMCompat`` instance.
"""

_compat: LLMCompat = PrivateAttr()

def __init__(self, llm: LLM, **kwargs: Any) -> None: # noqa: ANN401
"""Initialize with an SDK1 LLM instance."""
super().__init__(**kwargs)
self._compat = LLMCompat.from_llm(llm)

@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
is_chat_model=True,
model_name=self._compat.get_model_name(),
)

# ── Sync ─────────────────────────────────────────────────────────────────

def chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any, # noqa: ANN401
) -> ChatResponse:
result = self._compat.chat(messages, **kwargs)
return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=result.message.content,
),
raw=result.raw,
)

def complete(
self,
prompt: str,
formatted: bool = False,
**kwargs: Any, # noqa: ANN401
) -> CompletionResponse:
result = self._compat.complete(prompt, formatted=formatted, **kwargs)
return CompletionResponse(text=result.text, raw=result.raw)

def stream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any, # noqa: ANN401
) -> ChatResponseGen:
raise NotImplementedError("stream_chat is not supported.")

def stream_complete(
self,
prompt: str,
formatted: bool = False,
**kwargs: Any, # noqa: ANN401
) -> CompletionResponseGen:
raise NotImplementedError("stream_complete is not supported.")

# ── Async ────────────────────────────────────────────────────────────────

async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any, # noqa: ANN401
) -> ChatResponse:
result = await self._compat.achat(messages, **kwargs)
return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=result.message.content,
),
raw=result.raw,
)

async def acomplete(
self,
prompt: str,
formatted: bool = False,
**kwargs: Any, # noqa: ANN401
) -> CompletionResponse:
result = await self._compat.acomplete(prompt, formatted=formatted, **kwargs)
return CompletionResponse(text=result.text, raw=result.raw)

async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any, # noqa: ANN401
) -> ChatResponseAsyncGen:
raise NotImplementedError("astream_chat is not supported.")

async def astream_complete(
self,
prompt: str,
formatted: bool = False,
**kwargs: Any, # noqa: ANN401
) -> CompletionResponseAsyncGen:
raise NotImplementedError("astream_complete is not supported.")
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.question_gen.llm_generators import LLMQuestionGenerator
from llama_index.core.schema import QueryBundle
from llama_index.core.tools import QueryEngineTool, ToolMetadata

Expand All @@ -22,9 +23,10 @@ def retrieve(self) -> set[str]:
set[str]: A set of text chunks retrieved from the database.
"""
try:
llm = self.require_llm()
logger.info("Initialising vector query engine...")
vector_query_engine = self.vector_db.get_vector_store_index().as_query_engine(
llm=self.llm, similarity_top_k=self.top_k
llm=llm, similarity_top_k=self.top_k
)
logger.info(
f"Retrieving chunks for {self.doc_id} using SubQuestionQueryEngine."
Expand All @@ -39,10 +41,14 @@ def retrieve(self) -> set[str]:
]
query_bundle = QueryBundle(query_str=self.prompt)

question_gen = LLMQuestionGenerator.from_defaults(
llm=llm,
)
query_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=query_engine_tools,
question_gen=question_gen,
use_async=True,
llm=self.llm,
llm=llm,
)

response = query_engine.query(str_or_query_bundle=query_bundle)
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions prompt-service/src/unstract/prompt_service/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Pytest configuration for unit tests.

Unit tests should not require external dependencies or the full app.
This conftest intentionally does NOT import Flask app components.

WARNING: This file is NOT auto-loaded when running via tox because
--noconftest is used to skip the parent tests/conftest.py (which
imports Flask blueprints and triggers the full adapter import chain).
If you add shared fixtures here, either remove --noconftest from
tox.ini and fix the parent conftest's eager imports, or define
fixtures directly in test files.
"""
Loading
Loading