410 lines
15 KiB
Python

from configparser import ConfigParser
from datetime import datetime
from pathlib import Path
from syslog import syslog
from time import sleep
import ipaddress
import json
import subprocess
from dateutil.relativedelta import relativedelta
from .exceptions import ClientLimitError, ValiditySpecificationError
from .qr import generate_qr
confsuffix = '.conf'
serversuffix = '.serverconf'
metasuffix = '.json'
workdir = Path('./work')
lockfile = workdir.joinpath('lockfile.lock~')
def safe_join(*args) -> Path:
'''
Similar to flask's own safe_join, but uses Path objects
instead of strings
'''
base_path = args[0]
joined_path = Path.joinpath(*args)
try:
joined_path.relative_to(base_path)
except ValueError:
return None
return joined_path
def call_with_lock(func: callable, args: list=[], timeout: float=None):
sleep_time = 0.1
slept = 0
while timeout is None or slept < timeout:
try:
with open(lockfile, 'x') as lf:
print(f'Lock successful after {slept}s, '
f'tried on {sleep_time}s intervals.')
try:
result = func(*args)
return result
except Exception as e:
raise e
finally:
lockfile.unlink()
except FileExistsError:
slept += sleep_time
sleep(sleep_time)
raise TimeoutError('Unable to aquire lock within {timeout} seconds')
def dict_to_ini(indict: dict) -> str:
out = ''
for section, rows in indict.items():
if section == 'DEFAULT':
continue
out += f'[{section}]\n'
for key, value in rows.items():
out += f'{key}'
if value is None and key.startswith('#'):
out += '\n'
continue
out += f' = {value}\n'
return out
def generate_keypair():
privkey = run_wg('genkey')
pubkey = run_wg('pubkey', input=privkey)
return privkey.strip(), pubkey.strip()
def generate_user_serverside_config(config_id: str,
user_name: str,
client_ip: ipaddress,
client_pubkey: str):
config = ConfigParser(allow_no_value=True)
config['Peer'] = {
f'# {user_name}/{config_id}': None,
'PublicKey': client_pubkey,
'AllowedIPs': f'{client_ip}/32'
}
return config
def generate_user_clientside_config(client_ip: str,
client_privkey: str,
server_address: ipaddress,
server_port: int,
server_pubkey: str,
fragment_file: Path=None):
config = ConfigParser(interpolation=None)
config['Interface'] = {
'Address': f'{client_ip}/32',
'PrivateKey': client_privkey
}
config['Peer'] = {
'AllowedIPs': '0.0.0.0/0',
'Endpoint': f'{server_address}:{server_port}',
'PublicKey': server_pubkey
}
if fragment_file:
fragment = ConfigParser(interpolation=None)
fragment.read(fragment_file)
for section, contents in fragment.items():
for key, value in contents.items():
config[section][key] = value
return config
def run_wg(*args, input: str=None):
result = subprocess.run(['wg', *args],
input=input,
capture_output=True,
text=True)
return result.stdout
def run_command(command, *args) -> None:
# The command must be called on an absolute path so that it's possible
# to set up a safe sudoers rule for it.
command_path = Path().joinpath('commands.sh')
subprocess.run(['sudo', command_path.absolute(), command, *args])
def create_route(client_ip: ipaddress) -> None:
run_command('add', f'{client_ip}/32')
def delete_route(client_ip: str) -> None:
# We don't add /32 to the ip because it is already included
# from the client config
run_command('del', client_ip)
def parse_timestring(spec: str) -> relativedelta:
count, unit = spec.split()
try:
count = int(count)
except Exception:
raise ValiditySpecificationError(
f"'{spec}' is not recognized as a valid time specification")
if unit == 'year' or unit == 'years':
return relativedelta(years=count)
if unit == 'month' or unit == 'months':
return relativedelta(months=count)
if unit == 'week' or unit == 'weeks':
return relativedelta(weeks=count)
if unit == 'day' or unit == 'days':
return relativedelta(days=count)
raise ValiditySpecificationError(
f"'{spec}' is not recognized as a valid time specification")
class WireGuard:
def __init__(self, config: dict):
self.tunnel_id = config['tunnel_id']
self.server_address = ipaddress.ip_address(config['server_address'])
self.server_port = int(config['server_port'])
self.client_network = ipaddress.ip_network(config['client_network'])
self.configs_base = Path(config['configs_base'])
self.max_clients = config.getint('user_client_limit', fallback=0)
self.client_validity = 0
client_validity = config.get('user_client_validity', fallback=0)
if client_validity:
self.client_validity = parse_timestring(client_validity)
self.server_config_base = None
if 'server_extra_config' in config:
self.server_config_base = Path(config['server_extra_config'])
self.client_config_base = None
if 'client_extra_config' in config:
self.client_config_base = Path(config['client_extra_config'])
self.server_config_file = safe_join(workdir, self.tunnel_id + '.conf')
with open(config['server_privkey_file'], 'r') as privkey_file:
self.server_privkey = privkey_file.read().strip()
with open(config['server_pubkey_file'], 'r') as pubkey_file:
self.server_pubkey = pubkey_file.read().strip()
self.user_name = None
self.user_base = None
# Ensure a wg config exists on startup
self.update(force=True)
def log(self, context_id: str, message) -> None:
syslog(f'[{self.tunnel_id}] {context_id}: {message}')
def set_user(self, user_name: str) -> None:
user_base = safe_join(self.configs_base, user_name)
if not user_base.exists():
user_base.mkdir()
self.user_name = user_name
self.user_base = user_base
def get_used_ips(self):
latest_config = self.generate_server_config()
ips = []
for line in latest_config.split('\n'):
if line.lower().startswith('allowedips'):
ipstring = line.split('=')[1].strip().removesuffix('/32')
ips.append(ipaddress.ip_address(ipstring))
return ips
def get_free_ip(self):
used_ips = self.get_used_ips()
for addr in self.client_network.hosts():
if addr not in used_ips:
return addr
raise Exception('No addresses available')
def get_user_client_count(self) -> int:
# Path.glob() returns a generator, so we have to count by iteration
metafiles = self.user_base.glob(f'*{metasuffix}')
return sum(1 for f in metafiles)
def filepath(self, config_filename: str) -> Path:
if self.user_base is None:
raise Exception('Not properly initialized')
return safe_join(self.user_base, config_filename)
def config_filepath(self, config_id: str) -> Path:
return self.filepath(f'{config_id}{confsuffix}')
def serverconfig_filepath(self, config_id: str) -> Path:
return self.filepath(f'{config_id}{serversuffix}')
def meta_filepath(self, config_id: str) -> Path:
return self.filepath(f'{config_id}{metasuffix}')
def generate_server_config(self):
config = ConfigParser(interpolation=None)
config['Interface'] = {
'Address': self.server_address,
'ListenPort': self.server_port,
'PrivateKey': self.server_privkey
}
if self.server_config_base:
fragment = ConfigParser(interpolation=None)
fragment.read(self.server_config_base)
for key, value in fragment['Interface'].items():
config['Interface'][key] = value
# We need to leave configparser-land here due to lots
# of duplicated sections when configuring peers
server_config = dict_to_ini(config)
for conffile in self.configs_base.glob('*/*'+serversuffix):
with open(conffile, 'r') as cf:
server_config += '\n' + cf.read()
return server_config
def list_configs(self) -> list:
if self.user_base is None:
raise Exception('Not properly initialized')
return [p.stem for p
in self.user_base.glob(f'*{confsuffix}')]
def generate_config_files(self, *args) -> None:
client_ip = call_with_lock(self._unsafe_generate_config_files,
args,
10)
create_route(client_ip)
def _unsafe_generate_config_files(self,
config_id: str,
name: str,
description: str,
creation_time: datetime) -> ipaddress:
if self.max_clients \
and self.get_user_client_count() >= self.max_clients:
raise ClientLimitError('client limit reached')
client_privkey, client_pubkey = generate_keypair()
client_ip = self.get_free_ip()
with open(self.config_filepath(config_id), 'x') as cf, \
open(self.serverconfig_filepath(config_id), 'x') as sf, \
open(self.meta_filepath(config_id), 'x') as mf:
metadata = {'name': name,
'description': description,
'created': creation_time.isoformat(' ',
'minutes')}
json.dump(metadata, mf)
client_config = generate_user_clientside_config(
client_ip,
client_privkey,
self.server_address,
self.server_port,
self.server_pubkey,
self.client_config_base
)
client_config.write(cf)
server_config = generate_user_serverside_config(
config_id,
self.user_name,
client_ip,
client_pubkey
)
server_config.write(sf)
self.log(f'{self.user_name}/{config_id}', f'Created new config')
self.log(f'{self.user_name}/{config_id}', f'IP: {client_ip}')
self.wg_updated = True
return client_ip
def update_config(self, *args) -> None:
call_with_lock(self._unsafe_update_config,
args,
10)
def _unsafe_update_config(self,
config_id: str,
name: str,
description: str) -> None:
with open(self.meta_filepath(config_id), 'r+') as mf:
metadata = json.load(mf)
metadata['name'] = name
metadata['description'] = description
mf.seek(0)
json.dump(metadata, mf)
mf.truncate()
def get_config(self, config_id: str) -> dict:
with open(self.config_filepath(config_id), 'r') as cf, \
open(self.meta_filepath(config_id), 'r') as mf:
metadata = json.load(mf)
configdata = cf.read()
creation_date = metadata['created']
expiry_date = None
if self.client_validity:
expiry_date = (datetime.fromisoformat(creation_date)
+ self.client_validity)
expiry_date = expiry_date.strftime('%Y-%m-%d')
return {'id': config_id,
'name': metadata['name'],
'description': metadata['description'],
'created': creation_date,
'expires': expiry_date,
'qrcode': generate_qr(configdata),
'data': configdata}
def delete_config(self, *args) -> None:
call_with_lock(self._unsafe_delete_config,
args,
10)
def _unsafe_delete_config(self, config_id: str) -> None:
config_path = self.config_filepath(config_id)
paths = [config_path,
self.serverconfig_filepath(config_id),
self.meta_filepath(config_id)]
for path in paths:
if not path.exists():
raise FileNotFoundError(path)
config = ConfigParser()
config.read(config_path)
client_ip = config['Interface']['Address']
[path.unlink() for path in paths]
delete_route(client_ip)
self.log(f'{self.user_name}/{config_id}', 'Deleted config')
self.wg_updated = True
def delete_many_configs(self, *args) -> None:
call_with_lock(self._unsafe_delete_many_configs,
args,
10)
def _unsafe_delete_many_configs(self, expired: dict) -> None:
for user_name, configs in expired.items():
self.set_user(user_name)
for config_id in configs:
self._unsafe_delete_config(config_id)
def run_cleanup(self) -> None:
if not self.client_validity:
return
now = datetime.now()
expired = {}
# Collect expired configs
for metafile in self.configs_base.glob('*/*'+metasuffix):
with open(metafile, 'r') as cf:
metadata = json.load(cf)
created = datetime.fromisoformat(metadata['created'])
config_id = metafile.stem
user_name = metafile.parent.name
if now > created + self.client_validity:
if user_name not in expired:
expired[user_name] = []
expired[user_name].append(config_id)
# Delete the expired configs in a separate step to minimize lock time
self.delete_many_configs(expired)
def update(self, force: bool = False) -> None:
if not force and not self.wg_updated:
return
with open(self.server_config_file, 'w') as sf:
sf.write(self.generate_server_config())
self.server_config_file.chmod(0o600)
# Sync updated settings to interface
run_command('reload')
self.wg_updated = False
return