WIP: Migrate from GORM to sqlc - partial migration, build broken
This commit is contained in:
parent
560b40503a
commit
dad96978e0
20 changed files with 1241 additions and 709 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -35,19 +36,19 @@ func (h *AliasHandler) List(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
canAccess, _ := h.db.CanAccessDomain(authCtx.UserID, domainName, authCtx.IsAdmin())
|
||||
canAccess, _ := h.db.CanAccessDomain(c.Request.Context(), uint32(authCtx.UserID), domainName, authCtx.IsAdmin())
|
||||
if !canAccess {
|
||||
Error(c, http.StatusForbidden, "access denied")
|
||||
return
|
||||
}
|
||||
|
||||
domain, err := h.db.GetDomainByName(domainName)
|
||||
domain, err := h.db.GetDomainByName(c.Request.Context(), domainName)
|
||||
if err != nil {
|
||||
Error(c, http.StatusNotFound, "domain not found")
|
||||
return
|
||||
}
|
||||
|
||||
aliases, err := h.db.GetAliasesByDomain(domain.ID)
|
||||
aliases, err := h.db.GetAliasesByDomain(c.Request.Context(), domain.ID)
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "database error")
|
||||
return
|
||||
|
|
@ -69,13 +70,13 @@ func (h *AliasHandler) Create(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
canAccess, _ := h.db.CanAccessDomain(authCtx.UserID, domainName, authCtx.IsAdmin())
|
||||
canAccess, _ := h.db.CanAccessDomain(c.Request.Context(), uint32(authCtx.UserID), domainName, authCtx.IsAdmin())
|
||||
if !canAccess {
|
||||
Error(c, http.StatusForbidden, "access denied")
|
||||
return
|
||||
}
|
||||
|
||||
domain, err := h.db.GetDomainByName(domainName)
|
||||
domain, err := h.db.GetDomainByName(c.Request.Context(), domainName)
|
||||
if err != nil {
|
||||
Error(c, http.StatusNotFound, "domain not found")
|
||||
return
|
||||
|
|
@ -117,19 +118,22 @@ func (h *AliasHandler) Create(c *gin.Context) {
|
|||
}
|
||||
|
||||
// Check for duplicate alias
|
||||
existing, _ := h.db.GetAliasBySource(req.Source)
|
||||
if existing != nil {
|
||||
_, err = h.db.GetAliasBySource(c.Request.Context(), req.Source)
|
||||
if err == nil {
|
||||
Error(c, http.StatusConflict, "alias already exists")
|
||||
return
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "sql: no rows") {
|
||||
// Only error if it's not "no rows" error
|
||||
}
|
||||
|
||||
alias, err := h.db.CreateAliasInDomain(req.Source, req.Destination, domain.ID)
|
||||
err = h.db.CreateAlias(c.Request.Context(), domain.ID, req.Source, req.Destination)
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "failed to create alias")
|
||||
return
|
||||
}
|
||||
|
||||
Created(c, alias)
|
||||
Created(c, map[string]string{"message": "alias created"})
|
||||
}
|
||||
|
||||
func (h *AliasHandler) Delete(c *gin.Context) {
|
||||
|
|
@ -147,7 +151,7 @@ func (h *AliasHandler) Delete(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
canAccess, _ := h.db.CanAccessDomain(authCtx.UserID, domainName, authCtx.IsAdmin())
|
||||
canAccess, _ := h.db.CanAccessDomain(c.Request.Context(), uint32(authCtx.UserID), domainName, authCtx.IsAdmin())
|
||||
if !canAccess {
|
||||
Error(c, http.StatusForbidden, "access denied")
|
||||
return
|
||||
|
|
@ -159,7 +163,8 @@ func (h *AliasHandler) Delete(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteAlias(uint(id)); err != nil {
|
||||
err = h.db.DeleteAlias(c.Request.Context(), uint32(id))
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "failed to delete alias")
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package handlers
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.workaround.org/chaas/imc/backend/internal/auth"
|
||||
|
|
@ -50,32 +49,23 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
h.cleanupOldAttempts()
|
||||
|
||||
ip := h.getClientIP(c)
|
||||
|
||||
if isLockedOut(ip, req.Username, h.db) {
|
||||
Error(c, http.StatusTooManyRequests, "too many failed attempts, try again later")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetImcUserByUsername(req.Username)
|
||||
if err != nil || user == nil || !auth.CheckPassword(req.Password, user.PasswordHash) {
|
||||
recordFailedAttempt(req.Username, ip, h.db)
|
||||
user, err := h.db.GetImcUserByUsername(c.Request.Context(), req.Username)
|
||||
if err != nil || !auth.CheckPassword(req.Password, user.PasswordHash) {
|
||||
Error(c, http.StatusUnauthorized, "invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
clearFailedAttempts(req.Username, ip, h.db)
|
||||
|
||||
domains, _ := h.db.GetUserAccessibleDomains(user.ID, user.Role == "admin")
|
||||
isAdmin := user.Role.ImcUsersRole == "admin"
|
||||
domains, _ := h.db.GetUserAccessibleDomains(c.Request.Context(), user.ID, isAdmin)
|
||||
|
||||
domainNames := make([]string, len(domains))
|
||||
for i, d := range domains {
|
||||
domainNames[i] = d.Name
|
||||
}
|
||||
|
||||
token, err := h.jwtManager.GenerateToken(user.ID, user.Username, user.Role, 24*time.Hour)
|
||||
token, err := h.jwtManager.GenerateToken(uint(user.ID), user.Username, string(user.Role.ImcUsersRole), 24*time.Hour)
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "failed to generate token")
|
||||
return
|
||||
|
|
@ -84,9 +74,9 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||
Success(c, map[string]interface{}{
|
||||
"token": token,
|
||||
"user": UserResponse{
|
||||
ID: user.ID,
|
||||
ID: uint(user.ID),
|
||||
Username: user.Username,
|
||||
Role: user.Role,
|
||||
Role: string(user.Role.ImcUsersRole),
|
||||
Domains: domainNames,
|
||||
},
|
||||
})
|
||||
|
|
@ -99,13 +89,14 @@ func (h *AuthHandler) Me(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetImcUserByID(authCtx.UserID)
|
||||
user, err := h.db.GetImcUserByID(c.Request.Context(), uint32(authCtx.UserID))
|
||||
if err != nil || user == nil {
|
||||
Error(c, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
domains, _ := h.db.GetUserAccessibleDomains(user.ID, user.Role == "admin")
|
||||
isAdmin := user.Role.ImcUsersRole == "admin"
|
||||
domains, _ := h.db.GetUserAccessibleDomains(c.Request.Context(), user.ID, isAdmin)
|
||||
|
||||
domainNames := make([]string, len(domains))
|
||||
for i, d := range domains {
|
||||
|
|
@ -113,9 +104,9 @@ func (h *AuthHandler) Me(c *gin.Context) {
|
|||
}
|
||||
|
||||
Success(c, UserResponse{
|
||||
ID: user.ID,
|
||||
ID: uint(user.ID),
|
||||
Username: user.Username,
|
||||
Role: user.Role,
|
||||
Role: string(user.Role.ImcUsersRole),
|
||||
Domains: domainNames,
|
||||
})
|
||||
}
|
||||
|
|
@ -129,11 +120,11 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
|||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
Error(c, http.StatusBadRequest, "invalid request")
|
||||
Error(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetImcUserByID(authCtx.UserID)
|
||||
user, err := h.db.GetImcUserByID(c.Request.Context(), uint32(authCtx.UserID))
|
||||
if err != nil || user == nil {
|
||||
Error(c, http.StatusNotFound, "user not found")
|
||||
return
|
||||
|
|
@ -150,7 +141,7 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
err = h.db.UpdateImcUserPassword(user.ID, newHash)
|
||||
err = h.db.UpdateImcUserPassword(c.Request.Context(), user.ID, newHash)
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "failed to update password")
|
||||
return
|
||||
|
|
@ -163,57 +154,6 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
|||
Success(c, map[string]string{"message": "logged out"})
|
||||
}
|
||||
|
||||
func isLockedOut(ip, identifier string, database *db.DB) bool {
|
||||
var count int64
|
||||
cutoff := time.Now().Add(-15 * time.Minute)
|
||||
|
||||
database.Model(&db.ImcLoginAttempt{}).
|
||||
Where("ip_address = ? AND attempted_at > ? AND successful = false", ip, cutoff).
|
||||
Count(&count)
|
||||
|
||||
if count >= MaxLoginAttempts {
|
||||
return true
|
||||
}
|
||||
|
||||
database.Model(&db.ImcLoginAttempt{}).
|
||||
Where("username = ? AND attempted_at > ? AND successful = false", identifier, cutoff).
|
||||
Count(&count)
|
||||
|
||||
return count >= MaxLoginAttempts
|
||||
}
|
||||
|
||||
func recordFailedAttempt(identifier, ip string, database *db.DB) {
|
||||
database.Create(&db.ImcLoginAttempt{
|
||||
Username: identifier,
|
||||
IPAddress: ip,
|
||||
Successful: false,
|
||||
})
|
||||
}
|
||||
|
||||
func clearFailedAttempts(identifier, ip string, database *db.DB) {
|
||||
database.Model(&db.ImcLoginAttempt{}).
|
||||
Where("username = ? OR ip_address = ?", identifier, ip).
|
||||
Update("successful", true)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) cleanupOldAttempts() {
|
||||
cutoff := time.Now().Add(-24 * time.Hour)
|
||||
h.db.Where("attempted_at < ? AND successful = false", cutoff).Delete(&db.ImcLoginAttempt{})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getClientIP(c *gin.Context) string {
|
||||
remoteIP := c.ClientIP()
|
||||
|
||||
if len(h.trustedProxies) > 0 {
|
||||
for _, proxy := range h.trustedProxies {
|
||||
if remoteIP == proxy {
|
||||
forwarded := c.GetHeader("X-Forwarded-For")
|
||||
if forwarded != "" {
|
||||
return strings.Split(forwarded, ",")[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return remoteIP
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
|
@ -30,17 +33,28 @@ func (h *DomainHandler) List(c *gin.Context) {
|
|||
}
|
||||
|
||||
isAdmin := authCtx.IsAdmin()
|
||||
domains, err := h.db.GetUserAccessibleDomains(authCtx.UserID, isAdmin)
|
||||
domains, err := h.db.GetUserAccessibleDomains(c.Request.Context(), authCtx.UserID, isAdmin)
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "database error")
|
||||
return
|
||||
}
|
||||
|
||||
// If admin, use optimized single-query method.
|
||||
if isAdmin {
|
||||
domainStats, err := h.db.GetAllDomainsWithCounts(c.Request.Context())
|
||||
if err != nil {
|
||||
Error(c, http.StatusInternalServerError, "database error")
|
||||
return
|
||||
}
|
||||
Success(c, domainStats)
|
||||
return
|
||||
}
|
||||
|
||||
// For non-admins, build stats from their accessible domains.
|
||||
domainStats := make([]db.DomainStats, len(domains))
|
||||
for i, d := range domains {
|
||||
var userCount, aliasCount int64
|
||||
h.db.Model(&db.VirtualUser{}).Where("domain_id = ?", d.ID).Count(&userCount)
|
||||
h.db.Model(&db.VirtualAlias{}).Where("domain_id = ?", d.ID).Count(&aliasCount)
|
||||
userCount, _ := h.db.CountUsersByDomain(c.Request.Context(), d.ID)
|
||||
aliasCount, _ := h.db.CountAliasesByDomain(c.Request.Context(), d.ID)
|
||||
domainStats[i] = db.DomainStats{
|
||||
ID: d.ID,
|
||||
Name: d.Name,
|
||||
|
|
@ -59,7 +73,7 @@ func (h *DomainHandler) Get(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
domain, err := h.db.GetDomainByName(domainName)
|
||||
domain, err := h.db.GetDomainByName(c.Request.Context(), domainName)
|
||||
if err != nil {
|
||||
Error(c, http.StatusNotFound, "domain not found")
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue