package main import ( "context" "encoding/json" "errors" "fmt" "io" "log" "mime" "net/http" "net/url" "os" "path/filepath" "sort" "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) type appConfig struct { Addr string Region string Bucket string Prefix string FrontendOrigin string EndpointURL string UsePathStyle bool } type fileEntry struct { Key string `json:"key"` Size int64 `json:"size"` LastModified time.Time `json:"lastModified"` } type s3Store struct { client *s3.Client uploader *manager.Uploader bucket string prefix string } func main() { cfg, err := loadConfig() if err != nil { log.Fatal(err) } store, err := newS3Store(cfg) if err != nil { log.Fatal(err) } server := newServer(cfg, store) log.Printf( "storage config bucket=%q region=%q endpoint=%q path_style=%t prefix=%q", cfg.Bucket, cfg.Region, cfg.EndpointURL, cfg.UsePathStyle, cfg.Prefix, ) log.Printf("listening on %s", cfg.Addr) if err := http.ListenAndServe(cfg.Addr, server); err != nil { log.Fatal(err) } } func loadConfig() (appConfig, error) { cfg := appConfig{ Addr: envOrDefault("SERVER_ADDR", ":8080"), Region: strings.TrimSpace(os.Getenv("AWS_REGION")), Bucket: strings.TrimSpace(os.Getenv("S3_BUCKET")), Prefix: strings.Trim(strings.TrimSpace(os.Getenv("S3_PREFIX")), "/"), FrontendOrigin: strings.TrimSpace(os.Getenv("FRONTEND_ORIGIN")), EndpointURL: strings.TrimSpace(os.Getenv("AWS_ENDPOINT_URL")), UsePathStyle: strings.EqualFold(strings.TrimSpace(os.Getenv("AWS_USE_PATH_STYLE")), "true"), } if cfg.Region == "" { return cfg, errors.New("AWS_REGION is required") } if cfg.Bucket == "" { return cfg, errors.New("S3_BUCKET is required") } return cfg, nil } func newS3Store(cfg appConfig) (*s3Store, error) { ctx := context.Background() loadOptions := []func(*awsconfig.LoadOptions) error{ awsconfig.WithRegion(cfg.Region), } accessKey := strings.TrimSpace(os.Getenv("AWS_ACCESS_KEY_ID")) secretKey := strings.TrimSpace(os.Getenv("AWS_SECRET_ACCESS_KEY")) sessionToken := strings.TrimSpace(os.Getenv("AWS_SESSION_TOKEN")) if accessKey != "" && secretKey != "" { loadOptions = append(loadOptions, awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken))) } awsCfg, err := awsconfig.LoadDefaultConfig(ctx, loadOptions...) if err != nil { return nil, fmt.Errorf("load aws config: %w", err) } clientOptions := func(o *s3.Options) { o.UsePathStyle = cfg.UsePathStyle if cfg.EndpointURL != "" { o.BaseEndpoint = aws.String(cfg.EndpointURL) } } client := s3.NewFromConfig(awsCfg, clientOptions) return &s3Store{ client: client, uploader: manager.NewUploader(client), bucket: cfg.Bucket, prefix: cfg.Prefix, }, nil } func newServer(cfg appConfig, store *s3Store) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) }) mux.HandleFunc("/api/files", func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: handleListFiles(store, w, r) case http.MethodPost: handleUploadFile(store, w, r) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } }) mux.HandleFunc("/api/files/", func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: handleDownloadFile(store, w, r) case http.MethodDelete: handleDeleteFile(store, w, r) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } }) return withCORS(withLogging(mux), cfg.FrontendOrigin) } func handleListFiles(store *s3Store, w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() input := &s3.ListObjectsV2Input{ Bucket: aws.String(store.bucket), Prefix: prefixValue(store.prefix), } var files []fileEntry paginator := s3.NewListObjectsV2Paginator(store.client, input) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) if err != nil { writeAPIError(w, http.StatusBadGateway, "list files", err) return } for _, item := range page.Contents { if item.Key == nil { continue } key := strings.TrimPrefix(aws.ToString(item.Key), store.prefix) key = strings.TrimPrefix(key, "/") if key == "" || strings.HasSuffix(key, "/") { continue } files = append(files, fileEntry{ Key: key, Size: aws.ToInt64(item.Size), LastModified: aws.ToTime(item.LastModified), }) } } sort.Slice(files, func(i, j int) bool { return files[i].Key < files[j].Key }) writeJSON(w, http.StatusOK, map[string]any{"files": files}) } func handleUploadFile(store *s3Store, w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(64 << 20); err != nil { writeAPIError(w, http.StatusBadRequest, "parse upload payload", err) return } file, header, err := r.FormFile("file") if err != nil { writeAPIError(w, http.StatusBadRequest, "read upload file", errors.New("missing form field 'file'")) return } defer file.Close() key := sanitizeKey(r.FormValue("key")) if key == "" { key = sanitizeKey(header.Filename) } if key == "" { writeAPIError(w, http.StatusBadRequest, "validate upload key", errors.New("file key is required")) return } contentType := header.Header.Get("Content-Type") if contentType == "" { contentType = mime.TypeByExtension(filepath.Ext(key)) } if contentType == "" { contentType = "application/octet-stream" } ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute) defer cancel() _, err = store.uploader.Upload(ctx, &s3.PutObjectInput{ Bucket: aws.String(store.bucket), Key: aws.String(store.objectKey(key)), Body: file, ContentType: aws.String(contentType), }) if err != nil { writeAPIError(w, http.StatusBadGateway, "upload file", err) return } writeJSON(w, http.StatusCreated, map[string]string{"key": key}) } func handleDownloadFile(store *s3Store, w http.ResponseWriter, r *http.Request) { key := objectKeyFromRequest(r) if key == "" { writeAPIError(w, http.StatusBadRequest, "read download key", errors.New("file key is required")) return } ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute) defer cancel() output, err := store.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(store.bucket), Key: aws.String(store.objectKey(key)), }) if err != nil { var noSuchKey *types.NoSuchKey if errors.As(err, &noSuchKey) { writeAPIError(w, http.StatusNotFound, "download file", errors.New("file not found")) return } writeAPIError(w, http.StatusBadGateway, "download file", err) return } defer output.Body.Close() w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(key))) if output.ContentType != nil { w.Header().Set("Content-Type", aws.ToString(output.ContentType)) } else { w.Header().Set("Content-Type", "application/octet-stream") } w.Header().Set("Content-Length", strconv.FormatInt(aws.ToInt64(output.ContentLength), 10)) if _, err := io.Copy(w, output.Body); err != nil { log.Printf("stream download %q: %v", key, err) } } func handleDeleteFile(store *s3Store, w http.ResponseWriter, r *http.Request) { key := objectKeyFromRequest(r) if key == "" { writeAPIError(w, http.StatusBadRequest, "read delete key", errors.New("file key is required")) return } ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() _, err := store.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(store.bucket), Key: aws.String(store.objectKey(key)), }) if err != nil { writeAPIError(w, http.StatusBadGateway, "delete file", err) return } writeJSON(w, http.StatusOK, map[string]string{"deleted": key}) } func (s *s3Store) objectKey(key string) string { cleanKey := sanitizeKey(key) if s.prefix == "" { return cleanKey } return s.prefix + "/" + cleanKey } func objectKeyFromRequest(r *http.Request) string { encodedKey := strings.TrimPrefix(r.URL.Path, "/api/files/") decoded, err := url.PathUnescape(encodedKey) if err != nil { return "" } return sanitizeKey(decoded) } func sanitizeKey(value string) string { parts := strings.Split(strings.TrimSpace(value), "/") cleaned := make([]string, 0, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) if part == "" || part == "." || part == ".." { continue } cleaned = append(cleaned, part) } return strings.Join(cleaned, "/") } func prefixValue(prefix string) *string { if prefix == "" { return nil } value := prefix + "/" return &value } func envOrDefault(name, fallback string) string { if value := strings.TrimSpace(os.Getenv(name)); value != "" { return value } return fallback } func writeJSON(w http.ResponseWriter, status int, payload any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(payload); err != nil { log.Printf("write json: %v", err) } } func writeAPIError(w http.ResponseWriter, status int, operation string, err error) { message := fmt.Sprintf("%s: %v", operation, err) log.Printf("error %s", message) writeJSON(w, status, map[string]string{"error": message}) } func withLogging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start).Round(time.Millisecond)) }) } func withCORS(next http.Handler, allowedOrigin string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if allowedOrigin != "" && r.Header.Get("Origin") == allowedOrigin { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Headers", "Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS") } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) }