Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions backend/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
Expand All @@ -17,6 +18,7 @@ import (
"image-gen-service/internal/api"
"image-gen-service/internal/config"
"image-gen-service/internal/model"
"image-gen-service/internal/platform"
"image-gen-service/internal/provider"
"image-gen-service/internal/storage"
"image-gen-service/internal/templates"
Expand All @@ -27,7 +29,7 @@ import (

func getWorkDir() string {
// 如果是作为 Tauri 边车运行,使用用户目录下的应用支持目录
if os.Getenv("TAURI_PLATFORM") != "" || os.Getenv("TAURI_FAMILY") != "" {
if platform.IsTauriSidecar() {
configDir, err := os.UserConfigDir()
if err == nil {
appDir := configDir + "/com.dztool.banana"
Expand All @@ -38,6 +40,69 @@ func getWorkDir() string {
return "."
}

func isLoopbackOrigin(origin string) bool {
u, err := url.Parse(origin)
if err != nil || u == nil {
return false
}
host := strings.ToLower(u.Hostname())
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}

func isNullOrigin(origin string) bool {
return strings.EqualFold(strings.TrimSpace(origin), "null")
}

func isAllowedTauriOrigin(origin string) bool {
origin = strings.TrimSpace(origin)
if origin == "" || isNullOrigin(origin) {
return false
}
u, err := url.Parse(origin)
if err != nil || u == nil {
return false
}
if !strings.EqualFold(u.Scheme, "tauri") {
return isLoopbackOrigin(origin)
}
if !strings.EqualFold(u.Hostname(), "localhost") {
return false
}
if strings.TrimSpace(u.Path) != "" && strings.TrimSpace(u.Path) != "/" {
return false
}
return true
}

func loadCORSAllowlistFromEnv() map[string]struct{} {
raw := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGINS"))
if raw == "" {
return map[string]struct{}{}
}
allowlist := make(map[string]struct{})
for _, part := range strings.Split(raw, ",") {
v := strings.TrimSpace(part)
if v == "" {
continue
}
allowlist[v] = struct{}{}
}
return allowlist
}

func originInAllowlist(origin string, allowlist map[string]struct{}) bool {
if len(allowlist) == 0 {
return false
}
_, ok := allowlist[origin]
return ok
}

func allowlistHasWildcard(allowlist map[string]struct{}) bool {
_, ok := allowlist["*"]
return ok
}

// isRunningInDocker 检测是否运行在 Docker 容器中
// 使用多种检测方式组合,确保可靠性
func isRunningInDocker() bool {
Expand Down Expand Up @@ -133,19 +198,54 @@ func main() {

// 5. 设置路由
r := gin.Default()
corsAllowlist := loadCORSAllowlistFromEnv()

// 允许跨域请求
r.Use(func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
log.Printf("[CORS] Request from Origin: %s, Method: %s, Path: %s", origin, c.Request.Method, c.Request.URL.Path)

if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
if platform.IsTauriSidecar() {
if !isAllowedTauriOrigin(origin) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "origin not allowed",
"data": nil,
})
return
}
if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
}
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
trimmedOrigin := strings.TrimSpace(origin)
hasWildcard := allowlistHasWildcard(corsAllowlist)
if isNullOrigin(trimmedOrigin) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "origin not allowed",
"data": nil,
})
return
} else if trimmedOrigin == "" || (hasWildcard && len(corsAllowlist) > 0) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if len(corsAllowlist) == 0 {
// 非 Tauri 模式默认放开跨域,但不允许携带凭证,避免“反射 Origin + credentials”风险
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if originInAllowlist(trimmedOrigin, corsAllowlist) {
c.Writer.Header().Set("Access-Control-Allow-Origin", trimmedOrigin)
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
Comment on lines +236 to +238
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The CORS implementation incorrectly handles the wildcard (*) in the CORS_ALLOW_ORIGINS allowlist. If the allowlist contains *, the originInAllowlist function returns true for any origin, which is then reflected in the Access-Control-Allow-Origin header with Access-Control-Allow-Credentials set to true. This effectively bypasses browser security, allowing malicious websites to make authenticated requests if the allowlist is misconfigured with a wildcard. The logic should be updated to ensure that if a wildcard is used, credentials are not permitted, or the origin is not reflected unless it matches a specific trusted domain. Furthermore, the non-Tauri mode logic in this CORS middleware is complex and repetitive, especially where if trimmedOrigin == "" and else if len(corsAllowlist) == 0 branches perform the same c.Writer.Header().Set("Access-Control-Allow-Origin", "*") operation. Reorganizing this logic to prioritize invalid origins and then handle corsAllowlist presence separately would improve clarity and maintainability.

} else {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "origin not allowed",
"data": nil,
})
return
}
}

c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, *")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")

Expand Down
92 changes: 89 additions & 3 deletions backend/internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"

"image-gen-service/internal/config"
"image-gen-service/internal/model"
"image-gen-service/internal/platform"
"image-gen-service/internal/provider"
"image-gen-service/internal/storage"
"image-gen-service/internal/worker"
Expand Down Expand Up @@ -124,6 +126,77 @@ func defaultTimeoutSecondsForProvider(providerName string) int {
}
}

func normalizePathForCheck(path string) string {
cleaned := filepath.Clean(path)
cleaned = strings.TrimRight(cleaned, string(filepath.Separator))
if cleaned == "" {
return string(filepath.Separator)
}
return cleaned
}

func pathWithinRoot(path, root string) bool {
nPath := normalizePathForCheck(path)
nRoot := normalizePathForCheck(root)
if runtime.GOOS == "windows" {
nPath = strings.ToLower(nPath)
nRoot = strings.ToLower(nRoot)
}
if nPath == nRoot {
return true
}
rel, err := filepath.Rel(nRoot, nPath)
if err != nil {
return false
}
rel = strings.TrimSpace(rel)
if rel == "." {
return true
}
if rel == "" {
return false
}
return !strings.HasPrefix(rel, "..")
}

func allowedRefPathRoots() []string {
roots := make([]string, 0, 4)
if configDir, err := os.UserConfigDir(); err == nil && strings.TrimSpace(configDir) != "" {
roots = append(roots, filepath.Join(configDir, "com.dztool.banana"))
}
if cacheDir, err := os.UserCacheDir(); err == nil && strings.TrimSpace(cacheDir) != "" {
roots = append(roots, filepath.Join(cacheDir, "com.dztool.banana"))
}
roots = append(roots, os.TempDir())
return roots
}

func validateRefPathForTauri(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", fmt.Errorf("empty ref path")
}
abs, err := filepath.Abs(trimmed)
if err != nil {
return "", fmt.Errorf("invalid ref path: %w", err)
}
abs = filepath.Clean(abs)
resolved, err := filepath.EvalSymlinks(abs)
if err != nil {
return "", fmt.Errorf("ref path could not be resolved: %w", err)
}
real := filepath.Clean(strings.TrimSpace(resolved))
if real == "" {
return "", fmt.Errorf("ref path could not be resolved")
}
for _, root := range allowedRefPathRoots() {
if pathWithinRoot(real, root) {
return real, nil
}
}
return "", fmt.Errorf("ref path is outside allowed directories")
}

// ProviderConfigRequest 设置 Provider 配置请求
type ProviderConfigRequest struct {
ProviderName string `json:"provider_name" binding:"required"`
Expand Down Expand Up @@ -429,10 +502,23 @@ func GenerateWithImagesHandler(c *gin.Context) {
// 处理本地路径请求 (Tauri 优化)
for _, path := range req.RefPaths {
if path != "" {
content, err := os.ReadFile(path)
if !platform.IsTauriSidecar() {
Error(c, http.StatusBadRequest, 400, "refPaths 仅支持桌面端模式")
return
}
targetPath := path
validatedPath, validateErr := validateRefPathForTauri(path)
if validateErr != nil {
log.Printf("[API] 非法本地参考图路径: %s, err: %v\n", path, validateErr)
Error(c, http.StatusBadRequest, 400, "参考图路径不在允许目录内")
return
}
targetPath = validatedPath
content, err := os.ReadFile(targetPath)
Comment on lines +509 to +517
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

The GenerateWithImagesHandler function allows reading arbitrary files from the server's file system when not running in Tauri mode. The path parameter from the request is used directly in os.ReadFile without any validation or sanitization. An attacker can exploit this to read sensitive files such as configuration files, credentials, or system files. While validation was added for Tauri mode, other modes (like Docker) remain unprotected. Strict path validation should be implemented for all environments.

if err != nil {
log.Printf("[API] 读取本地参考图失败: %s, err: %v\n", path, err)
continue
log.Printf("[API] 读取本地参考图失败: %s, err: %v\n", targetPath, err)
Error(c, http.StatusBadRequest, 400, "读取本地参考图失败")
return
}
refImageBytes = append(refImageBytes, content)
}
Expand Down
11 changes: 5 additions & 6 deletions backend/internal/model/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func migrateOldTasksToMonthFolders() {
folderCache := make(map[string]uint)

batchSize := 100
offset := 0
processed := 0
totalMigrated := 0
var totalTasks int64

Expand All @@ -242,10 +242,10 @@ func migrateOldTasksToMonthFolders() {
// 分批处理任务
for {
var tasks []Task
// 分批查询未归类的任务
// 每批都从当前剩余未归类任务中取前 N 条,避免更新后使用 offset 漏扫
result := DB.Where("folder_id = ? OR folder_id IS NULL", "").
Order("id ASC").
Limit(batchSize).
Offset(offset).
Find(&tasks)

if result.Error != nil {
Expand Down Expand Up @@ -312,9 +312,8 @@ func migrateOldTasksToMonthFolders() {
totalMigrated++
}

// 下一批任务
offset += len(tasks)
log.Printf("[Migration] 已处理 %d/%d 个任务,继续下一批...\n", offset, totalTasks)
processed += len(tasks)
log.Printf("[Migration] 已处理 %d/%d 个任务,继续下一批...\n", processed, totalTasks)
}

log.Printf("[Migration] 迁移完成: %d/%d 个任务已归类\n", totalMigrated, totalTasks)
Expand Down
7 changes: 7 additions & 0 deletions backend/internal/platform/runtime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package platform

import "os"

func IsTauriSidecar() bool {
return os.Getenv("TAURI_PLATFORM") != "" || os.Getenv("TAURI_FAMILY") != ""
}
Loading
Loading