269 lines
6.5 KiB
Go
269 lines
6.5 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/imc-vibe/backend/internal/auth"
|
|
"github.com/imc-vibe/backend/internal/db"
|
|
)
|
|
|
|
const MaxLoginAttempts = 5
|
|
|
|
type AuthHandler struct {
|
|
db *db.DB
|
|
jwtManager *auth.JWTManager
|
|
}
|
|
|
|
func NewAuthHandler(database *db.DB, jwtManager *auth.JWTManager) *AuthHandler {
|
|
return &AuthHandler{
|
|
db: database,
|
|
jwtManager: jwtManager,
|
|
}
|
|
}
|
|
|
|
type LoginRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type ChangePasswordRequest struct {
|
|
OldPassword string `json:"oldPassword"`
|
|
NewPassword string `json:"newPassword"`
|
|
}
|
|
|
|
type ForgotPasswordRequest struct {
|
|
Identifier string `json:"identifier"`
|
|
}
|
|
|
|
type UserResponse struct {
|
|
ID uint `json:"id"`
|
|
Username string `json:"username"`
|
|
Role string `json:"role"`
|
|
Domains []string `json:"domains"`
|
|
}
|
|
|
|
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
Error(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
var req LoginRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
Error(w, http.StatusBadRequest, "invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Username == "" || req.Password == "" {
|
|
Error(w, http.StatusBadRequest, "username and password required")
|
|
return
|
|
}
|
|
|
|
ip := getClientIP(r)
|
|
|
|
if isLockedOut(ip, req.Username, h.db) {
|
|
Error(w, http.StatusTooManyRequests, "too many failed attempts, try again later")
|
|
return
|
|
}
|
|
|
|
user, err := h.db.GetImcUserByUsername(req.Username)
|
|
if err != nil || user == nil || !auth.CheckPassword(req.Password, user.PasswordHash) {
|
|
recordFailedAttempt(req.Username, ip, h.db)
|
|
Error(w, http.StatusUnauthorized, "invalid credentials")
|
|
return
|
|
}
|
|
|
|
clearFailedAttempts(req.Username, ip, h.db)
|
|
|
|
domains, _ := h.db.GetUserAccessibleDomains(user.ID, user.Role == "admin")
|
|
|
|
domainNames := make([]string, len(domains))
|
|
for i, d := range domains {
|
|
domainNames[i] = d.Name
|
|
}
|
|
|
|
token, err := h.jwtManager.GenerateToken(user.ID, user.Username, user.Role, 24*time.Hour)
|
|
if err != nil {
|
|
Error(w, http.StatusInternalServerError, "failed to generate token")
|
|
return
|
|
}
|
|
|
|
Success(w, map[string]interface{}{
|
|
"token": token,
|
|
"user": UserResponse{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Role: user.Role,
|
|
Domains: domainNames,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
Error(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
authCtx := GetAuthContext(r)
|
|
if authCtx == nil {
|
|
Error(w, http.StatusUnauthorized, "not authenticated")
|
|
return
|
|
}
|
|
|
|
user, err := h.db.GetImcUserByID(authCtx.UserID)
|
|
if err != nil || user == nil {
|
|
Error(w, http.StatusNotFound, "user not found")
|
|
return
|
|
}
|
|
|
|
domains, _ := h.db.GetUserAccessibleDomains(user.ID, user.Role == "admin")
|
|
|
|
domainNames := make([]string, len(domains))
|
|
for i, d := range domains {
|
|
domainNames[i] = d.Name
|
|
}
|
|
|
|
Success(w, UserResponse{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Role: user.Role,
|
|
Domains: domainNames,
|
|
})
|
|
}
|
|
|
|
func (h *AuthHandler) ForgotPassword(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
Error(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
var req ForgotPasswordRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
Error(w, http.StatusBadRequest, "invalid request")
|
|
return
|
|
}
|
|
|
|
req.Identifier = strings.TrimSpace(strings.ToLower(req.Identifier))
|
|
if req.Identifier == "" {
|
|
Error(w, http.StatusBadRequest, "identifier required")
|
|
return
|
|
}
|
|
|
|
_, err := h.db.GetImcUserByUsername(req.Identifier)
|
|
if err == nil {
|
|
Success(w, map[string]string{
|
|
"message": "If the account exists, a password reset link will be sent",
|
|
})
|
|
return
|
|
}
|
|
|
|
Success(w, map[string]string{
|
|
"message": "If the account exists, a password reset link will be sent",
|
|
})
|
|
}
|
|
|
|
func (h *AuthHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
Error(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
authCtx := GetAuthContext(r)
|
|
if authCtx == nil {
|
|
Error(w, http.StatusUnauthorized, "authentication required")
|
|
return
|
|
}
|
|
|
|
var req ChangePasswordRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
Error(w, http.StatusBadRequest, "invalid request")
|
|
return
|
|
}
|
|
|
|
if req.OldPassword == "" || req.NewPassword == "" {
|
|
Error(w, http.StatusBadRequest, "old and new password required")
|
|
return
|
|
}
|
|
|
|
if len(req.NewPassword) < 8 {
|
|
Error(w, http.StatusBadRequest, "password must be at least 8 characters")
|
|
return
|
|
}
|
|
|
|
user, err := h.db.GetImcUserByID(authCtx.UserID)
|
|
if err != nil || user == nil {
|
|
Error(w, http.StatusNotFound, "user not found")
|
|
return
|
|
}
|
|
|
|
if !auth.CheckPassword(req.OldPassword, user.PasswordHash) {
|
|
Error(w, http.StatusUnauthorized, "current password is incorrect")
|
|
return
|
|
}
|
|
|
|
newHash, err := auth.HashPassword(req.NewPassword)
|
|
if err != nil {
|
|
Error(w, http.StatusInternalServerError, "failed to hash password")
|
|
return
|
|
}
|
|
|
|
err = h.db.UpdateImcUserPassword(user.ID, newHash)
|
|
if err != nil {
|
|
Error(w, http.StatusInternalServerError, "failed to update password")
|
|
return
|
|
}
|
|
|
|
Success(w, map[string]string{"message": "password updated successfully"})
|
|
}
|
|
|
|
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
|
Success(w, map[string]string{"message": "logged out"})
|
|
}
|
|
|
|
func isLockedOut(ip, identifier string, database *db.DB) bool {
|
|
var count int64
|
|
cutoff := time.Now().Add(-15 * time.Minute)
|
|
|
|
database.Model(&db.ImcLoginAttempt{}).
|
|
Where("ip_address = ? AND attempted_at > ? AND successful = false", ip, cutoff).
|
|
Count(&count)
|
|
|
|
if count >= MaxLoginAttempts {
|
|
return true
|
|
}
|
|
|
|
database.Model(&db.ImcLoginAttempt{}).
|
|
Where("email = ? AND attempted_at > ? AND successful = false", identifier, cutoff).
|
|
Count(&count)
|
|
|
|
return count >= MaxLoginAttempts
|
|
}
|
|
|
|
func recordFailedAttempt(identifier, ip string, database *db.DB) {
|
|
database.Create(&db.ImcLoginAttempt{
|
|
Email: identifier,
|
|
IPAddress: ip,
|
|
Successful: false,
|
|
})
|
|
}
|
|
|
|
func clearFailedAttempts(identifier, ip string, database *db.DB) {
|
|
database.Model(&db.ImcLoginAttempt{}).
|
|
Where("email = ? OR ip_address = ?", identifier, ip).
|
|
Update("successful", true)
|
|
}
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
forwarded := r.Header.Get("X-Forwarded-For")
|
|
if forwarded != "" {
|
|
return strings.Split(forwarded, ",")[0]
|
|
}
|
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
|
return r.RemoteAddr[:idx]
|
|
}
|
|
return r.RemoteAddr
|
|
}
|