1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnConnectionIdError
from ..utils import generic_filter, random_vpn_connection_id
class VPNConnection(TaggedEC2Resource):
def __init__(
self,
ec2_backend,
vpn_connection_id,
vpn_conn_type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
tags=None,
):
self.ec2_backend = ec2_backend
self.id = vpn_connection_id
self.state = "available"
self.customer_gateway_configuration = {}
self.type = vpn_conn_type
self.customer_gateway_id = customer_gateway_id
self.vpn_gateway_id = vpn_gateway_id
self.transit_gateway_id = transit_gateway_id
self.tunnels = None
self.options = None
self.static_routes = None
self.add_tags(tags or {})
def get_filter_value(self, filter_name):
return super().get_filter_value(filter_name, "DescribeVpnConnections")
class VPNConnectionBackend:
def __init__(self):
self.vpn_connections = {}
def create_vpn_connection(
self,
vpn_conn_type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
static_routes_only=None,
tags=None,
):
vpn_connection_id = random_vpn_connection_id()
if static_routes_only:
pass
vpn_connection = VPNConnection(
self,
vpn_connection_id=vpn_connection_id,
vpn_conn_type=vpn_conn_type,
customer_gateway_id=customer_gateway_id,
vpn_gateway_id=vpn_gateway_id,
transit_gateway_id=transit_gateway_id,
tags=tags,
)
self.vpn_connections[vpn_connection.id] = vpn_connection
return vpn_connection
def delete_vpn_connection(self, vpn_connection_id):
if vpn_connection_id in self.vpn_connections:
self.vpn_connections[vpn_connection_id].state = "deleted"
else:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return self.vpn_connections[vpn_connection_id]
def describe_vpn_connections(self, vpn_connection_ids=None):
vpn_connections = []
for vpn_connection_id in vpn_connection_ids or []:
if vpn_connection_id in self.vpn_connections:
vpn_connections.append(self.vpn_connections[vpn_connection_id])
else:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return vpn_connections or self.vpn_connections.values()
def get_all_vpn_connections(self, vpn_connection_ids=None, filters=None):
vpn_connections = self.vpn_connections.values()
if vpn_connection_ids:
vpn_connections = [
vpn_connection
for vpn_connection in vpn_connections
if vpn_connection.id in vpn_connection_ids
]
if len(vpn_connections) != len(vpn_connection_ids):
invalid_id = list(
set(vpn_connection_ids).difference(
set([vpn_connection.id for vpn_connection in vpn_connections])
)
)[0]
raise InvalidVpnConnectionIdError(invalid_id)
return generic_filter(filters, vpn_connections)
|