Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environments/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ dependencies:
- python=3.10
- pip:
- py-data-juicer
- agentscope
- flask
- omegaconf
- sqlalchemy
- psycopg2
- networkx
- transformers
- "-e ..[dev]"
5 changes: 5 additions & 0 deletions environments/env_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
"env_name": "trinity_data",
"env_yaml": "environments/data.yaml",
"env_entry": "trinity/data/server.py"
},
"trinity.training": {
"env_name": "trinity",
"env_yaml": "environments/training.yaml",
"env_entry": "trinity/cli/server.py"
}
}
7 changes: 7 additions & 0 deletions environments/training.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: trinity
channels:
- defaults
dependencies:
- python=3.10
- pip:
- "-e ..[dev]"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"flask",
"requests",
]

[project.scripts]
Expand Down
8 changes: 4 additions & 4 deletions scripts/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def main():
env_mapping = json.load(f)
for env_path, env_config in env_mapping.items():
env_name = env_config["env_name"]
print(f"Installing dependencies for module {env_name}...")
print(f"Installing dependencies for module [{env_name}]...")
# check if it's existing
res = subprocess.run(
f"{env_mng} env list | grep {env_name}", shell=True, text=True, stdout=subprocess.PIPE
)
if res.returncode == 0 and env_name in res.stdout:
print(f"Environment {env_name} already exists. Skipping...")
print(f"Environment [{env_name}] already exists. Skipping...")
else:
res = subprocess.run(
f'{env_mng} env create -f {env_config["env_yaml"]}'
Expand All @@ -39,9 +39,9 @@ def main():
shell=True,
)
if res.returncode == 0:
print(f"Environment {env_name} created successfully.")
print(f"Environment [{env_name}] created successfully.")
else:
print(f"Failed to create environment {env_name} with exit code {res.returncode}.")
print(f"Failed to create environment [{env_name}] with exit code {res.returncode}.")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions scripts/start_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main():
env_mapping = json.load(f)
for env_path, env_config in env_mapping.items():
env_name = env_config["env_name"]
print(f"Starting server for module {env_name}...")
print(f"Starting server for module [{env_name}]...")
timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(os.path.join(args.log_dir, f"{env_name}_{timestamp}_log.txt"), "w") as log_file:
server = subprocess.Popen(
Expand All @@ -38,7 +38,7 @@ def main():
shell=True,
)
servers.append(server)
print(f"Server of module {env_name} is started with PID {server.pid}")
print(f"Server of module [{env_name}] is started with PID {server.pid}")
for server in servers:
server.wait()

Expand Down
15 changes: 10 additions & 5 deletions trinity/data/client.py → trinity/cli/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import requests

LOCAL_SERVER_URL = "http://127.0.0.1:5000/data_workflow"


def send_get_request(url: str, params: dict) -> None:
def send_get_request(url: str, params: dict):
"""
Send GET request with parameters.

Expand Down Expand Up @@ -32,8 +30,15 @@ def request(url, **kwargs):


if __name__ == "__main__":
# --- only for local testing
LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow"
LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft"
# --- only for local testing

res = request(
url=LOCAL_SERVER_URL,
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
configPath="examples/grpo_gsm8k/gsm8k.yaml",
)
print(res)
if res:
print(res)
print(res["message"])
38 changes: 22 additions & 16 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,38 @@ def both(config: Config) -> None:
raise e


def activate_data_module(config_path: str):
def activate_data_module(data_workflow_url: str, config_path: str):
"""Check whether to activate data module and preprocess datasets."""
from trinity.data.client import LOCAL_SERVER_URL, request
from trinity.cli.client import request

logger.info("Activating data module...")
res = request(
url=LOCAL_SERVER_URL,
url=data_workflow_url,
configPath=config_path,
)
if res["return_code"] != 0:
logger.error(f"Failed to activate data module: {res['return_msg']}.")
return


def run(config_path: str):
config = load_config(config_path)
config.check_and_update()
# try to activate data module
data_config = config.data
if data_config.data_workflow_url and (
data_config.dj_config_path or data_config.dj_process_desc
):
activate_data_module(data_config.data_workflow_url, config_path)
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
train(config)
elif config.mode == "both":
both(config)


def main() -> None:
"""The main entrypoint."""
parser = argparse.ArgumentParser()
Expand All @@ -132,19 +150,7 @@ def main() -> None:
args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
config = load_config(args.config)
config.check_and_update()
# try to activate data module
data_config = config.data
if data_config.dj_config_path or data_config.dj_process_desc:
activate_data_module(args.config)
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
train(config)
elif config.mode == "both":
both(config)
run(args.config)


if __name__ == "__main__":
Expand Down
32 changes: 32 additions & 0 deletions trinity/cli/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import traceback

import fire
from flask import Flask, jsonify, request

app = Flask(__name__)

APP_NAME = "trinity_rft"


@app.route(f"/{APP_NAME}", methods=["GET"])
def trinity_training():
config_path = request.args.get("configPath")
try:
from trinity.cli.launcher import run

run(config_path)
ret = 0
msg = "Training Success."
except: # noqa: E722
traceback.print_exc()
msg = traceback.format_exc()
ret = 1
return jsonify({"return_code": ret, "message": msg})


def main(port=5006):
app.run(port=port, debug=True)


if __name__ == "__main__":
fire.Fire(main)
3 changes: 2 additions & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class FormatConfig:
class DataConfig:
"""Data config"""

# TODO: add more
data_workflow_url: Optional[str] = None

dataset_path: str = ""
train_split: str = "train"
eval_split: Optional[str] = None # TODO: check data format
Expand Down
2 changes: 2 additions & 0 deletions trinity/data/controllers/default_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
},
"llm_quality_score_filter": {
"api_or_hf_model": "qwen2.5-72b-instruct",
"min_score": 0.0,
"enable_vllm": False,
},
"perplexity_filter": {
Expand All @@ -66,6 +67,7 @@
},
"llm_difficulty_score_filter": {
"api_or_hf_model": "qwen2.5-72b-instruct",
"min_score": 0.0,
"enable_vllm": False,
},
# human annotators
Expand Down
10 changes: 7 additions & 3 deletions trinity/data/core/dataset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool

from trinity.buffer.utils import retry_session
from trinity.common.config import DataConfig
from trinity.common.schema import Base, RftDatasetModel
from trinity.data.core.dataset import RftDataset
from trinity.manager.sql_storage import retry_session
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(self, config: DataConfig) -> None:
self.session = sessionmaker(bind=self.engine)

def add_entries(self, dataset: RftDataset):
with retry_session(self) as session:
with retry_session(
self, self.config.max_retry_times, self.config.max_retry_interval
) as session:
session.add_all(rft_dataset_to_model(dataset))

def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False):
Expand All @@ -65,7 +67,9 @@ def get_entries(self, num_entries: int, order_by: str = None, ascending: bool =
order_by_key = asc(order_by_key) if ascending else desc(order_by_key)
else:
order_by_key = None
with retry_session(self) as session:
with retry_session(
self, self.config.max_retry_times, self.config.max_retry_interval
) as session:
entries = (
session.query(RftDatasetModel)
.order_by(order_by_key)
Expand Down
4 changes: 2 additions & 2 deletions trinity/data/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ synth_data = synthesizer.process(clean_data)
- Request using our simple client:

```python
from trinity.data.client import request
from trinity.cli.client import request

res = request(
url="http://127.0.0.1:5000/data_workflow",
url="http://127.0.0.1:5005/data_workflow",
configPath="tests/test_configs/active_iterator_test_cfg.yaml"
)

Expand Down
17 changes: 12 additions & 5 deletions trinity/data/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import fire
from flask import Flask, jsonify, request

from trinity.common.config import load_config
from trinity.data.controllers.active_iterator import DataActiveIterator

app = Flask(__name__)

APP_NAME = "data_workflow"


@app.route("/data_workflow", methods=["GET"])
@app.route(f"/{APP_NAME}", methods=["GET"])
def data_workflow():
from trinity.common.config import load_config
from trinity.data.controllers.active_iterator import DataActiveIterator

config_path = request.args.get("configPath")
config = load_config(config_path)

Expand All @@ -16,5 +19,9 @@ def data_workflow():
return jsonify({"return_code": ret, "message": msg})


def main(port=5005):
app.run(port=port, debug=True)


if __name__ == "__main__":
app.run(debug=True)
fire.Fire(main)