356 lines
10 KiB
Go
356 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
type WgEasyClient struct {
|
|
ID int `json:"id"`
|
|
UserID int `json:"user_id"`
|
|
InterfaceID string `json:"interface_id"`
|
|
Name string `json:"name"`
|
|
IPv4Address string `json:"ipv4_address"`
|
|
IPv6Address string `json:"ipv6_address"`
|
|
PreUp string `json:"pre_up"`
|
|
PostUp string `json:"post_up"`
|
|
PreDown string `json:"pre_down"`
|
|
PostDown string `json:"post_down"`
|
|
PrivateKey string `json:"private_key"`
|
|
PublicKey string `json:"public_key"`
|
|
PreSharedKey string `json:"pre_shared_key"`
|
|
ExpiresAt string `json:"expires_at,omitempty"`
|
|
AllowedIps string `json:"allowed_ips,omitempty"`
|
|
ServerAllowedIps string `json:"server_allowed_ips"`
|
|
PersistentKeepalive int `json:"persistent_keepalive"`
|
|
MTU int `json:"mtu"`
|
|
JC int `json:"j_c"`
|
|
JMin int `json:"j_min"`
|
|
JMax int `json:"j_max"`
|
|
I1 string `json:"i1,omitempty"`
|
|
I2 string `json:"i2,omitempty"`
|
|
I3 string `json:"i3,omitempty"`
|
|
I4 string `json:"i4,omitempty"`
|
|
I5 string `json:"i5,omitempty"`
|
|
DNS string `json:"dns,omitempty"`
|
|
ServerEndpoint string `json:"server_endpoint"`
|
|
Enabled int `json:"enabled"`
|
|
CreatedAt string `json:"created_at"`
|
|
UpdatedAt string `json:"updated_at"`
|
|
}
|
|
|
|
type WgEasyConfig struct {
|
|
ID int `json:"id"`
|
|
SetupStep int `json:"setup_step"`
|
|
SessionPassword string `json:"session_password"`
|
|
SessionTimeout int `json:"session_timeout"`
|
|
MetricsPrometheus int `json:"metrics_prometheus"`
|
|
MetricsJSON int `json:"metrics_json"`
|
|
MetricsPassword *string `json:"metrics_password,omitempty"` // Pointer to handle NULL values
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type WgEasyInterface struct {
|
|
Name string `json:"name"`
|
|
Device string `json:"device"`
|
|
Port int `json:"port"`
|
|
PrivateKey string `json:"private_key"`
|
|
PublicKey string `json:"public_key"`
|
|
IPv4CIDR string `json:"ipv4_cidr"`
|
|
IPv6CIDR string `json:"ipv6_cidr"`
|
|
MTU int `json:"mtu"`
|
|
Enabled int `json:"enabled"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// getEnv gets an environment variable or returns a default value.
|
|
func getEnv(key, fallback string) string {
|
|
if value, ok := os.LookupEnv(key); ok {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func insertNewUser(client WgEasyClient) {
|
|
db_file := getEnv("DB_FILE", "wg-easy.db")
|
|
db, err := sql.Open("sqlite3", db_file)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
insertStmt := "INSERT INTO clients_table (id, user_id, interface_id, name, ipv4_address, ipv6_address, pre_up, post_up, pre_down, post_down, private_key, public_key, pre_shared_key, allowed_ips, server_allowed_ips, persistent_keepalive, mtu, dns, server_endpoint, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, '::1', '', '', '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, '" + strconv.FormatInt(time.Now().Unix(), 10) + "', '" + strconv.FormatInt(time.Now().Unix(), 10) + "')"
|
|
fmt.Println(insertStmt)
|
|
|
|
_, err = db.Exec(insertStmt,
|
|
// Replace these values with the actual data you want to insert
|
|
client.ID,
|
|
client.UserID,
|
|
client.InterfaceID,
|
|
client.Name,
|
|
client.IPv4Address,
|
|
client.IPv6Address,
|
|
client.PreUp,
|
|
client.PostUp,
|
|
client.PreDown,
|
|
client.PostDown,
|
|
client.PrivateKey,
|
|
client.PublicKey,
|
|
client.PreSharedKey,
|
|
client.AllowedIps,
|
|
client.ServerAllowedIps,
|
|
client.PersistentKeepalive,
|
|
client.MTU,
|
|
client.DNS,
|
|
client.ServerEndpoint,
|
|
// client.Enabled,
|
|
// client.CreatedAt,
|
|
// client.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Println("Inserted new row successfully")
|
|
}
|
|
|
|
func getNetwork() WgEasyInterface {
|
|
db_file := getEnv("DB_FILE", "wg-easy.db")
|
|
db, err := sql.Open("sqlite3", db_file)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
queryString := "SELECT ipv4_cidr, mtu FROM interfaces_table WHERE name = 'wg0'"
|
|
|
|
rows, err := db.Query(queryString)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var network WgEasyInterface
|
|
for rows.Next() {
|
|
fmt.Println(rows)
|
|
err = rows.Scan(&network.IPv4CIDR, &network.MTU)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
return network
|
|
|
|
}
|
|
|
|
func getLatestClient() WgEasyClient {
|
|
var client WgEasyClient
|
|
db_file := getEnv("DB_FILE", "wg-easy.db")
|
|
db, err := sql.Open("sqlite3", db_file)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
|
|
queryString := "SELECT MAX(id), user_id, interface_id, name, ipv4_address, persistent_keepalive FROM clients_table"
|
|
|
|
rows, err := db.Query(queryString)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&client.ID, &client.UserID, &client.InterfaceID, &client.Name, &client.IPv4Address, &client.PersistentKeepalive)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return client
|
|
|
|
}
|
|
|
|
func getNextIp(ip net.IP) net.IP {
|
|
ip = ip.To4()
|
|
if ip == nil {
|
|
log.Fatal("non ipv4 address")
|
|
}
|
|
|
|
ip[3]++
|
|
|
|
return net.IP(ip)
|
|
}
|
|
|
|
// Borrowed from here: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L113
|
|
func generatePrivateKey(location string) ([]byte, error) {
|
|
cmd := exec.Command("wg", "genkey")
|
|
privateKey, err := cmd.Output()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not generate private key: %w", err)
|
|
}
|
|
|
|
err = os.WriteFile(location, privateKey, 0600)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not write private key to file: %w", err)
|
|
}
|
|
|
|
return privateKey, nil
|
|
}
|
|
|
|
// Borrowed (minor changes) from just a few lines below the other func: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L129
|
|
func generatePublicKey(privateKey []byte, location string) ([]byte, error) {
|
|
cmd := exec.Command("wg", "pubkey")
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not get stdin of pubkey command: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
defer stdin.Close()
|
|
io.WriteString(stdin, string(privateKey))
|
|
}()
|
|
|
|
publicKey, err := cmd.Output()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not generate public key: %w", err)
|
|
}
|
|
|
|
err = os.WriteFile(location, publicKey, 0600)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not write private key to file: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
func generatePreshareKey() ([]byte, error) {
|
|
cmd := exec.Command("wg", "genpsk")
|
|
presharedKey, err := cmd.Output()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not generate preshared key: %w", err)
|
|
}
|
|
|
|
return presharedKey, nil
|
|
}
|
|
|
|
func getNextIPWithMask(ip net.IP, mask net.IPMask) net.IP {
|
|
// Ensure the IP is in 4-byte format
|
|
if len(ip) != net.IPv4len {
|
|
ip = ip.To4()
|
|
}
|
|
|
|
// Convert the IP and mask to a slice of bytes
|
|
ipBytes := ip.To4()
|
|
fmt.Printf("IP first byte: %d\n", ipBytes[0])
|
|
maskOnes, maskBits := mask.Size()
|
|
fmt.Printf("Mask Ones: %d\n", maskOnes)
|
|
fmt.Printf("Mask Bits: %d\n", maskBits)
|
|
// maskBytes := mask.To4()
|
|
|
|
// Determine the number of bits in the host part
|
|
// var hostBits int
|
|
// for i := 0; i < net.IPv4len; i++ {
|
|
// if maskBytes[i] == 0xFF {
|
|
// // If the byte is all 1s, it means we have 8 bits of host bits.
|
|
// hostBits += 8
|
|
// } else {
|
|
// // Otherwise, we need to count the trailing zeros.
|
|
// for j := 7; j >= 0; j-- {
|
|
// if (maskBytes[i] & (1 << j)) != 0 {
|
|
// // If the bit is set, it means we have counted all the host bits.
|
|
// hostBits++
|
|
// } else {
|
|
// // If the bit is not set, it means we have reached the network bits.
|
|
// break
|
|
// }
|
|
// }
|
|
// // We only need to check one byte, so we can break out of the loop.
|
|
// break
|
|
// }
|
|
// }
|
|
|
|
// // Increment only the least significant bits that belong to the host part
|
|
// for i := net.IPv4len - 1; i >= 0; i-- {
|
|
// if (maskBytes[i] & 0xFF) != 0xFF {
|
|
// incrementBits := 8 - ((i * 8) + (8 - hostBits))
|
|
// ipBytes[i] += 1 << incrementBits
|
|
// break
|
|
// }
|
|
// }
|
|
|
|
// Convert the bytes back to an IP address
|
|
nextIP := net.IP(ipBytes)
|
|
|
|
return nextIP
|
|
}
|
|
|
|
func main() {
|
|
db_file := getEnv("DB_FILE", "wg-easy.db")
|
|
db, err := sql.Open("sqlite3", db_file)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
defer db.Close()
|
|
fmt.Println("Connected to the SQLite database successfully.")
|
|
|
|
latestClient := getLatestClient()
|
|
nextClient := latestClient
|
|
nextClient.ID += 1
|
|
|
|
fmt.Printf("Next Client ID: %d\n", latestClient.ID)
|
|
lastIp := net.ParseIP(latestClient.IPv4Address)
|
|
|
|
if lastIp == nil {
|
|
fmt.Println("Invalid IP address")
|
|
} else {
|
|
nextClient.IPv4Address = getNextIp(lastIp).String()
|
|
fmt.Printf("Next Valid IP: %v\n", nextClient.IPv4Address)
|
|
}
|
|
|
|
nextClient.IPv6Address = "::1"
|
|
|
|
network := getNetwork()
|
|
fmt.Printf("Got Network: %v\n", network)
|
|
|
|
networkParts := strings.Split(network.IPv4CIDR, "/")
|
|
fmt.Printf("Network CIDR: %s\n", networkParts[1])
|
|
fmt.Printf("Last Client:\n%v\n", latestClient)
|
|
|
|
nextClient.Name = "NewUser"
|
|
privateKeyBytes, _ := generatePrivateKey(nextClient.Name + "-priv")
|
|
nextClient.PrivateKey = string(privateKeyBytes)
|
|
publicKeyBytes, _ := generatePublicKey(privateKeyBytes, nextClient.Name+"-pub")
|
|
nextClient.PublicKey = string(publicKeyBytes)
|
|
preshareKeyBytes, _ := generatePreshareKey()
|
|
nextClient.PreSharedKey = string(preshareKeyBytes)
|
|
nextClient.AllowedIps = "10.8.0.0/24"
|
|
nextClient.ServerAllowedIps = nextClient.IPv4Address
|
|
nextClient.PersistentKeepalive = 0
|
|
nextClient.MTU = network.MTU
|
|
nextClient.DNS = "1.1.1.1"
|
|
nextClient.ServerEndpoint = "1.2.3.4"
|
|
// nextClient.Enabled = 1
|
|
// nextClient.CreatedAt = time.Now().String()
|
|
// nextClient.UpdatedAt = time.Now().String()
|
|
fmt.Printf("New Client:\n%v\n", nextClient)
|
|
|
|
insertNewUser(nextClient)
|
|
}
|