aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
blob: 969675fd2780fbf02732193e9dd1d14f90143a25 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package pgconn

import (
	"errors"
	"fmt"

	"github.com/jackc/pgx/v5/pgproto3"
)

// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)

var newGSS NewGSSFunc

// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
//	import "github.com/otan/gopgkrb5"
//
//	func init() {
//		pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
//	}
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
	newGSS = newGSSArg
}

// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
	GetInitToken(host string, service string) ([]byte, error)
	GetInitTokenFromSPN(spn string) ([]byte, error)
	Continue(inToken []byte) (done bool, outToken []byte, err error)
}

func (c *PgConn) gssAuth() error {
	if newGSS == nil {
		return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
	}
	cli, err := newGSS()
	if err != nil {
		return err
	}

	var nextData []byte
	if c.config.KerberosSpn != "" {
		// Use the supplied SPN if provided.
		nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
	} else {
		// Allow the kerberos service name to be overridden
		service := "postgres"
		if c.config.KerberosSrvName != "" {
			service = c.config.KerberosSrvName
		}
		nextData, err = cli.GetInitToken(c.config.Host, service)
	}
	if err != nil {
		return err
	}

	for {
		gssResponse := &pgproto3.GSSResponse{
			Data: nextData,
		}
		c.frontend.Send(gssResponse)
		err = c.frontend.Flush()
		if err != nil {
			return err
		}
		resp, err := c.rxGSSContinue()
		if err != nil {
			return err
		}
		var done bool
		done, nextData, err = cli.Continue(resp.Data)
		if err != nil {
			return err
		}
		if done {
			break
		}
	}
	return nil
}

func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
	msg, err := c.receiveMessage()
	if err != nil {
		return nil, err
	}

	switch m := msg.(type) {
	case *pgproto3.AuthenticationGSSContinue:
		return m, nil
	case *pgproto3.ErrorResponse:
		return nil, ErrorResponseToPgError(m)
	}

	return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}