208 lines
6.3 KiB
Go
208 lines
6.3 KiB
Go
// Copyright 2024 Florian Beisel
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
|
|
"git.beisel.it/florian/hostname-service/models"
|
|
"github.com/golang-migrate/migrate/v4"
|
|
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var DB *sql.DB
|
|
|
|
// Initialize the database and create tables if they don't exist
|
|
func Init(db string) {
|
|
var err error
|
|
DB, err = sql.Open("sqlite3", db)
|
|
if err != nil {
|
|
log.Fatalf("Error opening database: %v", err)
|
|
}
|
|
|
|
m, err := migrate.New(
|
|
"file://db/migrations/",
|
|
fmt.Sprintf("sqlite3://%s", db),
|
|
)
|
|
|
|
if err != nil {
|
|
log.Fatalf("Migration initialization failed: %v", err)
|
|
}
|
|
|
|
// Apply all up migrations
|
|
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
|
log.Fatalf("Migration up failed: %v", err)
|
|
}
|
|
|
|
// Check if users table is empty
|
|
var userCount int
|
|
err = DB.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
|
if err != nil {
|
|
log.Fatalf("Error checking users table: %v", err)
|
|
}
|
|
|
|
// If there are no users, create a default admin user
|
|
if userCount == 0 {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte("defaultPassword"), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Fatalf("Error hashing password: %v", err)
|
|
}
|
|
|
|
_, err = DB.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", string(hashedPassword))
|
|
if err != nil {
|
|
log.Fatalf("Error creating default admin user: %v", err)
|
|
}
|
|
|
|
log.Println("Default admin user created")
|
|
}
|
|
|
|
log.Println("Database migrations applied successfully")
|
|
}
|
|
|
|
func CreateUser(user *models.User) error {
|
|
statement, err := DB.Prepare("INSERT INTO users(username, password) VALUES (?, ?)")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = statement.Exec(user.Username, user.Password)
|
|
return err
|
|
}
|
|
|
|
func InsertHostname(category string, hostname string, paramsJSON []byte) error {
|
|
_, err := DB.Exec("INSERT INTO hostnames (category, hostname, parameters) VALUES (?, ?, ?)", category, hostname, paramsJSON)
|
|
if err != nil {
|
|
log.Printf("Error inserting hostname into DB: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func UpdateHostname(category string, oldhostname string, hostname string, paramsJSON []byte) error {
|
|
_, err := DB.Exec("UPDATE hostnames set category = ?, hostname =?, parameters =? where category = ? and hostname = ?", category, hostname, paramsJSON, category, oldhostname)
|
|
if err != nil {
|
|
log.Printf("Error inserting hostname into DB: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DeleteHostname(category string, hostname string) error {
|
|
log.Printf("Soft-Deleting hostname: %v in category: %v", category, hostname)
|
|
_, err := DB.Exec("UPDATE hostnames set deleted = true where category = ? and hostname = ?", category, hostname)
|
|
if err != nil {
|
|
log.Printf("Error deleting hostname from DB: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HostnameExists checks if a hostname exists within a given category
|
|
func HostnameExists(category string, hostname string) (bool, error) {
|
|
var exists bool
|
|
|
|
query := "SELECT EXISTS(SELECT 1 FROM hostnames WHERE category = ? AND hostname = ?)"
|
|
err := DB.QueryRow(query, category, hostname).Scan(&exists)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
// No rows found, meaning the hostname does not exist
|
|
return false, nil
|
|
}
|
|
// An actual error occurred
|
|
return false, err
|
|
}
|
|
|
|
return exists, nil
|
|
}
|
|
|
|
func GetMaxNumberForCategory(category string) (int, error) {
|
|
var maxResult sql.NullInt64
|
|
err := DB.QueryRow("SELECT MAX(CAST(json_extract(parameters, '$.Number') AS INTEGER)) FROM hostnames WHERE category = ?", category).Scan(&maxResult)
|
|
if err != nil {
|
|
log.Printf("Error querying max number for category %s: %v", category, err)
|
|
return 0, err
|
|
}
|
|
if !maxResult.Valid {
|
|
return 0, nil // No rows found, start with 0
|
|
}
|
|
return int(maxResult.Int64), nil
|
|
}
|
|
|
|
func GetHostnamesByCategory(category string) ([]models.Hostname, error) {
|
|
var hostnames []models.Hostname
|
|
|
|
rows, err := DB.Query("SELECT id, category, hostname, parameters, created_at FROM hostnames WHERE category = ? and deleted = false", category)
|
|
if err != nil {
|
|
log.Printf("Error querying hostnames: %v", err)
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var h models.Hostname
|
|
var paramsJSON string
|
|
err := rows.Scan(&h.ID, &h.Category, &h.Hostname, ¶msJSON, &h.CreatedAt)
|
|
if err != nil {
|
|
log.Printf("Error scanning hostname: %v", err)
|
|
return nil, err
|
|
}
|
|
// Unmarshal parameters JSON
|
|
err = json.Unmarshal([]byte(paramsJSON), &h.Parameters)
|
|
if err != nil {
|
|
log.Printf("Error unmarshaling parameters: %v", err)
|
|
return nil, err
|
|
}
|
|
hostnames = append(hostnames, h)
|
|
}
|
|
|
|
return hostnames, nil
|
|
}
|
|
|
|
func GetHostnameByCategoryAndName(category string, hostname string) (models.Hostname, error) {
|
|
var host models.Hostname
|
|
var paramsJSON string
|
|
|
|
row, err := DB.Query("SELECT id, category, hostname, parameters, created_at FROM hostnames where category = ? and hostname = ? and deleted = false", category, hostname)
|
|
if err != nil {
|
|
log.Printf("Error querying hostname: %v", err)
|
|
return models.Hostname{}, err
|
|
}
|
|
defer row.Close()
|
|
|
|
if !row.Next() {
|
|
log.Printf("no rows found for category %s and hostname %s", category, hostname)
|
|
return models.Hostname{}, errors.New("no rows found")
|
|
}
|
|
|
|
err = row.Scan(&host.ID, &host.Category, &host.Hostname, ¶msJSON, &host.CreatedAt)
|
|
if err != nil {
|
|
log.Printf("Error scanning hostname: %v", err)
|
|
return models.Hostname{}, err
|
|
}
|
|
err = json.Unmarshal([]byte(paramsJSON), &host.Parameters)
|
|
if err != nil {
|
|
log.Printf("Error unmarshaling parameters: %v", err)
|
|
return models.Hostname{}, nil
|
|
}
|
|
|
|
return host, nil
|
|
}
|