panel/web/server.go

283 lines
7.7 KiB
Go
Raw Normal View History

2024-11-03 21:33:08 +01:00
package web
import (
"bufio"
"embed"
"encoding/gob"
"errors"
"fmt"
"github.com/a-h/templ"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/middleware/earlydata"
"github.com/gofiber/fiber/v2/middleware/etag"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/healthcheck"
"github.com/gofiber/fiber/v2/middleware/helmet"
"github.com/gofiber/fiber/v2/middleware/limiter"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/fiber/v2/middleware/session"
"github.com/gofiber/storage/postgres/v3"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/spf13/viper"
"github.com/uptrace/bun"
"go.uber.org/zap"
2024-12-20 23:29:45 +01:00
"golang.org/x/sys/unix"
2024-11-03 21:33:08 +01:00
"net/http"
"net/url"
2024-12-20 23:29:45 +01:00
_ "omnibill.net/omnibill/web/handlers"
2024-11-03 21:33:08 +01:00
"omnibill.net/omnibill/web/utils"
"omnibill.net/omnibill/web/views/layouts"
2024-12-20 23:29:45 +01:00
"os"
"os/signal"
2024-11-03 21:33:08 +01:00
"reflect"
"strings"
"time"
)
//go:embed assets/**/*
var assetDir embed.FS
func Start(logger *zap.Logger, db *bun.DB, dbPool *pgxpool.Pool) {
panelURL, err := url.Parse(viper.GetString("omnibill.domain"))
if err != nil {
logger.Fatal("error parsing panel URL", zap.Error(err))
}
gob.Register(&webauthn.SessionData{})
webAuthnConfig := &webauthn.Config{
RPDisplayName: viper.GetString("omnibill.display_name"),
RPID: panelURL.Host,
RPOrigins: []string{panelURL.String()},
}
webAuthn, err := webauthn.New(webAuthnConfig)
if err != nil {
logger.Fatal("error creating webauthn", zap.Error(err))
}
appConfig := fiber.Config{
AppName: viper.GetString("omnibill.display_name"),
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
}
if len(viper.GetString("omnibill.webserver.proxy")) != 0 {
switch strings.ToLower(viper.GetString("omnibill.webserver.proxy")) {
case "cloudflare", "cf":
logger.Info("grabbing trusted proxy list")
var trustedProxies []string
v4Req, err := http.NewRequest("GET", "https://www.cloudflare.com/ips-v4/#", nil)
if err != nil {
logger.Fatal("error creating request", zap.Error(err))
}
v6Req, err := http.NewRequest("GET", "https://www.cloudflare.com/ips-v6/#", nil)
if err != nil {
logger.Fatal("error creating request", zap.Error(err))
}
client := &http.Client{}
v4Resp, err := client.Do(v4Req)
if err != nil {
logger.Fatal("error doing request", zap.Error(err))
}
defer v4Resp.Body.Close()
v4Scanner := bufio.NewScanner(v4Resp.Body)
v4Scanner.Split(bufio.ScanLines)
for v4Scanner.Scan() {
trustedProxies = append(trustedProxies, v4Scanner.Text())
}
v6Resp, err := client.Do(v6Req)
if err != nil {
logger.Fatal("error doing request", zap.Error(err))
}
defer v6Resp.Body.Close()
v6Scanner := bufio.NewScanner(v6Resp.Body)
v6Scanner.Split(bufio.ScanLines)
for v6Scanner.Scan() {
trustedProxies = append(trustedProxies, v6Scanner.Text())
}
appConfig.ProxyHeader = "X-Forwarded-For"
appConfig.TrustedProxies = trustedProxies
case "none":
default:
log.Warnf("Proxy '%s' is not supported", viper.GetString("omnibill.webserver.proxy"))
}
}
app := fiber.New(appConfig)
app.Use(recover.New())
app.Use(earlydata.New())
app.Use(healthcheck.New())
app.Use(helmet.New())
app.Use(etag.New())
app.Use(limiter.New(limiter.Config{
Max: 250,
Expiration: 3 * time.Second,
LimiterMiddleware: limiter.SlidingWindow{},
}))
app.Use("/assets", filesystem.New(filesystem.Config{
Root: http.FS(assetDir),
2024-12-20 23:29:45 +01:00
PathPrefix: "assets/dist",
2024-11-03 21:33:08 +01:00
Browse: false,
}))
storage := postgres.New(postgres.Config{
DB: dbPool,
Table: "sessions",
})
authSessionStore := session.New(session.Config{
Storage: storage,
})
sessionStore := session.New(session.Config{
KeyLookup: "cookie:osession",
})
for _, handler := range utils.Handlers {
handlerValue := reflect.ValueOf(handler).Elem()
2024-12-20 23:29:45 +01:00
pathField, ok := handlerValue.Type().FieldByName("Path")
2024-11-03 21:33:08 +01:00
if !ok {
fmt.Println("invalid handler")
continue
}
var requireAuth bool
omnibillTag := pathField.Tag.Get("omnibill")
for _, option := range strings.Split(omnibillTag, ",") {
switch option {
case "requireAuth":
requireAuth = true
}
}
var pathHandlers []fiber.Handler
if requireAuth {
pathHandlers = append(pathHandlers, nil)
}
handlerValue.FieldByName("Db").Set(reflect.ValueOf(db))
handlerValue.FieldByName("AuthSessionStore").Set(reflect.ValueOf(authSessionStore))
handlerValue.FieldByName("SessionStore").Set(reflect.ValueOf(sessionStore))
handlerValue.FieldByName("Logger").Set(reflect.ValueOf(logger))
handlerValue.FieldByName("WebAuthn").Set(reflect.ValueOf(webAuthn))
path := handlerValue.FieldByName("Path").String()
if path == "index" {
path = ""
}
path = "/" + path
if iHandler, ok := handler.(utils.GET); ok {
pathHandlers = append(pathHandlers, iHandler.Get)
app.Get(path, func(ctx *fiber.Ctx) error {
sess, err := sessionStore.Get(ctx)
if err != nil {
return fiber.ErrInternalServerError
}
handlerValue.FieldByName("Session").Set(reflect.ValueOf(sess))
for _, pathHandler := range pathHandlers {
err := pathHandler(ctx)
if err != nil {
var e *fiber.Error
if errors.As(err, &e) {
return utils.Render(ctx, layouts.Error(*e), templ.WithStatus(e.Code))
} else {
return err
}
}
}
return nil
})
}
if iHandler, ok := handler.(utils.POST); ok {
pathHandlers = append(pathHandlers, iHandler.Post)
app.Post(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
if iHandler, ok := handler.(utils.PUT); ok {
pathHandlers = append(pathHandlers, iHandler.Put)
app.Put(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
if iHandler, ok := handler.(utils.DELETE); ok {
pathHandlers = append(pathHandlers, iHandler.Delete)
app.Delete(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
if iHandler, ok := handler.(utils.PATCH); ok {
pathHandlers = append(pathHandlers, iHandler.Patch)
app.Patch(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
if iHandler, ok := handler.(utils.OPTIONS); ok {
pathHandlers = append(pathHandlers, iHandler.Options)
app.Options(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
if iHandler, ok := handler.(utils.HEAD); ok {
pathHandlers = append(pathHandlers, iHandler.Head)
app.Head(path, func(ctx *fiber.Ctx) error {
return genericPathHandler(ctx, handlerValue, sessionStore, pathHandlers)
})
}
}
2024-12-20 23:29:45 +01:00
go func() {
if err := app.Listen(fmt.Sprintf("%s:%d", viper.GetString("omnibill.webserver.host"), viper.GetInt("omnibill.webserver.port"))); err != nil {
logger.Fatal("error running server", zap.Error(err))
}
}()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, unix.SIGTERM)
_ = <-sigChan
logger.Info("Shutting Down...")
_ = app.Shutdown()
2024-11-03 21:33:08 +01:00
}
func genericPathHandler(ctx *fiber.Ctx, handler reflect.Value, sessionStore *session.Store, handlers []fiber.Handler) error {
sess, err := sessionStore.Get(ctx)
if err != nil {
return fiber.ErrInternalServerError
}
handler.FieldByName("Session").Set(reflect.ValueOf(sess))
for _, pathHandler := range handlers {
err := pathHandler(ctx)
if err != nil {
var e *fiber.Error
if errors.As(err, &e) {
return err
} else {
return fiber.ErrInternalServerError
}
}
}
return nil
}