Added a thread-safe solution to finding the next free ip

This commit is contained in:
Erik Thuning 2025-02-19 18:10:04 +01:00
parent 8c9842484f
commit e3d6e52217

@ -1,13 +1,20 @@
from datetime import datetime
from pathlib import Path
from textwrap import dedent
from time import sleep
import ipaddress
import json
import subprocess
confsuffix = '.conf'
serversuffix = '.serverconf'
metasuffix = '.json'
workdir = Path('./work')
lockfile = workdir.joinpath('lockfile.lock~')
def safe_join(*args) -> Path:
'''
Similar to flask's own safe_join, but uses Path objects
@ -21,11 +28,28 @@ def safe_join(*args) -> Path:
return None
return joined_path
def get_free_ip():
return '172.17.1.1'
def call_with_lock(func: callable, timeout: float=None):
sleep_time = 0.1
slept = 0
while timeout is None or slept < timeout:
try:
print('Attempting lock')
with open(lockfile, 'x') as lf:
print('Lock successful')
result = func()
print('Releasing lock')
lockfile.unlink()
return result
except FileExistsError:
print('Lock failed')
slept += sleep_time
sleep(sleep_time)
raise TimeoutError('Unable to aquire lock within {timeout} seconds')
def generate_keypair():
return 'privkey', 'pubkey'
privkey = run_wg('genkey')
pubkey = run_wg('pubkey', input=privkey)
return privkey, pubkey
def generate_user_serverside_config(config_id: str,
user_name: str,
@ -35,7 +59,7 @@ def generate_user_serverside_config(config_id: str,
[Peer]
# {user_name}/{config_id}
PublicKey = {client_pubkey}
AllowedIPs = {client_ip}
AllowedIPs = {client_ip}/32
''')
return config.lstrip()
@ -46,7 +70,7 @@ def generate_user_clientside_config(client_ip: str,
dns_server: str):
config = dedent(f'''
[Interface]
Address = {client_ip}
Address = {client_ip}/32
DNS = {dns_server}
PrivateKey = {client_privkey}
@ -57,14 +81,22 @@ def generate_user_clientside_config(client_ip: str,
''')
return config.lstrip()
def run_wg(*args, input: str=None):
result = subprocess.run(['wg', *args],
input=input,
capture_output=True,
text=True)
return result.stdout
class WireGuard:
def __init__(self, config: dict):
self.tunnel_id = config['tunnel_id']
self.server_endpoint = config['server_endpoint']
self.dns_server = config['dns_server']
self.dns_server = ipaddress.ip_address(config['dns_server'])
self.client_network = ipaddress.ip_network(config['client_network'])
self.configs_base = Path(config['configs_base'])
self.server_config_file = safe_join(workdir, self.tunnel_id + '.conf')
with open(config['server_privkey_file'], 'r') as privkey_file:
self.server_privkey = privkey_file.read().strip()
@ -83,6 +115,22 @@ class WireGuard:
self.user_name = user_name
self.user_base = user_base
def get_used_ips(self):
latest_config = self.generate_server_config()
ips = []
for line in latest_config.split('\n'):
if line.startswith('AllowedIPs'):
ipstring = line.split('=')[1].strip().removesuffix('/32')
ips.append(ipaddress.ip_address(ipstring))
return ips
def get_free_ip(self):
used_ips = self.get_used_ips()
for addr in self.client_network.hosts():
if addr not in used_ips:
return addr
raise Exception('No addresses available')
def filepath(self, config_filename: str) -> Path:
if self.user_base is None:
raise Exception('Not properly initialized')
@ -97,6 +145,13 @@ class WireGuard:
def meta_filepath(self, config_id: str) -> Path:
return self.filepath(f'{config_id}{metasuffix}')
def generate_server_config(self):
server_config = ''
for conffile in self.configs_base.glob('*/*'+serversuffix):
with open(conffile, 'r') as cf:
server_config += cf.read() + '\n'
return server_config
def list_configs(self) -> list:
if self.user_base is None:
raise Exception('Not properly initialized')
@ -109,7 +164,7 @@ class WireGuard:
description: str,
creation_time: datetime) -> None:
client_privkey, client_pubkey = generate_keypair()
client_ip = get_free_ip()
client_ip = call_with_lock(self.get_free_ip, 10)
with open(self.config_filepath(config_id), 'x') as cf, \
open(self.serverconfig_filepath(config_id), 'x') as sf, \
open(self.meta_filepath(config_id), 'x') as mf:
@ -128,8 +183,8 @@ class WireGuard:
sf.write(generate_user_serverside_config(config_id,
self.user_name,
client_pubkey,
client_ip))
client_ip,
client_pubkey))
self.wg_updated = True
def update_config(self,
@ -168,8 +223,8 @@ class WireGuard:
def update(self) -> None:
if not self.wg_updated:
return
with open(self.server_config_file, 'w') as sf:
sf.write(self.generate_server_config())
# TODO: update wg config and reload
# 1. join all *.serverconf into one file
# 2. place that file in appropriate location
# 3. reload wireguard service
pass