Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/agentscope_react/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ buffer:
response_key: 'answer'
rollout_args:
temperature: 1.0
default_workflow_type: 'as_react_workflow'
default_workflow_type: 'agentscope_react_workflow'
eval_tasksets: []
trainer_input:
experience_buffer:
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
# tool_call
"tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow",
# agentscope
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscope_workflow_adapter": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapter",
"agentscope_workflow_adapter_v1": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapterV1",
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscope_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow.AgentScopeReactMathWorkflow",
"as_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscopev0_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow.AgentScopeV0ReactMathWorkflow",
Expand Down
109 changes: 107 additions & 2 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(
from agentscope.model import TrinityChatModel
except ImportError:
raise ImportError(
"This workflow requires agentscope >= 0.1.6, please install "
"it via `pip install agentscope>=0.1.6`",
"This workflow requires agentscope >= 1.0.7, please install "
"it via `pip install agentscope>=1.0.7`",
)

super().__init__(
Expand Down Expand Up @@ -72,3 +72,108 @@ async def run_async(self) -> List[Experience]:
"""Run the workflow asynchronously and return experiences."""
reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type]
return self.construct_experiences(reward)


class AgentScopeWorkflowAdapterV1(Workflow):
"""A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow."""

is_async: bool = True

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
"""Initialize the adapter with the task and model."""
try:
from agentscope.model import TrinityChatModel
except ImportError:
raise ImportError(
"This workflow requires agentscope >= 1.0.11, please install "
"it via `pip install agentscope>=1.0.11`",
)

super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.workflow_func = task.workflow_args.get("workflow_func", None)
self.judge_func = task.workflow_args.get("judge_func", None)

if self.workflow_func is None:
raise ValueError(
"The 'workflow_func' is not provided.",
)

self.chat_model: TrinityChatModel = TrinityChatModel(
model.get_openai_async_client(),
generate_kwargs={
"temperature": self.task.rollout_args.temperature,
"top_p": self.task.rollout_args.top_p,
"max_tokens": self.task.rollout_args.max_tokens or 4096,
"logprobs": True,
"top_logprobs": self.task.rollout_args.logprobs,
},
)
self.auxiliary_chat_models = [
TrinityChatModel(
openai_async_client=aux_model,
# TODO: customize generate_kwargs for auxiliary models if needed
)
for aux_model in (self.auxiliary_models or [])
]

def construct_experiences(
self,
reward: float,
metrics: Dict,
) -> List[Experience]:
"""Construct experiences from the agent's interaction history.

Args:
reward (float): The reward value to assign to each experience.
metrics (Dict): A dictionary of metrics to be attached to the last experience.

Returns:
List: A list of Experience objects.
"""
exps = self.model.extract_experience_from_history()
for exp in exps:
exp.reward = reward
# only attach metrics to the last experience
if len(exps) > 0:
exps[-1].metrics = metrics
return exps

async def run_async(self) -> List[Experience]:
"""Run the workflow asynchronously and return experiences."""
try:
from agentscope.tuner import JudgeOutput, WorkflowOutput
except ImportError:
raise ImportError(
"Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed."
)

metrics = {}
workflow_output: WorkflowOutput = await self.workflow_func(
self.task.raw_task, self.chat_model, self.auxiliary_chat_models
) # type: ignore [arg-type]
metrics.update(workflow_output.metrics or {})
if self.judge_func is not None:
assert (
workflow_output.response is not None
), "Workflow must provide response for judging."
judge_output: JudgeOutput = await self.judge_func(
self.task.raw_task, workflow_output.response, self.auxiliary_chat_models
) # type: ignore [arg-type]
reward = judge_output.reward
metrics.update(judge_output.metrics or {})
else:
assert (
workflow_output.reward is not None
), "Either workflow or judge must provide reward."
reward = workflow_output.reward
return self.construct_experiences(reward, metrics)