diff --git a/main.go b/main.go index 7284654..6c4c8d3 100644 --- a/main.go +++ b/main.go @@ -3,46 +3,50 @@ package main import ( "database/sql" "fmt" + "io" "log" "net" "os" + "os/exec" + "strconv" + "strings" "time" _ "github.com/mattn/go-sqlite3" ) type WgEasyClient struct { - ID int `json:"id"` - UserID int `json:"user_id"` - InterfaceID string `json:"interface_id"` - Name string `json:"name"` - IPv4Address string `json:"ipv4_address"` - IPv6Address string `json:"ipv6_address"` - PreUp string `json:"pre_up"` - PostUp string `json:"post_up"` - PreDown string `json:"pre_down"` - PostDown string `json:"post_down"` - PrivateKey string `json:"private_key"` - PublicKey string `json:"public_key"` - PreSharedKey string `json:"pre_shared_key"` - ExpiresAt string `json:"expires_at,omitempty"` - AllowedIps []string `json:"allowed_ips,omitempty"` - ServerAllowedIps []string `json:"server_allowed_ips"` - PersistentKeepalive int `json:"persistent_keepalive"` - MTU int `json:"mtu"` - JC int `json:"j_c"` - JMin int `json:"j_min"` - JMax int `json:"j_max"` - I1 string `json:"i1,omitempty"` - I2 string `json:"i2,omitempty"` - I3 string `json:"i3,omitempty"` - I4 string `json:"i4,omitempty"` - I5 string `json:"i5,omitempty"` - DNS []string `json:"dns,omitempty"` - ServerEndpoint string `json:"server_endpoint"` - Enabled bool `json:"enabled"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int `json:"id"` + UserID int `json:"user_id"` + InterfaceID string `json:"interface_id"` + Name string `json:"name"` + IPv4Address string `json:"ipv4_address"` + IPv6Address string `json:"ipv6_address"` + PreUp string `json:"pre_up"` + PostUp string `json:"post_up"` + PreDown string `json:"pre_down"` + PostDown string `json:"post_down"` + PrivateKey string `json:"private_key"` + PublicKey string `json:"public_key"` + PreSharedKey string `json:"pre_shared_key"` + ExpiresAt string `json:"expires_at,omitempty"` + AllowedIps string `json:"allowed_ips,omitempty"` + ServerAllowedIps string `json:"server_allowed_ips"` + PersistentKeepalive int `json:"persistent_keepalive"` + MTU int `json:"mtu"` + JC int `json:"j_c"` + JMin int `json:"j_min"` + JMax int `json:"j_max"` + I1 string `json:"i1,omitempty"` + I2 string `json:"i2,omitempty"` + I3 string `json:"i3,omitempty"` + I4 string `json:"i4,omitempty"` + I5 string `json:"i5,omitempty"` + DNS string `json:"dns,omitempty"` + ServerEndpoint string `json:"server_endpoint"` + Enabled int `json:"enabled"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type WgEasyConfig struct { @@ -57,7 +61,7 @@ type WgEasyConfig struct { UpdatedAt time.Time `json:"updated_at"` } -type WgEasyInterfaceConfig struct { +type WgEasyInterface struct { Name string `json:"name"` Device string `json:"device"` Port int `json:"port"` @@ -87,10 +91,12 @@ func insertNewUser(client WgEasyClient) { return } - insertStmt := `INSERT INTO clients_table (userId, interfaceId, name, ipv4Address, ipv6Address, preUp, postUp, preDown, postDown, privateKey, publicKey, preSharedKey, expiresAt, allowedIps, serverAllowedIps, persistentKeepalive, mtu, jC, jMin, jMax, i1, i2, i3, i4, i5, dns, serverEndpoint, enabled) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + insertStmt := "INSERT INTO clients_table (id, user_id, interface_id, name, ipv4_address, ipv6_address, pre_up, post_up, pre_down, post_down, private_key, public_key, pre_shared_key, allowed_ips, server_allowed_ips, persistent_keepalive, mtu, dns, server_endpoint, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, '::1', '', '', '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, '" + strconv.FormatInt(time.Now().Unix(), 10) + "', '" + strconv.FormatInt(time.Now().Unix(), 10) + "')" + fmt.Println(insertStmt) _, err = db.Exec(insertStmt, // Replace these values with the actual data you want to insert + client.ID, client.UserID, client.InterfaceID, client.Name, @@ -103,24 +109,15 @@ func insertNewUser(client WgEasyClient) { client.PrivateKey, client.PublicKey, client.PreSharedKey, - client.ExpiresAt, client.AllowedIps, client.ServerAllowedIps, client.PersistentKeepalive, client.MTU, - client.JC, - client.JMin, - client.JMax, - client.I1, - client.I2, - client.I3, - client.I4, - client.I5, client.DNS, client.ServerEndpoint, - client.Enabled, - client.CreatedAt, - client.UpdatedAt, + // client.Enabled, + // client.CreatedAt, + // client.UpdatedAt, ) if err != nil { @@ -130,7 +127,34 @@ func insertNewUser(client WgEasyClient) { fmt.Println("Inserted new row successfully") } -func getLatestClient(rowName string) WgEasyClient { +func getNetwork() WgEasyInterface { + db_file := getEnv("DB_FILE", "wg-easy.db") + db, err := sql.Open("sqlite3", db_file) + if err != nil { + fmt.Println(err) + } + queryString := "SELECT ipv4_cidr, mtu FROM interfaces_table WHERE name = 'wg0'" + + rows, err := db.Query(queryString) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + + var network WgEasyInterface + for rows.Next() { + fmt.Println(rows) + err = rows.Scan(&network.IPv4CIDR, &network.MTU) + if err != nil { + log.Fatal(err) + } + } + + return network + +} + +func getLatestClient() WgEasyClient { var client WgEasyClient db_file := getEnv("DB_FILE", "wg-easy.db") db, err := sql.Open("sqlite3", db_file) @@ -138,16 +162,16 @@ func getLatestClient(rowName string) WgEasyClient { fmt.Println(err) } - queryString := "SELECT MAX(id), ipv4_address FROM clients_table" + queryString := "SELECT MAX(id), user_id, interface_id, name, ipv4_address, persistent_keepalive FROM clients_table" - rows, err := db.Query(queryString, rowName, rowName) + rows, err := db.Query(queryString) if err != nil { log.Fatal(err) } defer rows.Close() for rows.Next() { - err = rows.Scan(&client.ID, &client.IPv4Address) + err = rows.Scan(&client.ID, &client.UserID, &client.InterfaceID, &client.Name, &client.IPv4Address, &client.PersistentKeepalive) if err != nil { log.Fatal(err) } @@ -172,6 +196,109 @@ func getNextIp(ip net.IP) net.IP { return net.IP(ip) } +// Borrowed from here: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L113 +func generatePrivateKey(location string) ([]byte, error) { + cmd := exec.Command("wg", "genkey") + privateKey, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("could not generate private key: %w", err) + } + + err = os.WriteFile(location, privateKey, 0600) + if err != nil { + return nil, fmt.Errorf("could not write private key to file: %w", err) + } + + return privateKey, nil +} + +// Borrowed (minor changes) from just a few lines below the other func: https://github.com/icyflame/wireguard-configuration-generator/blob/81d955916a3f3e5a0bbab0a2c9ea328a46b29bc9/internal/keygen/keygen.go#L129 +func generatePublicKey(privateKey []byte, location string) ([]byte, error) { + cmd := exec.Command("wg", "pubkey") + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("could not get stdin of pubkey command: %w", err) + } + + go func() { + defer stdin.Close() + io.WriteString(stdin, string(privateKey)) + }() + + publicKey, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("could not generate public key: %w", err) + } + + err = os.WriteFile(location, publicKey, 0600) + if err != nil { + return nil, fmt.Errorf("could not write private key to file: %w", err) + } + + return publicKey, nil +} + +func generatePreshareKey() ([]byte, error) { + cmd := exec.Command("wg", "genpsk") + presharedKey, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("could not generate preshared key: %w", err) + } + + return presharedKey, nil +} + +func getNextIPWithMask(ip net.IP, mask net.IPMask) net.IP { + // Ensure the IP is in 4-byte format + if len(ip) != net.IPv4len { + ip = ip.To4() + } + + // Convert the IP and mask to a slice of bytes + ipBytes := ip.To4() + fmt.Printf("IP first byte: %d\n", ipBytes[0]) + maskOnes, maskBits := mask.Size() + fmt.Printf("Mask Ones: %d\n", maskOnes) + fmt.Printf("Mask Bits: %d\n", maskBits) + // maskBytes := mask.To4() + + // Determine the number of bits in the host part + // var hostBits int + // for i := 0; i < net.IPv4len; i++ { + // if maskBytes[i] == 0xFF { + // // If the byte is all 1s, it means we have 8 bits of host bits. + // hostBits += 8 + // } else { + // // Otherwise, we need to count the trailing zeros. + // for j := 7; j >= 0; j-- { + // if (maskBytes[i] & (1 << j)) != 0 { + // // If the bit is set, it means we have counted all the host bits. + // hostBits++ + // } else { + // // If the bit is not set, it means we have reached the network bits. + // break + // } + // } + // // We only need to check one byte, so we can break out of the loop. + // break + // } + // } + + // // Increment only the least significant bits that belong to the host part + // for i := net.IPv4len - 1; i >= 0; i-- { + // if (maskBytes[i] & 0xFF) != 0xFF { + // incrementBits := 8 - ((i * 8) + (8 - hostBits)) + // ipBytes[i] += 1 << incrementBits + // break + // } + // } + + // Convert the bytes back to an IP address + nextIP := net.IP(ipBytes) + + return nextIP +} + func main() { db_file := getEnv("DB_FILE", "wg-easy.db") db, err := sql.Open("sqlite3", db_file) @@ -183,7 +310,9 @@ func main() { defer db.Close() fmt.Println("Connected to the SQLite database successfully.") - latestClient := getLatestClient("user_id") + latestClient := getLatestClient() + nextClient := latestClient + nextClient.ID += 1 fmt.Printf("Next Client ID: %d\n", latestClient.ID) lastIp := net.ParseIP(latestClient.IPv4Address) @@ -191,7 +320,36 @@ func main() { if lastIp == nil { fmt.Println("Invalid IP address") } else { - fmt.Printf("Next Valid IP: %v\n", getNextIp(lastIp)) + nextClient.IPv4Address = getNextIp(lastIp).String() + fmt.Printf("Next Valid IP: %v\n", nextClient.IPv4Address) } + nextClient.IPv6Address = "::1" + + network := getNetwork() + fmt.Printf("Got Network: %v\n", network) + + networkParts := strings.Split(network.IPv4CIDR, "/") + fmt.Printf("Network CIDR: %s\n", networkParts[1]) + fmt.Printf("Last Client:\n%v\n", latestClient) + + nextClient.Name = "NewUser" + privateKeyBytes, _ := generatePrivateKey(nextClient.Name + "-priv") + nextClient.PrivateKey = string(privateKeyBytes) + publicKeyBytes, _ := generatePublicKey(privateKeyBytes, nextClient.Name+"-pub") + nextClient.PublicKey = string(publicKeyBytes) + preshareKeyBytes, _ := generatePreshareKey() + nextClient.PreSharedKey = string(preshareKeyBytes) + nextClient.AllowedIps = "10.8.0.0/24" + nextClient.ServerAllowedIps = nextClient.IPv4Address + nextClient.PersistentKeepalive = 0 + nextClient.MTU = network.MTU + nextClient.DNS = "1.1.1.1" + nextClient.ServerEndpoint = "1.2.3.4" + // nextClient.Enabled = 1 + // nextClient.CreatedAt = time.Now().String() + // nextClient.UpdatedAt = time.Now().String() + fmt.Printf("New Client:\n%v\n", nextClient) + + insertNewUser(nextClient) }