initial commit
This commit is contained in:
388
backend/main.go
Normal file
388
backend/main.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
)
|
||||
|
||||
type appConfig struct {
|
||||
Addr string
|
||||
Region string
|
||||
Bucket string
|
||||
Prefix string
|
||||
FrontendOrigin string
|
||||
EndpointURL string
|
||||
UsePathStyle bool
|
||||
}
|
||||
|
||||
type fileEntry struct {
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
LastModified time.Time `json:"lastModified"`
|
||||
}
|
||||
|
||||
type s3Store struct {
|
||||
client *s3.Client
|
||||
uploader *manager.Uploader
|
||||
bucket string
|
||||
prefix string
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := newS3Store(cfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server := newServer(cfg, store)
|
||||
log.Printf(
|
||||
"storage config bucket=%q region=%q endpoint=%q path_style=%t prefix=%q",
|
||||
cfg.Bucket,
|
||||
cfg.Region,
|
||||
cfg.EndpointURL,
|
||||
cfg.UsePathStyle,
|
||||
cfg.Prefix,
|
||||
)
|
||||
log.Printf("listening on %s", cfg.Addr)
|
||||
if err := http.ListenAndServe(cfg.Addr, server); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func loadConfig() (appConfig, error) {
|
||||
cfg := appConfig{
|
||||
Addr: envOrDefault("SERVER_ADDR", ":8080"),
|
||||
Region: strings.TrimSpace(os.Getenv("AWS_REGION")),
|
||||
Bucket: strings.TrimSpace(os.Getenv("S3_BUCKET")),
|
||||
Prefix: strings.Trim(strings.TrimSpace(os.Getenv("S3_PREFIX")), "/"),
|
||||
FrontendOrigin: strings.TrimSpace(os.Getenv("FRONTEND_ORIGIN")),
|
||||
EndpointURL: strings.TrimSpace(os.Getenv("AWS_ENDPOINT_URL")),
|
||||
UsePathStyle: strings.EqualFold(strings.TrimSpace(os.Getenv("AWS_USE_PATH_STYLE")), "true"),
|
||||
}
|
||||
|
||||
if cfg.Region == "" {
|
||||
return cfg, errors.New("AWS_REGION is required")
|
||||
}
|
||||
if cfg.Bucket == "" {
|
||||
return cfg, errors.New("S3_BUCKET is required")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func newS3Store(cfg appConfig) (*s3Store, error) {
|
||||
ctx := context.Background()
|
||||
loadOptions := []func(*awsconfig.LoadOptions) error{
|
||||
awsconfig.WithRegion(cfg.Region),
|
||||
}
|
||||
|
||||
accessKey := strings.TrimSpace(os.Getenv("AWS_ACCESS_KEY_ID"))
|
||||
secretKey := strings.TrimSpace(os.Getenv("AWS_SECRET_ACCESS_KEY"))
|
||||
sessionToken := strings.TrimSpace(os.Getenv("AWS_SESSION_TOKEN"))
|
||||
if accessKey != "" && secretKey != "" {
|
||||
loadOptions = append(loadOptions, awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)))
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, loadOptions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
clientOptions := func(o *s3.Options) {
|
||||
o.UsePathStyle = cfg.UsePathStyle
|
||||
if cfg.EndpointURL != "" {
|
||||
o.BaseEndpoint = aws.String(cfg.EndpointURL)
|
||||
}
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, clientOptions)
|
||||
return &s3Store{
|
||||
client: client,
|
||||
uploader: manager.NewUploader(client),
|
||||
bucket: cfg.Bucket,
|
||||
prefix: cfg.Prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newServer(cfg appConfig, store *s3Store) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
})
|
||||
mux.HandleFunc("/api/files", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
handleListFiles(store, w, r)
|
||||
case http.MethodPost:
|
||||
handleUploadFile(store, w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/files/", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
handleDownloadFile(store, w, r)
|
||||
case http.MethodDelete:
|
||||
handleDeleteFile(store, w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
return withCORS(withLogging(mux), cfg.FrontendOrigin)
|
||||
}
|
||||
|
||||
func handleListFiles(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(store.bucket),
|
||||
Prefix: prefixValue(store.prefix),
|
||||
}
|
||||
|
||||
var files []fileEntry
|
||||
paginator := s3.NewListObjectsV2Paginator(store.client, input)
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusBadGateway, "list files", err)
|
||||
return
|
||||
}
|
||||
for _, item := range page.Contents {
|
||||
if item.Key == nil {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimPrefix(aws.ToString(item.Key), store.prefix)
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
if key == "" || strings.HasSuffix(key, "/") {
|
||||
continue
|
||||
}
|
||||
files = append(files, fileEntry{
|
||||
Key: key,
|
||||
Size: aws.ToInt64(item.Size),
|
||||
LastModified: aws.ToTime(item.LastModified),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Key < files[j].Key
|
||||
})
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{"files": files})
|
||||
}
|
||||
|
||||
func handleUploadFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(64 << 20); err != nil {
|
||||
writeAPIError(w, http.StatusBadRequest, "parse upload payload", err)
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusBadRequest, "read upload file", errors.New("missing form field 'file'"))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
key := sanitizeKey(r.FormValue("key"))
|
||||
if key == "" {
|
||||
key = sanitizeKey(header.Filename)
|
||||
}
|
||||
if key == "" {
|
||||
writeAPIError(w, http.StatusBadRequest, "validate upload key", errors.New("file key is required"))
|
||||
return
|
||||
}
|
||||
|
||||
contentType := header.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = mime.TypeByExtension(filepath.Ext(key))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
_, err = store.uploader.Upload(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(store.bucket),
|
||||
Key: aws.String(store.objectKey(key)),
|
||||
Body: file,
|
||||
ContentType: aws.String(contentType),
|
||||
})
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusBadGateway, "upload file", err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"key": key})
|
||||
}
|
||||
|
||||
func handleDownloadFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
||||
key := objectKeyFromRequest(r)
|
||||
if key == "" {
|
||||
writeAPIError(w, http.StatusBadRequest, "read download key", errors.New("file key is required"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
output, err := store.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(store.bucket),
|
||||
Key: aws.String(store.objectKey(key)),
|
||||
})
|
||||
if err != nil {
|
||||
var noSuchKey *types.NoSuchKey
|
||||
if errors.As(err, &noSuchKey) {
|
||||
writeAPIError(w, http.StatusNotFound, "download file", errors.New("file not found"))
|
||||
return
|
||||
}
|
||||
writeAPIError(w, http.StatusBadGateway, "download file", err)
|
||||
return
|
||||
}
|
||||
defer output.Body.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(key)))
|
||||
if output.ContentType != nil {
|
||||
w.Header().Set("Content-Type", aws.ToString(output.ContentType))
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(aws.ToInt64(output.ContentLength), 10))
|
||||
if _, err := io.Copy(w, output.Body); err != nil {
|
||||
log.Printf("stream download %q: %v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleDeleteFile(store *s3Store, w http.ResponseWriter, r *http.Request) {
|
||||
key := objectKeyFromRequest(r)
|
||||
if key == "" {
|
||||
writeAPIError(w, http.StatusBadRequest, "read delete key", errors.New("file key is required"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := store.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(store.bucket),
|
||||
Key: aws.String(store.objectKey(key)),
|
||||
})
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusBadGateway, "delete file", err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"deleted": key})
|
||||
}
|
||||
|
||||
func (s *s3Store) objectKey(key string) string {
|
||||
cleanKey := sanitizeKey(key)
|
||||
if s.prefix == "" {
|
||||
return cleanKey
|
||||
}
|
||||
return s.prefix + "/" + cleanKey
|
||||
}
|
||||
|
||||
func objectKeyFromRequest(r *http.Request) string {
|
||||
encodedKey := strings.TrimPrefix(r.URL.Path, "/api/files/")
|
||||
decoded, err := url.PathUnescape(encodedKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return sanitizeKey(decoded)
|
||||
}
|
||||
|
||||
func sanitizeKey(value string) string {
|
||||
parts := strings.Split(strings.TrimSpace(value), "/")
|
||||
cleaned := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" || part == "." || part == ".." {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, part)
|
||||
}
|
||||
return strings.Join(cleaned, "/")
|
||||
}
|
||||
|
||||
func prefixValue(prefix string) *string {
|
||||
if prefix == "" {
|
||||
return nil
|
||||
}
|
||||
value := prefix + "/"
|
||||
return &value
|
||||
}
|
||||
|
||||
func envOrDefault(name, fallback string) string {
|
||||
if value := strings.TrimSpace(os.Getenv(name)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(payload); err != nil {
|
||||
log.Printf("write json: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeAPIError(w http.ResponseWriter, status int, operation string, err error) {
|
||||
message := fmt.Sprintf("%s: %v", operation, err)
|
||||
log.Printf("error %s", message)
|
||||
writeJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
func withLogging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
next.ServeHTTP(w, r)
|
||||
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start).Round(time.Millisecond))
|
||||
})
|
||||
}
|
||||
|
||||
func withCORS(next http.Handler, allowedOrigin string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if allowedOrigin != "" && r.Header.Get("Origin") == allowedOrigin {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user