aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/moto/py3/moto/ec2/models/vpn_connections.py
blob: 18f8d7b8564c267cb0041be081981ba8311fbd6f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnConnectionIdError
from ..utils import generic_filter, random_vpn_connection_id


class VPNConnection(TaggedEC2Resource):
    def __init__(
        self,
        ec2_backend,
        vpn_connection_id,
        vpn_conn_type,
        customer_gateway_id,
        vpn_gateway_id=None,
        transit_gateway_id=None,
        tags=None,
    ):
        self.ec2_backend = ec2_backend
        self.id = vpn_connection_id
        self.state = "available"
        self.customer_gateway_configuration = {}
        self.type = vpn_conn_type
        self.customer_gateway_id = customer_gateway_id
        self.vpn_gateway_id = vpn_gateway_id
        self.transit_gateway_id = transit_gateway_id
        self.tunnels = None
        self.options = None
        self.static_routes = None
        self.add_tags(tags or {})

    def get_filter_value(self, filter_name):
        return super().get_filter_value(filter_name, "DescribeVpnConnections")


class VPNConnectionBackend:
    def __init__(self):
        self.vpn_connections = {}

    def create_vpn_connection(
        self,
        vpn_conn_type,
        customer_gateway_id,
        vpn_gateway_id=None,
        transit_gateway_id=None,
        static_routes_only=None,
        tags=None,
    ):
        vpn_connection_id = random_vpn_connection_id()
        if static_routes_only:
            pass
        vpn_connection = VPNConnection(
            self,
            vpn_connection_id=vpn_connection_id,
            vpn_conn_type=vpn_conn_type,
            customer_gateway_id=customer_gateway_id,
            vpn_gateway_id=vpn_gateway_id,
            transit_gateway_id=transit_gateway_id,
            tags=tags,
        )
        self.vpn_connections[vpn_connection.id] = vpn_connection
        return vpn_connection

    def delete_vpn_connection(self, vpn_connection_id):

        if vpn_connection_id in self.vpn_connections:
            self.vpn_connections[vpn_connection_id].state = "deleted"
        else:
            raise InvalidVpnConnectionIdError(vpn_connection_id)
        return self.vpn_connections[vpn_connection_id]

    def describe_vpn_connections(self, vpn_connection_ids=None):
        vpn_connections = []
        for vpn_connection_id in vpn_connection_ids or []:
            if vpn_connection_id in self.vpn_connections:
                vpn_connections.append(self.vpn_connections[vpn_connection_id])
            else:
                raise InvalidVpnConnectionIdError(vpn_connection_id)
        return vpn_connections or self.vpn_connections.values()

    def get_all_vpn_connections(self, vpn_connection_ids=None, filters=None):
        vpn_connections = self.vpn_connections.values()

        if vpn_connection_ids:
            vpn_connections = [
                vpn_connection
                for vpn_connection in vpn_connections
                if vpn_connection.id in vpn_connection_ids
            ]
            if len(vpn_connections) != len(vpn_connection_ids):
                invalid_id = list(
                    set(vpn_connection_ids).difference(
                        set([vpn_connection.id for vpn_connection in vpn_connections])
                    )
                )[0]
                raise InvalidVpnConnectionIdError(invalid_id)

        return generic_filter(filters, vpn_connections)