diff --git a/examples/agentscope_react/gsm8k.yaml b/examples/agentscope_react/gsm8k.yaml index c1b79f7016..102606c4ba 100644 --- a/examples/agentscope_react/gsm8k.yaml +++ b/examples/agentscope_react/gsm8k.yaml @@ -29,7 +29,7 @@ buffer: response_key: 'answer' rollout_args: temperature: 1.0 - default_workflow_type: 'as_react_workflow' + default_workflow_type: 'agentscope_react_workflow' eval_tasksets: [] trainer_input: experience_buffer: diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index b8496f8f41..ea7390b4a4 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -20,8 +20,9 @@ # tool_call "tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow", # agentscope - "agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow", "agentscope_workflow_adapter": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapter", + "agentscope_workflow_adapter_v1": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapterV1", + "agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow", "agentscope_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow.AgentScopeReactMathWorkflow", "as_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow", "agentscopev0_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow.AgentScopeV0ReactMathWorkflow", diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 0de1c41c2a..a29594c1c0 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -22,8 +22,8 @@ def __init__( from agentscope.model import TrinityChatModel except ImportError: raise ImportError( - "This workflow requires agentscope >= 0.1.6, please install " - "it via `pip install agentscope>=0.1.6`", + "This workflow requires agentscope >= 1.0.7, please install " + "it via `pip install agentscope>=1.0.7`", ) super().__init__( @@ -72,3 +72,108 @@ async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type] return self.construct_experiences(reward) + + +class AgentScopeWorkflowAdapterV1(Workflow): + """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow.""" + + is_async: bool = True + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[ModelWrapper]] = None, + ): + """Initialize the adapter with the task and model.""" + try: + from agentscope.model import TrinityChatModel + except ImportError: + raise ImportError( + "This workflow requires agentscope >= 1.0.11, please install " + "it via `pip install agentscope>=1.0.11`", + ) + + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.workflow_func = task.workflow_args.get("workflow_func", None) + self.judge_func = task.workflow_args.get("judge_func", None) + + if self.workflow_func is None: + raise ValueError( + "The 'workflow_func' is not provided.", + ) + + self.chat_model: TrinityChatModel = TrinityChatModel( + model.get_openai_async_client(), + generate_kwargs={ + "temperature": self.task.rollout_args.temperature, + "top_p": self.task.rollout_args.top_p, + "max_tokens": self.task.rollout_args.max_tokens or 4096, + "logprobs": True, + "top_logprobs": self.task.rollout_args.logprobs, + }, + ) + self.auxiliary_chat_models = [ + TrinityChatModel( + openai_async_client=aux_model, + # TODO: customize generate_kwargs for auxiliary models if needed + ) + for aux_model in (self.auxiliary_models or []) + ] + + def construct_experiences( + self, + reward: float, + metrics: Dict, + ) -> List[Experience]: + """Construct experiences from the agent's interaction history. + + Args: + reward (float): The reward value to assign to each experience. + metrics (Dict): A dictionary of metrics to be attached to the last experience. + + Returns: + List: A list of Experience objects. + """ + exps = self.model.extract_experience_from_history() + for exp in exps: + exp.reward = reward + # only attach metrics to the last experience + if len(exps) > 0: + exps[-1].metrics = metrics + return exps + + async def run_async(self) -> List[Experience]: + """Run the workflow asynchronously and return experiences.""" + try: + from agentscope.tuner import JudgeOutput, WorkflowOutput + except ImportError: + raise ImportError( + "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed." + ) + + metrics = {} + workflow_output: WorkflowOutput = await self.workflow_func( + self.task.raw_task, self.chat_model, self.auxiliary_chat_models + ) # type: ignore [arg-type] + metrics.update(workflow_output.metrics or {}) + if self.judge_func is not None: + assert ( + workflow_output.response is not None + ), "Workflow must provide response for judging." + judge_output: JudgeOutput = await self.judge_func( + self.task.raw_task, workflow_output.response, self.auxiliary_chat_models + ) # type: ignore [arg-type] + reward = judge_output.reward + metrics.update(judge_output.metrics or {}) + else: + assert ( + workflow_output.reward is not None + ), "Either workflow or judge must provide reward." + reward = workflow_output.reward + return self.construct_experiences(reward, metrics)