Files
go-easy/main.go
2025-11-30 22:37:46 -05:00

356 lines
10 KiB
Go

package main
import (
"database/sql"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"strconv"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
type WgEasyClient struct {
ID int `json:"id"`
UserID int `json:"user_id"`
InterfaceID string `json:"interface_id"`
Name string `json:"name"`
IPv4Address string `json:"ipv4_address"`
IPv6Address string `json:"ipv6_address"`
PreUp string `json:"pre_up"`
PostUp string `json:"post_up"`
PreDown string `json:"pre_down"`
PostDown string `json:"post_down"`
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
PreSharedKey string `json:"pre_shared_key"`
ExpiresAt string `json:"expires_at,omitempty"`
AllowedIps string `json:"allowed_ips,omitempty"`
ServerAllowedIps string `json:"server_allowed_ips"`
PersistentKeepalive int `json:"persistent_keepalive"`
MTU int `json:"mtu"`
JC int `json:"j_c"`
JMin int `json:"j_min"`
JMax int `json:"j_max"`
I1 string `json:"i1,omitempty"`
I2 string `json:"i2,omitempty"`
I3 string `json:"i3,omitempty"`
I4 string `json:"i4,omitempty"`
I5 string `json:"i5,omitempty"`
DNS string `json:"dns,omitempty"`
ServerEndpoint string `json:"server_endpoint"`
Enabled int `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type WgEasyConfig struct {
ID int `json:"id"`
SetupStep int `json:"setup_step"`
SessionPassword string `json:"session_password"`
SessionTimeout int `json:"session_timeout"`
MetricsPrometheus int `json:"metrics_prometheus"`
MetricsJSON int `json:"metrics_json"`
MetricsPassword *string `json:"metrics_password,omitempty"` // Pointer to handle NULL values
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type WgEasyInterface struct {
Name string `json:"name"`
Device string `json:"device"`
Port int `json:"port"`
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
IPv4CIDR string `json:"ipv4_cidr"`
IPv6CIDR string `json:"ipv6_cidr"`
MTU int `json:"mtu"`
Enabled int `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// getEnv gets an environment variable or returns a default value.
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
func insertNewUser(client WgEasyClient) {
db_file := getEnv("DB_FILE", "wg-easy.db")
db, err := sql.Open("sqlite3", db_file)
if err != nil {
fmt.Println(err)
return
}
insertStmt := "INSERT INTO clients_table (id, user_id, interface_id, name, ipv4_address, ipv6_address, pre_up, post_up, pre_down, post_down, private_key, public_key, pre_shared_key, allowed_ips, server_allowed_ips, persistent_keepalive, mtu, dns, server_endpoint, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, '::1', '', '', '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, '" + strconv.FormatInt(time.Now().Unix(), 10) + "', '" + strconv.FormatInt(time.Now().Unix(), 10) + "')"
fmt.Println(insertStmt)
_, err = db.Exec(insertStmt,
// Replace these values with the actual data you want to insert
client.ID,
client.UserID,
client.InterfaceID,
client.Name,
client.IPv4Address,
client.IPv6Address,
client.PreUp,
client.PostUp,
client.PreDown,
client.PostDown,
client.PrivateKey,
client.PublicKey,
client.PreSharedKey,
client.AllowedIps,
client.ServerAllowedIps,
client.PersistentKeepalive,
client.MTU,
client.DNS,
client.ServerEndpoint,
// client.Enabled,
// client.CreatedAt,
// client.UpdatedAt,
)
if err != nil {
log.Fatal(err)
}
fmt.Println("Inserted new row successfully")
}
func getNetwork() WgEasyInterface {
db_file := getEnv("DB_FILE", "wg-easy.db")
db, err := sql.Open("sqlite3", db_file)
if err != nil {
fmt.Println(err)
}
queryString := "SELECT ipv4_cidr, mtu FROM interfaces_table WHERE name = 'wg0'"
rows, err := db.Query(queryString)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
var network WgEasyInterface
for rows.Next() {
fmt.Println(rows)
err = rows.Scan(&network.IPv4CIDR, &network.MTU)
if err != nil {
log.Fatal(err)
}
}
return network
}
func getLatestClient() WgEasyClient {
var client WgEasyClient
db_file := getEnv("DB_FILE", "wg-easy.db")
db, err := sql.Open("sqlite3", db_file)
if err != nil {
fmt.Println(err)
}
queryString := "SELECT MAX(id), user_id, interface_id, name, ipv4_address, persistent_keepalive FROM clients_table"
rows, err := db.Query(queryString)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&client.ID, &client.UserID, &client.InterfaceID, &client.Name, &client.IPv4Address, &client.PersistentKeepalive)
if err != nil {
log.Fatal(err)
}
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
return client
}
func getNextIp(ip net.IP) net.IP {
ip = ip.To4()
if ip == nil {
log.Fatal("non ipv4 address")
}
ip[3]++
return net.IP(ip)
}
// Borrowed from here: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L113
func generatePrivateKey(location string) ([]byte, error) {
cmd := exec.Command("wg", "genkey")
privateKey, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("could not generate private key: %w", err)
}
err = os.WriteFile(location, privateKey, 0600)
if err != nil {
return nil, fmt.Errorf("could not write private key to file: %w", err)
}
return privateKey, nil
}
// Borrowed (minor changes) from just a few lines below the other func: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L129
func generatePublicKey(privateKey []byte, location string) ([]byte, error) {
cmd := exec.Command("wg", "pubkey")
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("could not get stdin of pubkey command: %w", err)
}
go func() {
defer stdin.Close()
io.WriteString(stdin, string(privateKey))
}()
publicKey, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("could not generate public key: %w", err)
}
err = os.WriteFile(location, publicKey, 0600)
if err != nil {
return nil, fmt.Errorf("could not write private key to file: %w", err)
}
return publicKey, nil
}
func generatePreshareKey() ([]byte, error) {
cmd := exec.Command("wg", "genpsk")
presharedKey, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("could not generate preshared key: %w", err)
}
return presharedKey, nil
}
func getNextIPWithMask(ip net.IP, mask net.IPMask) net.IP {
// Ensure the IP is in 4-byte format
if len(ip) != net.IPv4len {
ip = ip.To4()
}
// Convert the IP and mask to a slice of bytes
ipBytes := ip.To4()
fmt.Printf("IP first byte: %d\n", ipBytes[0])
maskOnes, maskBits := mask.Size()
fmt.Printf("Mask Ones: %d\n", maskOnes)
fmt.Printf("Mask Bits: %d\n", maskBits)
// maskBytes := mask.To4()
// Determine the number of bits in the host part
// var hostBits int
// for i := 0; i < net.IPv4len; i++ {
// if maskBytes[i] == 0xFF {
// // If the byte is all 1s, it means we have 8 bits of host bits.
// hostBits += 8
// } else {
// // Otherwise, we need to count the trailing zeros.
// for j := 7; j >= 0; j-- {
// if (maskBytes[i] & (1 << j)) != 0 {
// // If the bit is set, it means we have counted all the host bits.
// hostBits++
// } else {
// // If the bit is not set, it means we have reached the network bits.
// break
// }
// }
// // We only need to check one byte, so we can break out of the loop.
// break
// }
// }
// // Increment only the least significant bits that belong to the host part
// for i := net.IPv4len - 1; i >= 0; i-- {
// if (maskBytes[i] & 0xFF) != 0xFF {
// incrementBits := 8 - ((i * 8) + (8 - hostBits))
// ipBytes[i] += 1 << incrementBits
// break
// }
// }
// Convert the bytes back to an IP address
nextIP := net.IP(ipBytes)
return nextIP
}
func main() {
db_file := getEnv("DB_FILE", "wg-easy.db")
db, err := sql.Open("sqlite3", db_file)
if err != nil {
fmt.Println(err)
return
}
defer db.Close()
fmt.Println("Connected to the SQLite database successfully.")
latestClient := getLatestClient()
nextClient := latestClient
nextClient.ID += 1
fmt.Printf("Next Client ID: %d\n", latestClient.ID)
lastIp := net.ParseIP(latestClient.IPv4Address)
if lastIp == nil {
fmt.Println("Invalid IP address")
} else {
nextClient.IPv4Address = getNextIp(lastIp).String()
fmt.Printf("Next Valid IP: %v\n", nextClient.IPv4Address)
}
nextClient.IPv6Address = "::1"
network := getNetwork()
fmt.Printf("Got Network: %v\n", network)
networkParts := strings.Split(network.IPv4CIDR, "/")
fmt.Printf("Network CIDR: %s\n", networkParts[1])
fmt.Printf("Last Client:\n%v\n", latestClient)
nextClient.Name = "NewUser"
privateKeyBytes, _ := generatePrivateKey(nextClient.Name + "-priv")
nextClient.PrivateKey = string(privateKeyBytes)
publicKeyBytes, _ := generatePublicKey(privateKeyBytes, nextClient.Name+"-pub")
nextClient.PublicKey = string(publicKeyBytes)
preshareKeyBytes, _ := generatePreshareKey()
nextClient.PreSharedKey = string(preshareKeyBytes)
nextClient.AllowedIps = "10.8.0.0/24"
nextClient.ServerAllowedIps = nextClient.IPv4Address
nextClient.PersistentKeepalive = 0
nextClient.MTU = network.MTU
nextClient.DNS = "1.1.1.1"
nextClient.ServerEndpoint = "1.2.3.4"
// nextClient.Enabled = 1
// nextClient.CreatedAt = time.Now().String()
// nextClient.UpdatedAt = time.Now().String()
fmt.Printf("New Client:\n%v\n", nextClient)
insertNewUser(nextClient)
}