389 lines
10 KiB
Go
389 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
|
)
|
|
|
|
type appConfig struct {
|
|
Addr string
|
|
Region string
|
|
Bucket string
|
|
Prefix string
|
|
FrontendOrigin string
|
|
EndpointURL string
|
|
UsePathStyle bool
|
|
}
|
|
|
|
type fileEntry struct {
|
|
Key string `json:"key"`
|
|
Size int64 `json:"size"`
|
|
LastModified time.Time `json:"lastModified"`
|
|
}
|
|
|
|
type s3Store struct {
|
|
client *s3.Client
|
|
uploader *manager.Uploader
|
|
bucket string
|
|
prefix string
|
|
}
|
|
|
|
func main() {
|
|
cfg, err := loadConfig()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
store, err := newS3Store(cfg)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
server := newServer(cfg, store)
|
|
log.Printf(
|
|
"storage config bucket=%q region=%q endpoint=%q path_style=%t prefix=%q",
|
|
cfg.Bucket,
|
|
cfg.Region,
|
|
cfg.EndpointURL,
|
|
cfg.UsePathStyle,
|
|
cfg.Prefix,
|
|
)
|
|
log.Printf("listening on %s", cfg.Addr)
|
|
if err := http.ListenAndServe(cfg.Addr, server); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func loadConfig() (appConfig, error) {
|
|
cfg := appConfig{
|
|
Addr: envOrDefault("SERVER_ADDR", ":8080"),
|
|
Region: strings.TrimSpace(os.Getenv("AWS_REGION")),
|
|
Bucket: strings.TrimSpace(os.Getenv("S3_BUCKET")),
|
|
Prefix: strings.Trim(strings.TrimSpace(os.Getenv("S3_PREFIX")), "/"),
|
|
FrontendOrigin: strings.TrimSpace(os.Getenv("FRONTEND_ORIGIN")),
|
|
EndpointURL: strings.TrimSpace(os.Getenv("AWS_ENDPOINT_URL")),
|
|
UsePathStyle: strings.EqualFold(strings.TrimSpace(os.Getenv("AWS_USE_PATH_STYLE")), "true"),
|
|
}
|
|
|
|
if cfg.Region == "" {
|
|
return cfg, errors.New("AWS_REGION is required")
|
|
}
|
|
if cfg.Bucket == "" {
|
|
return cfg, errors.New("S3_BUCKET is required")
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func newS3Store(cfg appConfig) (*s3Store, error) {
|
|
ctx := context.Background()
|
|
loadOptions := []func(*awsconfig.LoadOptions) error{
|
|
awsconfig.WithRegion(cfg.Region),
|
|
}
|
|
|
|
accessKey := strings.TrimSpace(os.Getenv("AWS_ACCESS_KEY_ID"))
|
|
secretKey := strings.TrimSpace(os.Getenv("AWS_SECRET_ACCESS_KEY"))
|
|
sessionToken := strings.TrimSpace(os.Getenv("AWS_SESSION_TOKEN"))
|
|
if accessKey != "" && secretKey != "" {
|
|
loadOptions = append(loadOptions, awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)))
|
|
}
|
|
|
|
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, loadOptions...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load aws config: %w", err)
|
|
}
|
|
|
|
clientOptions := func(o *s3.Options) {
|
|
o.UsePathStyle = cfg.UsePathStyle
|
|
if cfg.EndpointURL != "" {
|
|
o.BaseEndpoint = aws.String(cfg.EndpointURL)
|
|
}
|
|
}
|
|
|
|
client := s3.NewFromConfig(awsCfg, clientOptions)
|
|
return &s3Store{
|
|
client: client,
|
|
uploader: manager.NewUploader(client),
|
|
bucket: cfg.Bucket,
|
|
prefix: cfg.Prefix,
|
|
}, nil
|
|
}
|
|
|
|
func newServer(cfg appConfig, store *s3Store) http.Handler {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) {
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
|
})
|
|
mux.HandleFunc("/api/files", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
handleListFiles(store, w, r)
|
|
case http.MethodPost:
|
|
handleUploadFile(store, w, r)
|
|
default:
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
})
|
|
mux.HandleFunc("/api/files/", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
handleDownloadFile(store, w, r)
|
|
case http.MethodDelete:
|
|
handleDeleteFile(store, w, r)
|
|
default:
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
})
|
|
|
|
return withCORS(withLogging(mux), cfg.FrontendOrigin)
|
|
}
|
|
|
|
func handleListFiles(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
input := &s3.ListObjectsV2Input{
|
|
Bucket: aws.String(store.bucket),
|
|
Prefix: prefixValue(store.prefix),
|
|
}
|
|
|
|
var files []fileEntry
|
|
paginator := s3.NewListObjectsV2Paginator(store.client, input)
|
|
for paginator.HasMorePages() {
|
|
page, err := paginator.NextPage(ctx)
|
|
if err != nil {
|
|
writeAPIError(w, http.StatusBadGateway, "list files", err)
|
|
return
|
|
}
|
|
for _, item := range page.Contents {
|
|
if item.Key == nil {
|
|
continue
|
|
}
|
|
key := strings.TrimPrefix(aws.ToString(item.Key), store.prefix)
|
|
key = strings.TrimPrefix(key, "/")
|
|
if key == "" || strings.HasSuffix(key, "/") {
|
|
continue
|
|
}
|
|
files = append(files, fileEntry{
|
|
Key: key,
|
|
Size: aws.ToInt64(item.Size),
|
|
LastModified: aws.ToTime(item.LastModified),
|
|
})
|
|
}
|
|
}
|
|
|
|
sort.Slice(files, func(i, j int) bool {
|
|
return files[i].Key < files[j].Key
|
|
})
|
|
|
|
writeJSON(w, http.StatusOK, map[string]any{"files": files})
|
|
}
|
|
|
|
func handleUploadFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseMultipartForm(64 << 20); err != nil {
|
|
writeAPIError(w, http.StatusBadRequest, "parse upload payload", err)
|
|
return
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
writeAPIError(w, http.StatusBadRequest, "read upload file", errors.New("missing form field 'file'"))
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
key := sanitizeKey(r.FormValue("key"))
|
|
if key == "" {
|
|
key = sanitizeKey(header.Filename)
|
|
}
|
|
if key == "" {
|
|
writeAPIError(w, http.StatusBadRequest, "validate upload key", errors.New("file key is required"))
|
|
return
|
|
}
|
|
|
|
contentType := header.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = mime.TypeByExtension(filepath.Ext(key))
|
|
}
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
_, err = store.uploader.Upload(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(store.bucket),
|
|
Key: aws.String(store.objectKey(key)),
|
|
Body: file,
|
|
ContentType: aws.String(contentType),
|
|
})
|
|
if err != nil {
|
|
writeAPIError(w, http.StatusBadGateway, "upload file", err)
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusCreated, map[string]string{"key": key})
|
|
}
|
|
|
|
func handleDownloadFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
|
key := objectKeyFromRequest(r)
|
|
if key == "" {
|
|
writeAPIError(w, http.StatusBadRequest, "read download key", errors.New("file key is required"))
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
output, err := store.client.GetObject(ctx, &s3.GetObjectInput{
|
|
Bucket: aws.String(store.bucket),
|
|
Key: aws.String(store.objectKey(key)),
|
|
})
|
|
if err != nil {
|
|
var noSuchKey *types.NoSuchKey
|
|
if errors.As(err, &noSuchKey) {
|
|
writeAPIError(w, http.StatusNotFound, "download file", errors.New("file not found"))
|
|
return
|
|
}
|
|
writeAPIError(w, http.StatusBadGateway, "download file", err)
|
|
return
|
|
}
|
|
defer output.Body.Close()
|
|
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(key)))
|
|
if output.ContentType != nil {
|
|
w.Header().Set("Content-Type", aws.ToString(output.ContentType))
|
|
} else {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
}
|
|
w.Header().Set("Content-Length", strconv.FormatInt(aws.ToInt64(output.ContentLength), 10))
|
|
if _, err := io.Copy(w, output.Body); err != nil {
|
|
log.Printf("stream download %q: %v", key, err)
|
|
}
|
|
}
|
|
|
|
func handleDeleteFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
|
key := objectKeyFromRequest(r)
|
|
if key == "" {
|
|
writeAPIError(w, http.StatusBadRequest, "read delete key", errors.New("file key is required"))
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
_, err := store.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
|
Bucket: aws.String(store.bucket),
|
|
Key: aws.String(store.objectKey(key)),
|
|
})
|
|
if err != nil {
|
|
writeAPIError(w, http.StatusBadGateway, "delete file", err)
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]string{"deleted": key})
|
|
}
|
|
|
|
func (s *s3Store) objectKey(key string) string {
|
|
cleanKey := sanitizeKey(key)
|
|
if s.prefix == "" {
|
|
return cleanKey
|
|
}
|
|
return s.prefix + "/" + cleanKey
|
|
}
|
|
|
|
func objectKeyFromRequest(r *http.Request) string {
|
|
encodedKey := strings.TrimPrefix(r.URL.Path, "/api/files/")
|
|
decoded, err := url.PathUnescape(encodedKey)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return sanitizeKey(decoded)
|
|
}
|
|
|
|
func sanitizeKey(value string) string {
|
|
parts := strings.Split(strings.TrimSpace(value), "/")
|
|
cleaned := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" || part == "." || part == ".." {
|
|
continue
|
|
}
|
|
cleaned = append(cleaned, part)
|
|
}
|
|
return strings.Join(cleaned, "/")
|
|
}
|
|
|
|
func prefixValue(prefix string) *string {
|
|
if prefix == "" {
|
|
return nil
|
|
}
|
|
value := prefix + "/"
|
|
return &value
|
|
}
|
|
|
|
func envOrDefault(name, fallback string) string {
|
|
if value := strings.TrimSpace(os.Getenv(name)); value != "" {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
if err := json.NewEncoder(w).Encode(payload); err != nil {
|
|
log.Printf("write json: %v", err)
|
|
}
|
|
}
|
|
|
|
func writeAPIError(w http.ResponseWriter, status int, operation string, err error) {
|
|
message := fmt.Sprintf("%s: %v", operation, err)
|
|
log.Printf("error %s", message)
|
|
writeJSON(w, status, map[string]string{"error": message})
|
|
}
|
|
|
|
func withLogging(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start).Round(time.Millisecond))
|
|
})
|
|
}
|
|
|
|
func withCORS(next http.Handler, allowedOrigin string) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if allowedOrigin != "" && r.Header.Get("Origin") == allowedOrigin {
|
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
|
|
}
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|