diff --git a/README.md b/README.md index c7243033b..97f1a0fe7 100644 --- a/README.md +++ b/README.md @@ -1101,6 +1101,11 @@ Possible options: - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) +- **get_milestone** - Get repository milestone. + - `milestone_number`: Milestone number to fetch (number, required) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + - **get_release_by_tag** - Get a release by tag name - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) @@ -1125,6 +1130,15 @@ Possible options: - `repo`: Repository name (string, required) - `sha`: Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA. (string, optional) +- **list_milestones** - List repository milestones. + - `direction`: Sort direction: asc or desc (string, optional) + - `owner`: Repository owner (username or organization name) (string, required) + - `page`: Page number (1-indexed) (number, optional) + - `per_page`: Results per page (max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sort`: Sort field: due_on or completeness (string, optional) + - `state`: Filter by state: open, closed, or all (string, optional) + - **list_releases** - List releases - `owner`: Repository owner (string, required) - `page`: Page number for pagination (min 1) (number, optional) @@ -1137,6 +1151,16 @@ Possible options: - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - `repo`: Repository name (string, required) +- **milestone_write** - Write operations on repository milestones. + - `description`: Milestone description (string, optional) + - `due_on`: Due date in ISO-8601 date (YYYY-MM-DD) or RFC3339 timestamp (string, optional) + - `method`: Operation to perform: 'create', 'update', or 'delete' (string, required) + - `milestone_number`: Milestone number to update or delete (number, optional) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + - `state`: Milestone state: 'open' or 'closed' (string, optional) + - `title`: Milestone title (required for create) (string, optional) + - **push_files** - Push files to repository - `branch`: Branch to push to (string, required) - `files`: Array of file objects to push, each object with path (string) and content (string) (object[], required) @@ -1151,6 +1175,14 @@ Possible options: - `query`: Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more. (string, required) - `sort`: Sort field ('indexed' only) (string, optional) +- **search_milestones** - Search repository milestones. + - `owner`: Repository owner (username or organization name) (string, required) + - `page`: Page number (1-indexed) (number, optional) + - `per_page`: Results per page (max 100) (number, optional) + - `query`: Text to search for in milestone title or description (string, required) + - `repo`: Repository name (string, required) + - `state`: Filter by state: open, closed, or all (default: open) (string, optional) + - **search_repositories** - Search repositories - `minimal_output`: Return minimal repository information (default: true). When false, returns full GitHub API repository objects. (boolean, optional) - `order`: Sort order (string, optional) diff --git a/pkg/github/__toolsnaps__/get_milestone.snap b/pkg/github/__toolsnaps__/get_milestone.snap new file mode 100644 index 000000000..7c536737a --- /dev/null +++ b/pkg/github/__toolsnaps__/get_milestone.snap @@ -0,0 +1,30 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Get repository milestone." + }, + "description": "Get a milestone by number.", + "inputSchema": { + "type": "object", + "required": [ + "owner", + "repo", + "milestone_number" + ], + "properties": { + "milestone_number": { + "type": "number", + "description": "Milestone number to fetch" + }, + "owner": { + "type": "string", + "description": "Repository owner (username or organization name)" + }, + "repo": { + "type": "string", + "description": "Repository name" + } + } + }, + "name": "get_milestone" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/list_milestones.snap b/pkg/github/__toolsnaps__/list_milestones.snap new file mode 100644 index 000000000..ca911d7c6 --- /dev/null +++ b/pkg/github/__toolsnaps__/list_milestones.snap @@ -0,0 +1,58 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "List repository milestones." + }, + "description": "List milestones for a repository.", + "inputSchema": { + "type": "object", + "required": [ + "owner", + "repo" + ], + "properties": { + "direction": { + "type": "string", + "description": "Sort direction: asc or desc", + "enum": [ + "asc", + "desc" + ] + }, + "owner": { + "type": "string", + "description": "Repository owner (username or organization name)" + }, + "page": { + "type": "number", + "description": "Page number (1-indexed)" + }, + "per_page": { + "type": "number", + "description": "Results per page (max 100)" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "sort": { + "type": "string", + "description": "Sort field: due_on or completeness", + "enum": [ + "due_on", + "completeness" + ] + }, + "state": { + "type": "string", + "description": "Filter by state: open, closed, or all", + "enum": [ + "open", + "closed", + "all" + ] + } + } + }, + "name": "list_milestones" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/milestone_write.snap b/pkg/github/__toolsnaps__/milestone_write.snap new file mode 100644 index 000000000..fc09ca993 --- /dev/null +++ b/pkg/github/__toolsnaps__/milestone_write.snap @@ -0,0 +1,58 @@ +{ + "annotations": { + "title": "Write operations on repository milestones." + }, + "description": "Create, update, or delete milestones in a repository.", + "inputSchema": { + "type": "object", + "required": [ + "method", + "owner", + "repo" + ], + "properties": { + "description": { + "type": "string", + "description": "Milestone description" + }, + "due_on": { + "type": "string", + "description": "Due date in ISO-8601 date (YYYY-MM-DD) or RFC3339 timestamp" + }, + "method": { + "type": "string", + "description": "Operation to perform: 'create', 'update', or 'delete'", + "enum": [ + "create", + "update", + "delete" + ] + }, + "milestone_number": { + "type": "number", + "description": "Milestone number to update or delete" + }, + "owner": { + "type": "string", + "description": "Repository owner (username or organization name)" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "state": { + "type": "string", + "description": "Milestone state: 'open' or 'closed'", + "enum": [ + "open", + "closed" + ] + }, + "title": { + "type": "string", + "description": "Milestone title (required for create)" + } + } + }, + "name": "milestone_write" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/search_milestones.snap b/pkg/github/__toolsnaps__/search_milestones.snap new file mode 100644 index 000000000..75b6e1f88 --- /dev/null +++ b/pkg/github/__toolsnaps__/search_milestones.snap @@ -0,0 +1,47 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Search repository milestones." + }, + "description": "Search milestones for a repository.", + "inputSchema": { + "type": "object", + "required": [ + "owner", + "repo", + "query" + ], + "properties": { + "owner": { + "type": "string", + "description": "Repository owner (username or organization name)" + }, + "page": { + "type": "number", + "description": "Page number (1-indexed)" + }, + "per_page": { + "type": "number", + "description": "Results per page (max 100)" + }, + "query": { + "type": "string", + "description": "Text to search for in milestone title or description" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "state": { + "type": "string", + "description": "Filter by state: open, closed, or all (default: open)", + "enum": [ + "open", + "closed", + "all" + ] + } + } + }, + "name": "search_milestones" +} \ No newline at end of file diff --git a/pkg/github/milestones.go b/pkg/github/milestones.go new file mode 100644 index 000000000..db4598805 --- /dev/null +++ b/pkg/github/milestones.go @@ -0,0 +1,725 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/sanitize" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/google/go-github/v79/github" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const ( + milestoneStateOpen = "open" + milestoneStateClosed = "closed" +) + +// SearchMilestones lists milestones and filters them by a text query across title and description. +func SearchMilestones(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { + tool := mcp.Tool{ + Name: "search_milestones", + Description: t("TOOL_SEARCH_MILESTONES_DESCRIPTION", "Search milestones for a repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_MILESTONES_TITLE", "Search repository milestones."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "query": { + Type: "string", + Description: "Text to search for in milestone title or description", + }, + "state": { + Type: "string", + Description: "Filter by state: open, closed, or all (default: open)", + Enum: []any{milestoneStateOpen, milestoneStateClosed, "all"}, + }, + "per_page": { + Type: "number", + Description: "Results per page (max 100)", + }, + "page": { + Type: "number", + Description: "Page number (1-indexed)", + }, + }, + Required: []string{"owner", "repo", "query"}, + }, + } + + handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if state == "" { + state = milestoneStateOpen + } + if state != milestoneStateOpen && state != milestoneStateClosed && state != "all" { + return utils.NewToolResultError("state must be 'open', 'closed', or 'all'"), nil, nil + } + + perPage, err := OptionalIntParam(args, "per_page") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + page, err := OptionalIntParam(args, "page") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := getClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + opts := &github.MilestoneListOptions{ + State: state, + } + if perPage > 0 { + opts.ListOptions.PerPage = perPage + } + if page > 0 { + opts.ListOptions.Page = page + } + + milestones, resp, err := client.Issues.ListMilestones(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to search milestones", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to search milestones: %s", string(body))), nil, nil + } + + if flags.LockdownMode { + if cache == nil { + return nil, nil, fmt.Errorf("lockdown cache is not configured") + } + filtered := make([]*github.Milestone, 0, len(milestones)) + for _, milestone := range milestones { + creator := milestone.Creator + if creator == nil || creator.GetLogin() == "" { + filtered = append(filtered, milestone) + continue + } + isSafeContent, err := cache.IsSafeContent(ctx, creator.GetLogin(), owner, repo) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil, nil + } + if isSafeContent { + filtered = append(filtered, milestone) + } + } + milestones = filtered + } + + lowerQuery := strings.ToLower(strings.TrimSpace(query)) + matchAll := lowerQuery == "" || lowerQuery == "*" + result := make([]map[string]any, 0, len(milestones)) + for _, m := range milestones { + if matchAll { + result = append(result, milestoneSummary(m)) + continue + } + + title := strings.ToLower(m.GetTitle()) + description := strings.ToLower(m.GetDescription()) + creatorLogin := "" + if m.Creator != nil { + creatorLogin = strings.ToLower(m.Creator.GetLogin()) + } + + if strings.Contains(title, lowerQuery) || + strings.Contains(description, lowerQuery) || + (creatorLogin != "" && strings.Contains(creatorLogin, lowerQuery)) { + result = append(result, milestoneSummary(m)) + } + } + + payload := map[string]any{ + "milestones": result, + "count": len(result), + } + out, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal milestones: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + }) + + return tool, handler +} + +// ListMilestones lists milestones for a repository. +func ListMilestones(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { + tool := mcp.Tool{ + Name: "list_milestones", + Description: t("TOOL_LIST_MILESTONES_DESCRIPTION", "List milestones for a repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_MILESTONES_TITLE", "List repository milestones."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "state": { + Type: "string", + Description: "Filter by state: open, closed, or all", + Enum: []any{milestoneStateOpen, milestoneStateClosed, "all"}, + }, + "sort": { + Type: "string", + Description: "Sort field: due_on or completeness", + Enum: []any{"due_on", "completeness"}, + }, + "direction": { + Type: "string", + Description: "Sort direction: asc or desc", + Enum: []any{"asc", "desc"}, + }, + "per_page": { + Type: "number", + Description: "Results per page (max 100)", + }, + "page": { + Type: "number", + Description: "Page number (1-indexed)", + }, + }, + Required: []string{"owner", "repo"}, + }, + } + + handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if state != "" && state != milestoneStateOpen && state != milestoneStateClosed && state != "all" { + return utils.NewToolResultError("state must be 'open', 'closed', or 'all'"), nil, nil + } + + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if sort != "" && sort != "due_on" && sort != "completeness" { + return utils.NewToolResultError("sort must be 'due_on' or 'completeness'"), nil, nil + } + + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if direction != "" && direction != "asc" && direction != "desc" { + return utils.NewToolResultError("direction must be 'asc' or 'desc'"), nil, nil + } + + perPage, err := OptionalIntParam(args, "per_page") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + page, err := OptionalIntParam(args, "page") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := getClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + opts := &github.MilestoneListOptions{ + State: state, + Sort: sort, + Direction: direction, + } + if perPage > 0 { + opts.ListOptions.PerPage = perPage + } + if page > 0 { + opts.ListOptions.Page = page + } + + milestones, resp, err := client.Issues.ListMilestones(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list milestones", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to list milestones: %s", string(body))), nil, nil + } + + if flags.LockdownMode { + if cache == nil { + return nil, nil, fmt.Errorf("lockdown cache is not configured") + } + filtered := make([]*github.Milestone, 0, len(milestones)) + for _, milestone := range milestones { + creator := milestone.Creator + if creator == nil || creator.GetLogin() == "" { + filtered = append(filtered, milestone) + continue + } + isSafeContent, err := cache.IsSafeContent(ctx, creator.GetLogin(), owner, repo) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil, nil + } + if isSafeContent { + filtered = append(filtered, milestone) + } + } + milestones = filtered + } + + result := make([]map[string]any, 0, len(milestones)) + for _, m := range milestones { + result = append(result, milestoneSummary(m)) + } + + payload := map[string]any{ + "milestones": result, + "count": len(result), + } + out, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal milestones: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + }) + + return tool, handler +} + +// GetMilestone fetches a single milestone by number. +func GetMilestone(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { + tool := mcp.Tool{ + Name: "get_milestone", + Description: t("TOOL_GET_MILESTONE_DESCRIPTION", "Get a milestone by number."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_MILESTONE_TITLE", "Get repository milestone."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "milestone_number": { + Type: "number", + Description: "Milestone number to fetch", + }, + }, + Required: []string{"owner", "repo", "milestone_number"}, + }, + } + + handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + number, err := RequiredInt(args, "milestone_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := getClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + milestone, resp, err := client.Issues.GetMilestone(ctx, owner, repo, number) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get milestone", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to get milestone: %s", string(body))), nil, nil + } + + if flags.LockdownMode { + if cache == nil { + return nil, nil, fmt.Errorf("lockdown cache is not configured") + } + creator := milestone.Creator + if creator != nil && creator.GetLogin() != "" { + isSafeContent, err := cache.IsSafeContent(ctx, creator.GetLogin(), owner, repo) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil, nil + } + if !isSafeContent { + return utils.NewToolResultError("access to milestone is restricted by lockdown mode"), nil, nil + } + } + } + + out, err := json.Marshal(milestoneSummary(milestone)) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal milestone: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + }) + + return tool, handler +} + +func MilestoneWrite(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { + tool := mcp.Tool{ + Name: "milestone_write", + Description: t("TOOL_MILESTONE_WRITE_DESCRIPTION", "Create, update, or delete milestones in a repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_MILESTONE_WRITE_TITLE", "Write operations on repository milestones."), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "Operation to perform: 'create', 'update', or 'delete'", + Enum: []any{"create", "update", "delete"}, + }, + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "title": { + Type: "string", + Description: "Milestone title (required for create)", + }, + "description": { + Type: "string", + Description: "Milestone description", + }, + "state": { + Type: "string", + Description: "Milestone state: 'open' or 'closed'", + Enum: []any{milestoneStateOpen, milestoneStateClosed}, + }, + "due_on": { + Type: "string", + Description: "Due date in ISO-8601 date (YYYY-MM-DD) or RFC3339 timestamp", + }, + "milestone_number": { + Type: "number", + Description: "Milestone number to update or delete", + }, + }, + Required: []string{"method", "owner", "repo"}, + }, + } + + handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + method = strings.ToLower(method) + + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + title, err := OptionalParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if state != "" && state != milestoneStateOpen && state != milestoneStateClosed { + return utils.NewToolResultError("state must be 'open' or 'closed'"), nil, nil + } + dueOnRaw, err := OptionalParam[string](args, "due_on") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + dueOn, err := parseDueOn(dueOnRaw) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := getClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + switch method { + case "create": + if title == "" { + return utils.NewToolResultError("missing required parameter: title"), nil, nil + } + return createMilestone(ctx, client, owner, repo, title, description, state, dueOn) + case "update": + milestoneNumber, err := RequiredInt(args, "milestone_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if title == "" && description == "" && state == "" && dueOn == nil { + return utils.NewToolResultError("at least one of title, description, state, or due_on must be provided for update"), nil, nil + } + return updateMilestone(ctx, client, owner, repo, milestoneNumber, title, description, state, dueOn) + case "delete": + milestoneNumber, err := RequiredInt(args, "milestone_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + return deleteMilestone(ctx, client, owner, repo, milestoneNumber) + default: + return utils.NewToolResultError("invalid method, must be either 'create', 'update', or 'delete'"), nil, nil + } + }) + + return tool, handler +} + +func createMilestone(ctx context.Context, client *github.Client, owner, repo, title, description, state string, dueOn *time.Time) (*mcp.CallToolResult, any, error) { + req := &github.Milestone{ + Title: github.Ptr(title), + } + + if description != "" { + req.Description = github.Ptr(description) + } + if state != "" { + req.State = github.Ptr(state) + } + if dueOn != nil { + req.DueOn = &github.Timestamp{Time: *dueOn} + } + + milestone, resp, err := client.Issues.CreateMilestone(ctx, owner, repo, req) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create milestone", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to create milestone: %s", string(body))), nil, nil + } + + return marshalMilestoneResponse(milestone) +} + +func updateMilestone(ctx context.Context, client *github.Client, owner, repo string, number int, title, description, state string, dueOn *time.Time) (*mcp.CallToolResult, any, error) { + req := &github.Milestone{} + if title != "" { + req.Title = github.Ptr(title) + } + if description != "" { + req.Description = github.Ptr(description) + } + if state != "" { + req.State = github.Ptr(state) + } + if dueOn != nil { + req.DueOn = &github.Timestamp{Time: *dueOn} + } + + milestone, resp, err := client.Issues.EditMilestone(ctx, owner, repo, number, req) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update milestone", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to update milestone: %s", string(body))), nil, nil + } + + return marshalMilestoneResponse(milestone) +} + +func deleteMilestone(ctx context.Context, client *github.Client, owner, repo string, number int) (*mcp.CallToolResult, any, error) { + resp, err := client.Issues.DeleteMilestone(ctx, owner, repo, number) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to delete milestone", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to delete milestone: %s", string(body))), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("milestone %d deleted", number)), nil, nil +} + +func milestoneSummary(milestone *github.Milestone) map[string]any { + dueOn := "" + if milestone.DueOn != nil { + dueOn = milestone.DueOn.Time.Format(time.RFC3339) + } + + title := sanitize.Sanitize(milestone.GetTitle()) + description := sanitize.Sanitize(milestone.GetDescription()) + + return map[string]any{ + "id": fmt.Sprintf("%d", milestone.GetID()), + "number": milestone.GetNumber(), + "title": title, + "state": milestone.GetState(), + "description": description, + "due_on": dueOn, + "open_issues": milestone.GetOpenIssues(), + "closed_issues": milestone.GetClosedIssues(), + "url": milestone.GetHTMLURL(), + } +} + +func marshalMilestoneResponse(milestone *github.Milestone) (*mcp.CallToolResult, any, error) { + minimalResponse := map[string]any{ + "id": fmt.Sprintf("%d", milestone.GetID()), + "number": milestone.GetNumber(), + "url": milestone.GetHTMLURL(), + } + + out, err := json.Marshal(minimalResponse) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + return utils.NewToolResultText(string(out)), nil, nil +} + +func parseDueOn(value string) (*time.Time, error) { + if value == "" { + return nil, nil + } + + if ts, err := time.Parse(time.RFC3339, value); err == nil { + return &ts, nil + } + + if ts, err := time.Parse("2006-01-02", value); err == nil { + return &ts, nil + } + + return nil, fmt.Errorf("invalid due_on format; use YYYY-MM-DD or RFC3339 timestamp") +} diff --git a/pkg/github/milestones_test.go b/pkg/github/milestones_test.go new file mode 100644 index 000000000..82e237f07 --- /dev/null +++ b/pkg/github/milestones_test.go @@ -0,0 +1,1218 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v79/github" + "github.com/google/jsonschema-go/jsonschema" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/shurcooL/githubv4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newMilestoneTestClient(t *testing.T, handler http.Handler) *github.Client { + t.Helper() + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + client := github.NewClient(server.Client()) + baseURL, err := url.Parse(server.URL + "/") + require.NoError(t, err) + client.BaseURL = baseURL + + return client +} + +func TestListMilestones_ToolDefinition(t *testing.T) { + t.Parallel() + + mockClient := github.NewClient(nil) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + tool, _ := ListMilestones(stubGetClientFn(mockClient), cache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "list_milestones", tool.Name) + assert.True(t, tool.Annotations.ReadOnlyHint) + + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "state") + assert.Contains(t, schema.Properties, "sort") + assert.Contains(t, schema.Properties, "direction") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo"}) +} + +func TestListMilestones_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mockedClient *http.Client + args map[string]any + expectError bool + errContains string + }{ + { + name: "success with filters and pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + expectQueryParams(t, map[string]string{ + "state": "all", + "sort": "due_on", + "direction": "desc", + "page": "2", + "per_page": "50", + }).andThen( + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 1, + "number": 10, + "title": "v1", + "state": "open", + "description": "first", + "due_on": "2024-01-02T00:00:00Z", + "open_issues": 3, + "closed_issues": 1, + "html_url": "https://example.com/1", + }, + { + "id": 2, + "number": 11, + "title": "v2", + "state": "closed", + "description": "second", + "open_issues": 0, + "closed_issues": 4, + "html_url": "https://example.com/2", + }, + }), + ), + ), + ), + args: map[string]any{ + "owner": "owner", + "repo": "repo", + "state": "all", + "sort": "due_on", + "direction": "desc", + "per_page": float64(50), + "page": float64(2), + }, + }, + { + name: "api error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusInternalServerError, map[string]string{"message": "boom"}), + ), + ), + args: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + errContains: "failed to list milestones", + }, + { + name: "validation error", + args: map[string]any{ + "owner": "o", + "repo": "r", + "state": "invalid", + }, + expectError: true, + errContains: "state must be", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := ListMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + request := createMCPRequest(tc.args) + result, _, err := handler(context.Background(), &request, tc.args) + + require.NoError(t, err) + require.NotNil(t, result) + + if tc.expectError { + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, tc.errContains) + return + } + + require.False(t, result.IsError) + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + assert.Equal(t, float64(2), resp["count"]) + }) + } +} + +func TestListMilestones_LockdownFiltersUnsafeCreators(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 1, + "number": 10, + "title": "unsafe", + "description": "from reader", + "creator": map[string]any{ + "login": "testuser", + }, + "html_url": "https://example.com/1", + }, + { + "id": 2, + "number": 11, + "title": "safe", + "description": "from writer", + "creator": map[string]any{ + "login": "testuser2", + }, + "html_url": "https://example.com/2", + }, + }), + ), + ) + + client := github.NewClient(mockedClient) + gqlClient := githubv4.NewClient(newRepoAccessHTTPClient()) + cache := stubRepoAccessCache(gqlClient, 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": true}) + _, handler := ListMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + assert.Len(t, milestones, 1) + + first, ok := milestones[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "safe", first["title"]) +} + +func TestListMilestones_ValidationError(t *testing.T) { + t.Parallel() + + client := github.NewClient(nil) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := ListMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "o", + "repo": "r", + "state": "invalid", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "state must be") +} + +func TestSearchMilestones_ToolDefinition(t *testing.T) { + t.Parallel() + + mockClient := github.NewClient(nil) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + tool, _ := SearchMilestones(stubGetClientFn(mockClient), cache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "search_milestones", tool.Name) + assert.True(t, tool.Annotations.ReadOnlyHint) + + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "query") + assert.Contains(t, schema.Properties, "state") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo", "query"}) +} + +func TestSearchMilestones_Success(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + expectQueryParams(t, map[string]string{ + "state": "open", + "page": "1", + "per_page": "25", + }).andThen( + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 1, + "number": 10, + "title": "Alpha release", + "state": "open", + "description": "first milestone", + "html_url": "https://example.com/1", + "open_issues": 3, + "closed_issues": 1, + }, + { + "id": 2, + "number": 11, + "title": "Beta", + "state": "open", + "description": "stability work", + "html_url": "https://example.com/2", + "open_issues": 0, + "closed_issues": 4, + }, + }), + ), + ), + ) + + client := github.NewClient(mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "alpha", + "per_page": float64(25), + "page": float64(1), + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + assert.Equal(t, float64(1), resp["count"]) + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + require.Len(t, milestones, 1) + first, ok := milestones[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "Alpha release", first["title"]) +} + +func TestSearchMilestones_WildcardMatchesAll(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + expectQueryParams(t, map[string]string{ + "state": "closed", + "page": "1", + "per_page": "25", + }).andThen( + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 3, + "number": 12, + "title": "Should be closed", + "state": "closed", + "description": "no filter needed", + "html_url": "https://example.com/3", + "open_issues": 0, + "closed_issues": 0, + }, + }), + ), + ), + ) + + client := github.NewClient(mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "*", + "state": "closed", + "per_page": float64(25), + "page": float64(1), + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + assert.Equal(t, float64(1), resp["count"]) + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + require.Len(t, milestones, 1) + first, ok := milestones[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "Should be closed", first["title"]) + assert.Equal(t, "closed", first["state"]) +} + +func TestSearchMilestones_MatchesCreatorLogin(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 4, + "number": 13, + "title": "Release tasks", + "state": "open", + "description": "tracking release work", + "creator": map[string]any{ + "login": "alice-dev", + }, + "html_url": "https://example.com/4", + }, + { + "id": 5, + "number": 14, + "title": "Other tasks", + "state": "open", + "description": "misc work", + "creator": map[string]any{ + "login": "bob-dev", + }, + "html_url": "https://example.com/5", + }, + }), + ), + ) + + client := github.NewClient(mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "alice", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + assert.Equal(t, float64(1), resp["count"]) + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + require.Len(t, milestones, 1) + first, ok := milestones[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "Release tasks", first["title"]) +} + +func TestSearchMilestones_Validation(t *testing.T) { + t.Parallel() + + client := github.NewClient(nil) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "o", + "repo": "r", + "query": "q", + "state": "invalid", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "state must be") +} + +func TestSearchMilestones_Lockdown(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 1, + "number": 10, + "title": "Unsafe alpha", + "description": "match me", + "creator": map[string]any{ + "login": "testuser", + }, + "html_url": "https://example.com/1", + }, + { + "id": 2, + "number": 11, + "title": "Safe alpha", + "description": "match me too", + "creator": map[string]any{ + "login": "testuser2", + }, + "html_url": "https://example.com/2", + }, + }), + ), + ) + + client := github.NewClient(mockedClient) + gqlClient := githubv4.NewClient(newRepoAccessHTTPClient()) + cache := stubRepoAccessCache(gqlClient, 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": true}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "alpha", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + assert.Len(t, milestones, 1) + + first, ok := milestones[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "Safe alpha", first["title"]) +} + +func TestSearchMilestones_LockdownBlocksUnsafeCreatorMatches(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, []map[string]any{ + { + "id": 1, + "number": 10, + "title": "Match me", + "description": "created by unsafe user", + "creator": map[string]any{ + "login": "testuser", + }, + "html_url": "https://example.com/1", + }, + { + "id": 2, + "number": 11, + "title": "Other milestone", + "description": "safe creator unrelated", + "creator": map[string]any{ + "login": "testuser2", + }, + "html_url": "https://example.com/2", + }, + }), + ), + ) + + client := github.NewClient(mockedClient) + gqlClient := githubv4.NewClient(newRepoAccessHTTPClient()) + cache := stubRepoAccessCache(gqlClient, 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": true}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "match", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + + assert.Equal(t, float64(0), resp["count"]) + milestones, ok := resp["milestones"].([]any) + require.True(t, ok) + assert.Len(t, milestones, 0) +} + +func TestSearchMilestones_ApiError(t *testing.T) { + t.Parallel() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones", Method: http.MethodGet}, + mockResponse(t, http.StatusInternalServerError, map[string]any{"message": "boom"}), + ), + ) + + client := github.NewClient(mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := SearchMilestones(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "query": "alpha", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsError) + + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "failed to search milestones") +} + +func TestGetMilestone_ToolDefinition(t *testing.T) { + t.Parallel() + + mockClient := github.NewClient(nil) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + tool, _ := GetMilestone(stubGetClientFn(mockClient), cache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "get_milestone", tool.Name) + assert.True(t, tool.Annotations.ReadOnlyHint) + + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok) + assert.Contains(t, schema.Properties, "milestone_number") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo", "milestone_number"}) +} + +func TestGetMilestone_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mockedClient *http.Client + args map[string]any + expectError bool + errContains string + }{ + { + name: "success", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones/{milestone_number}", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, map[string]any{ + "id": 55, + "number": 5, + "title": "v1", + "state": "open", + "description": "first", + "due_on": "2024-01-02T00:00:00Z", + "open_issues": 3, + "closed_issues": 1, + "html_url": "https://example.com/1", + }), + ), + ), + args: map[string]any{ + "owner": "owner", + "repo": "repo", + "milestone_number": float64(5), + }, + }, + { + name: "not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones/{milestone_number}", Method: http.MethodGet}, + mockResponse(t, http.StatusNotFound, map[string]string{"message": "not found"}), + ), + ), + args: map[string]any{ + "owner": "owner", + "repo": "repo", + "milestone_number": float64(99), + }, + expectError: true, + errContains: "failed to get milestone", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := GetMilestone(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + request := createMCPRequest(tc.args) + result, _, err := handler(context.Background(), &request, tc.args) + + require.NoError(t, err) + require.NotNil(t, result) + + if tc.expectError { + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, tc.errContains) + return + } + + require.False(t, result.IsError) + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + assert.Equal(t, float64(5), resp["number"]) + assert.Equal(t, "v1", resp["title"]) + }) + } +} + +func TestGetMilestone_LockdownEnforced(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mockedClient *http.Client + args map[string]any + expectError bool + errContains string + }{ + { + name: "blocked when creator lacks push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones/{milestone_number}", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, map[string]any{ + "id": 55, + "number": 5, + "title": "v1", + "state": "open", + "description": "first", + "html_url": "https://example.com/1", + "creator": map[string]any{ + "login": "testuser", + }, + }), + ), + ), + args: map[string]any{ + "owner": "owner", + "repo": "repo", + "milestone_number": float64(5), + }, + expectError: true, + errContains: "access to milestone is restricted by lockdown mode", + }, + { + name: "allowed for private repo creator", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{Pattern: "/repos/{owner}/{repo}/milestones/{milestone_number}", Method: http.MethodGet}, + mockResponse(t, http.StatusOK, map[string]any{ + "id": 56, + "number": 6, + "title": "v2", + "state": "open", + "description": "second", + "html_url": "https://example.com/2", + "creator": map[string]any{ + "login": "testuser2", + }, + }), + ), + ), + args: map[string]any{ + "owner": "owner2", + "repo": "repo2", + "milestone_number": float64(6), + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + gqlClient := githubv4.NewClient(newRepoAccessHTTPClient()) + cache := stubRepoAccessCache(gqlClient, 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": true}) + _, handler := GetMilestone(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + request := createMCPRequest(tc.args) + result, _, err := handler(context.Background(), &request, tc.args) + + require.NoError(t, err) + require.NotNil(t, result) + + if tc.expectError { + require.True(t, result.IsError) + errText := getErrorResult(t, result) + assert.Contains(t, errText.Text, tc.errContains) + return + } + + require.False(t, result.IsError) + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + assert.Equal(t, tc.args["milestone_number"], resp["number"]) + }) + } +} + +func TestGetMilestone_NotFound(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]any{ + "message": "Not Found", + }) + })) + + cache := stubRepoAccessCache(githubv4.NewClient(nil), 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": false}) + _, handler := GetMilestone(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + + args := map[string]any{ + "owner": "owner", + "repo": "repo", + "milestone_number": float64(1), + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "failed to get milestone") +} + +func TestMilestoneWrite_ToolDefinition(t *testing.T) { + t.Parallel() + + mockClient := github.NewClient(nil) + tool, _ := MilestoneWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "milestone_write", tool.Name) + assert.NotEmpty(t, tool.Description) + + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok, "InputSchema should be *jsonschema.Schema") + + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "method") + assert.Contains(t, schema.Properties, "title") + assert.Contains(t, schema.Properties, "due_on") + assert.ElementsMatch(t, schema.Required, []string{"method", "owner", "repo"}) +} + +func TestMilestoneWrite_CreateSuccess(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/repos/owner/repo/milestones", r.URL.Path) + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": 101, + "number": 5, + "html_url": "https://example.com/milestones/5", + "created_at": time.Now(), + }) + })) + + tool, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + require.Equal(t, "milestone_write", tool.Name) + + args := map[string]any{ + "method": "create", + "owner": "owner", + "repo": "repo", + "title": "v1.0", + "description": "Initial release", + "state": "open", + "due_on": "2024-01-02", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + assert.Equal(t, "101", resp["id"]) + assert.Equal(t, float64(5), resp["number"]) + assert.Equal(t, "https://example.com/milestones/5", resp["url"]) +} + +func TestMilestoneWrite_UpdateSuccess(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method) + assert.Equal(t, "/repos/owner/repo/milestones/7", r.URL.Path) + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": 202, + "number": 7, + "html_url": "https://example.com/milestones/7", + }) + })) + + _, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + + args := map[string]any{ + "method": "update", + "owner": "owner", + "repo": "repo", + "milestone_number": float64(7), + "title": "v1.1", + "state": "closed", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &resp)) + assert.Equal(t, "202", resp["id"]) + assert.Equal(t, float64(7), resp["number"]) + assert.Equal(t, "https://example.com/milestones/7", resp["url"]) +} + +func TestMilestoneWrite_DeleteSuccess(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "/repos/owner/repo/milestones/3", r.URL.Path) + w.WriteHeader(http.StatusNoContent) + })) + + _, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + + args := map[string]any{ + "method": "delete", + "owner": "owner", + "repo": "repo", + "milestone_number": float64(3), + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + assert.Contains(t, text.Text, "milestone 3 deleted") +} + +func TestMilestoneWrite_DeleteApiError(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "/repos/owner/repo/milestones/3", r.URL.Path) + + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "message": "delete failed", + }) + })) + + _, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + + args := map[string]any{ + "method": "delete", + "owner": "owner", + "repo": "repo", + "milestone_number": float64(3), + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "failed to delete milestone") +} + +func TestMilestoneWrite_ValidationErrors(t *testing.T) { + t.Parallel() + + client := github.NewClient(nil) + _, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + + tests := []struct { + name string + args map[string]any + errSubstr string + }{ + { + name: "missing title on create", + args: map[string]any{ + "method": "create", + "owner": "o", + "repo": "r", + }, + errSubstr: "missing required parameter: title", + }, + { + name: "invalid state", + args: map[string]any{ + "method": "create", + "owner": "o", + "repo": "r", + "title": "x", + "state": "pending", + }, + errSubstr: "state must be 'open' or 'closed'", + }, + { + name: "invalid due_on", + args: map[string]any{ + "method": "create", + "owner": "o", + "repo": "r", + "title": "x", + "due_on": "not-a-date", + }, + errSubstr: "invalid due_on format", + }, + { + name: "update without fields", + args: map[string]any{ + "method": "update", + "owner": "o", + "repo": "r", + "milestone_number": float64(1), + }, + errSubstr: "at least one of title", + }, + { + name: "invalid method", + args: map[string]any{ + "method": "noop", + "owner": "o", + "repo": "r", + }, + errSubstr: "invalid method", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.args) + result, _, err := handler(context.Background(), &request, tc.args) + + require.NoError(t, err) + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, tc.errSubstr) + }) + } +} + +func TestMilestoneWrite_ApiError(t *testing.T) { + t.Parallel() + + client := newMilestoneTestClient(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(github.ErrorResponse{ + Message: "bad request", + }) + })) + + _, handler := MilestoneWrite(stubGetClientFn(client), translations.NullTranslationHelper) + + args := map[string]any{ + "method": "create", + "owner": "o", + "repo": "r", + "title": "t", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "failed to create milestone") +} + +func TestParseDueOn(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + expectNil bool + }{ + {"empty", "", true}, + {"date only", "2024-01-02", false}, + {"rfc3339", "2024-01-02T15:04:05Z", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + ts, err := parseDueOn(tc.value) + require.NoError(t, err) + if tc.expectNil { + assert.Nil(t, ts) + return + } + require.NotNil(t, ts) + }) + } + + _, err := parseDueOn("bad-date") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid due_on format") +} + +func TestMarshalMilestoneResponse(t *testing.T) { + t.Parallel() + + milestone := &github.Milestone{ + ID: github.Ptr(int64(5)), + Number: github.Ptr(3), + HTMLURL: github.Ptr("https://example.com/m/3"), + } + + result, _, err := marshalMilestoneResponse(milestone) + require.NoError(t, err) + require.False(t, result.IsError) + + text := getTextResult(t, result) + var out map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &out)) + assert.Equal(t, "5", out["id"]) + assert.Equal(t, float64(3), out["number"]) + assert.Equal(t, "https://example.com/m/3", out["url"]) +} + +func TestMilestoneWrite_ClientError(t *testing.T) { + t.Parallel() + + getClientErr := func(_ context.Context) (*github.Client, error) { + return nil, fmt.Errorf("boom") + } + + _, handler := MilestoneWrite(getClientErr, translations.NullTranslationHelper) + + args := map[string]any{ + "method": "create", + "owner": "o", + "repo": "r", + "title": "t", + } + + request := createMCPRequest(args) + result, _, err := handler(context.Background(), &request, args) + + require.NoError(t, err) + require.True(t, result.IsError) + + text := getErrorResult(t, result) + assert.Contains(t, text.Text, "failed to get GitHub client") +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index d37af98b8..66741a15e 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -178,6 +178,9 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(ListReleases(getClient, t)), toolsets.NewServerTool(GetLatestRelease(getClient, t)), toolsets.NewServerTool(GetReleaseByTag(getClient, t)), + toolsets.NewServerTool(SearchMilestones(getClient, cache, t, flags)), + toolsets.NewServerTool(ListMilestones(getClient, cache, t, flags)), + toolsets.NewServerTool(GetMilestone(getClient, cache, t, flags)), ). AddWriteTools( toolsets.NewServerTool(CreateOrUpdateFile(getClient, t)), @@ -186,6 +189,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(CreateBranch(getClient, t)), toolsets.NewServerTool(PushFiles(getClient, t)), toolsets.NewServerTool(DeleteFile(getClient, t)), + toolsets.NewServerTool(MilestoneWrite(getClient, t)), ). AddResourceTemplates( toolsets.NewServerResourceTemplate(GetRepositoryResourceContent(getClient, getRawClient, t)),