mirror of
https://codeberg.org/readeck/readeck.git
synced 2025-12-22 13:17:10 +00:00
Refactored OAuth Dynamic Client Registration
OAuth clients can only register using Dynamic Client Registration (DCR). The initial OAuth work implemented DCR and Dynamic Client Management (DCM). This removes DCM entirely and implements ephemaral OAuth clients. Each time a client starts an authorization process, it will need to register a short lived (10 minutes) client first. The reasoning behind this is that, since client registration is already public, there's no real need to store clients (their information is stored with the token though) and force developers to manage clients (updates, deletion after token revocation, etc.). Moreover, even with the current model, a client is necessary almost each time you create a new token (no real difference in database usage). This model also simplifies potential future implementations: - IndieAuth clients (https://indieauth.spec.indieweb.org/) - x509 signed software attestations with a Readeck CA With this commit, the new workflow is slightly easier: - register a client - use the client ID in the next 10 minutes to ask for authorization - that's it! When a client is created, it's stored in Readeck K/V with a 10 minute TTL. Every subsequent request that needs the client ID (authorize, device, token, etc.) will fetch the necessary client information from there. The client is removed when the process ends (or after the TTL). The created token contains the client information. This commit comes with a migration that removes a table and a field that only make it to nightly. Enforce oauth client grant_types during authorization - during authorization code or device code grant, check that the client accepts the matching grant type - requires that client redirect_uris is not empty when grant_types contains "authorization_code"
This commit is contained in:
@@ -33,7 +33,7 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
<h3 class="title text-lg mb-2">{{ gettext("Application information") }}</h3>
|
||||
<ul>
|
||||
<li>{{ gettext("Name") }}: <strong>{{ .Client.Name }}</strong></li>
|
||||
<li>{{ gettext("Website") }}: <a class="link" href="{{ .Client.Website }}" target="_blank">{{ .Client.Website }}</a></li>
|
||||
<li>{{ gettext("Website") }}: <a class="link" href="{{ .Client.URI }}" target="_blank">{{ .Client.URI }}</a></li>
|
||||
<li>{{ gettext("Version") }}: <strong>{{ .Client.SoftwareVersion }}</strong></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
<div class="py-4 max-md:pl-4 grow">
|
||||
<h2 class="title text-lg mb-0">{{ .ClientName }}</h2>
|
||||
<p><a class="link" href="{{ .ClientURI }}" target="_blank">{{ .ClientURI }}</a></p>
|
||||
<p>{{ gettext("Version: %s", .ClientVersion) }}</p>
|
||||
|
||||
<p class="mt-1 text-sm">{{ gettext("Authorized on: %s", date(.Created, pgettext("datetime", "%e %B %Y"))) }}
|
||||
{{- if .LastUsed -}}
|
||||
|
||||
@@ -14,19 +14,17 @@ import (
|
||||
var Keys = KeyMaterial{}
|
||||
|
||||
const (
|
||||
keyToken = "api_token"
|
||||
keySession = "session"
|
||||
keyOauthClientToken = "oauth_client_token" //nolint:gosec
|
||||
keyOauthRequest = "oauth_request"
|
||||
keyToken = "api_token"
|
||||
keySession = "session"
|
||||
keyOauthRequest = "oauth_request"
|
||||
)
|
||||
|
||||
// KeyMaterial contains the signing and encryption keys.
|
||||
type KeyMaterial struct {
|
||||
prk []byte // Main pseudorandom key
|
||||
tokenKey SigningKey
|
||||
sessionKey []byte
|
||||
oauthClientTokenKey SigningKey
|
||||
oauthRequestKey []byte
|
||||
prk []byte // Main pseudorandom key
|
||||
tokenKey SigningKey
|
||||
sessionKey []byte
|
||||
oauthRequestKey []byte
|
||||
}
|
||||
|
||||
func hkdfHashFunc() hash.Hash {
|
||||
@@ -51,12 +49,6 @@ func (km KeyMaterial) SessionKey() []byte {
|
||||
return km.sessionKey
|
||||
}
|
||||
|
||||
// OauthClientTokenKey returns the key used to generate a client's
|
||||
// authorization token.
|
||||
func (km KeyMaterial) OauthClientTokenKey() SigningKey {
|
||||
return km.oauthClientTokenKey
|
||||
}
|
||||
|
||||
// OauthRequestKey returns a 256-bit key used to encode the oauth
|
||||
// authorization payload.
|
||||
func (km KeyMaterial) OauthRequestKey() []byte {
|
||||
@@ -86,6 +78,5 @@ func loadKeys() {
|
||||
Keys.tokenKey = Keys.mustExpand(keyToken, 32)
|
||||
Keys.sessionKey = Keys.mustExpand(keySession, 32)
|
||||
|
||||
Keys.oauthClientTokenKey = Keys.mustExpand(keyOauthClientToken, 32)
|
||||
Keys.oauthRequestKey = Keys.mustExpand(keyOauthRequest, 32)
|
||||
}
|
||||
|
||||
@@ -5,31 +5,22 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"codeberg.org/readeck/readeck/configs"
|
||||
"codeberg.org/readeck/readeck/internal/auth/tokens"
|
||||
"codeberg.org/readeck/readeck/internal/bus"
|
||||
"codeberg.org/readeck/readeck/internal/server"
|
||||
"codeberg.org/readeck/readeck/internal/server/urls"
|
||||
"codeberg.org/readeck/readeck/pkg/ctxr"
|
||||
"codeberg.org/readeck/readeck/pkg/forms"
|
||||
)
|
||||
|
||||
type (
|
||||
ctxClientKey struct{}
|
||||
const (
|
||||
clientTTL = time.Minute * 10
|
||||
)
|
||||
|
||||
var withClient, getClient = ctxr.WithGetter[*Client](ctxClientKey{})
|
||||
|
||||
type clientResponse struct {
|
||||
type oauthClient struct {
|
||||
ID string `json:"client_id"`
|
||||
ClientURI string `json:"registration_client_uri"`
|
||||
AccessToken string `json:"registration_access_token"`
|
||||
Name string `json:"client_name"`
|
||||
URI string `json:"client_uri"`
|
||||
Logo string `json:"logo_uri"`
|
||||
@@ -41,91 +32,48 @@ type clientResponse struct {
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
}
|
||||
|
||||
func newClientResponse(ctx context.Context, client *Client) clientResponse {
|
||||
token, _ := configs.Keys.OauthClientTokenKey().Encode(client.UID)
|
||||
func loadClient(id string, grantType string) (*oauthClient, error) {
|
||||
c := &oauthClient{}
|
||||
if err := bus.GetJSON("oauth:client:"+id, c); err != nil {
|
||||
return nil, errServerError.withError(err)
|
||||
}
|
||||
if c.ID == "" {
|
||||
return nil, errInvalidClient
|
||||
}
|
||||
|
||||
return clientResponse{
|
||||
ID: client.UID,
|
||||
ClientURI: urls.AbsoluteURLContext(ctx, "/api/oauth/client", client.UID).String(),
|
||||
AccessToken: token,
|
||||
Name: client.Name,
|
||||
URI: client.Website,
|
||||
RedirectURIs: client.RedirectURIs,
|
||||
Logo: client.Logo,
|
||||
SoftwareID: client.SoftwareID,
|
||||
SoftwareVersion: client.SoftwareVersion,
|
||||
TokenEndpointAuthMethod: "none",
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
ResponseTypes: []string{"code"},
|
||||
if !slices.Contains(c.GrantTypes, grantType) {
|
||||
return nil, errUnauthorizedClient
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *oauthClient) store() error {
|
||||
return bus.SetJSON("oauth:client:"+c.ID, c, clientTTL)
|
||||
}
|
||||
|
||||
func (c *oauthClient) remove() error {
|
||||
if c.ID == "" {
|
||||
return nil
|
||||
}
|
||||
return bus.Store().Del("oauth:client:" + c.ID)
|
||||
}
|
||||
|
||||
func (c *oauthClient) toClientInfo() *tokens.ClientInfo {
|
||||
return &tokens.ClientInfo{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Website: c.URI,
|
||||
Logo: c.Logo,
|
||||
GrantTypes: c.GrantTypes,
|
||||
SoftwareID: c.SoftwareID,
|
||||
SoftwareVersion: c.SoftwareVersion,
|
||||
}
|
||||
}
|
||||
|
||||
type clientAPI struct {
|
||||
chi.Router
|
||||
}
|
||||
|
||||
func newClientAPI() *clientAPI {
|
||||
api := &clientAPI{chi.NewRouter()}
|
||||
|
||||
api.Post("/", api.clientCreate)
|
||||
api.With(
|
||||
api.withClient,
|
||||
).Route("/{uid}", func(r chi.Router) {
|
||||
r.Get("/", api.clientInfo)
|
||||
r.Put("/", api.clientUpdate)
|
||||
r.Delete("/", api.clientDelete)
|
||||
})
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
// withAuthenticatedClient retrieves the client from a provided
|
||||
// bearer token.
|
||||
func withAuthenticatedClient(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
accessToken, ok := strings.CutPrefix(r.Header.Get("authorization"), "Bearer ")
|
||||
if !ok {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
token, err := configs.Keys.OauthClientTokenKey().Decode(accessToken)
|
||||
if err != nil {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := Clients.GetOne(goqu.C("uid").Eq(token))
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
} else {
|
||||
server.Err(w, r, errServerError.withError(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx := withClient(r.Context(), client)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// withClient wraps [withAuthenticatedClient] and checks that the
|
||||
// path's client ID matches the authenticated client.
|
||||
func (api *clientAPI) withClient(next http.Handler) http.Handler {
|
||||
return withAuthenticatedClient(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if getClient(r.Context()).UID != chi.URLParam(r, "uid") {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func (api *clientAPI) clientCreate(w http.ResponseWriter, r *http.Request) {
|
||||
// clientCreate creates a new client that is stored in the K/V store
|
||||
// for [clientTTL] duration.
|
||||
func (api *oauthAPI) clientCreate(w http.ResponseWriter, r *http.Request) {
|
||||
f := newClientForm(server.Locale(r))
|
||||
forms.Bind(f, r)
|
||||
|
||||
@@ -140,44 +88,5 @@ func (api *clientAPI) clientCreate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
server.Render(w, r, http.StatusCreated, newClientResponse(r.Context(), client))
|
||||
}
|
||||
|
||||
func (api *clientAPI) clientInfo(w http.ResponseWriter, r *http.Request) {
|
||||
server.Render(w, r, http.StatusOK,
|
||||
newClientResponse(r.Context(), getClient(r.Context())),
|
||||
)
|
||||
}
|
||||
|
||||
func (api *clientAPI) clientUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
client := getClient(r.Context())
|
||||
|
||||
f := newClientForm(server.Locale(r))
|
||||
f.setClient(client)
|
||||
forms.Bind(f, r)
|
||||
|
||||
if !f.IsValid() {
|
||||
server.Err(w, r, f.getError())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := f.updateClient(client)
|
||||
if err != nil {
|
||||
server.Err(w, r, errServerError.withError(err))
|
||||
return
|
||||
}
|
||||
if len(res) > 0 {
|
||||
client, _ = Clients.GetOne(goqu.C("id").Eq(client.ID))
|
||||
}
|
||||
|
||||
server.Render(w, r, http.StatusOK, newClientResponse(r.Context(), client))
|
||||
}
|
||||
|
||||
func (api *clientAPI) clientDelete(w http.ResponseWriter, r *http.Request) {
|
||||
if err := getClient(r.Context()).Delete(); err != nil {
|
||||
server.Err(w, r, errServerError.withError(err))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
server.Render(w, r, http.StatusCreated, client)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ var (
|
||||
errInvalidScope = oauthError{name: "invalid_scope"}
|
||||
errServerError = oauthError{name: "server_error", status: http.StatusInternalServerError}
|
||||
errSlowDown = oauthError{name: "slow_down"}
|
||||
errUnauthorizedClient = oauthError{name: "unauthorized_client"}
|
||||
)
|
||||
|
||||
// oauthError describes an OAuth error as specified in
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"image/png"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
@@ -19,19 +21,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
"codeberg.org/readeck/readeck/configs"
|
||||
"codeberg.org/readeck/readeck/internal/auth"
|
||||
"codeberg.org/readeck/readeck/internal/auth/tokens"
|
||||
"codeberg.org/readeck/readeck/internal/auth/users"
|
||||
"codeberg.org/readeck/readeck/internal/db/types"
|
||||
"codeberg.org/readeck/readeck/pkg/base58"
|
||||
"codeberg.org/readeck/readeck/pkg/forms"
|
||||
)
|
||||
|
||||
type (
|
||||
ctxClientFormKey struct{}
|
||||
)
|
||||
|
||||
const (
|
||||
grantTypeAuthCode = "authorization_code"
|
||||
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
@@ -45,45 +45,51 @@ func newClientForm(tr forms.Translator) *clientForm {
|
||||
return &clientForm{
|
||||
forms.Must(
|
||||
forms.WithTranslator(context.Background(), tr),
|
||||
forms.NewTextField("client_id", forms.Trim, forms.ValueValidatorFunc[string](func(f forms.Field, v string) error {
|
||||
c, _ := forms.GetForm(f).Context().Value(ctxClientFormKey{}).(*Client)
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Per RFC 7592:
|
||||
// The client MUST include its "client_id" field in the request, and it
|
||||
// MUST be the same as its currently issued client identifier.
|
||||
if c.UID != v {
|
||||
return errors.New("client ID doesn't match")
|
||||
}
|
||||
return nil
|
||||
})),
|
||||
forms.NewTextField("client_name", forms.Trim, forms.Required, forms.StrLen(0, 128)),
|
||||
forms.NewTextField("client_uri", forms.Trim, forms.Required, forms.StrLen(0, 256), forms.IsURL("https")),
|
||||
forms.NewTextField("client_uri", forms.Trim, forms.Required, forms.StrLen(0, 256), isValidClientURI),
|
||||
forms.NewTextField("logo_uri", forms.Trim, forms.StrLen(0, 8<<10), isValidLogoURI),
|
||||
forms.NewTextField("software_id", forms.Trim, forms.Required, forms.StrLen(0, 128)),
|
||||
forms.NewTextField("software_version", forms.Trim, forms.Required, forms.StrLen(0, 64)),
|
||||
forms.NewTextListField("redirect_uris", forms.Trim, forms.Required, isValidRedirectURI),
|
||||
forms.NewTextListField("redirect_uris",
|
||||
forms.Trim,
|
||||
forms.FieldValidatorFunc(func(f forms.Field) error {
|
||||
if !slices.Contains(
|
||||
forms.GetForm(f).Get("grant_types").(*forms.TextListField).V(),
|
||||
grantTypeAuthCode,
|
||||
) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(f.(*forms.ListField[string]).V()) == 0 {
|
||||
return forms.ErrRequired
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
isValidRedirectURI,
|
||||
),
|
||||
forms.NewTextListField("grant_types",
|
||||
forms.ChoicesPairs([][2]string{
|
||||
{grantTypeAuthCode, grantTypeAuthCode},
|
||||
{grantTypeDeviceCode, grantTypeDeviceCode},
|
||||
}),
|
||||
forms.Default([]string{grantTypeAuthCode, grantTypeDeviceCode}),
|
||||
),
|
||||
|
||||
// Ignored fields but we want to coerce their values
|
||||
forms.NewTextField("token_endpoint_auth_method", forms.ChoicesPairs([][2]string{{"none", "none"}})),
|
||||
forms.NewTextListField("grant_types", forms.ChoicesPairs([][2]string{
|
||||
{grantTypeAuthCode, grantTypeAuthCode},
|
||||
{grantTypeDeviceCode, grantTypeDeviceCode},
|
||||
})),
|
||||
forms.NewTextListField("response_types", forms.ChoicesPairs([][2]string{
|
||||
{"code", "code"},
|
||||
})),
|
||||
forms.NewTextField("token_endpoint_auth_method",
|
||||
forms.ChoicesPairs([][2]string{{"none", "none"}}),
|
||||
forms.Default("none"),
|
||||
),
|
||||
forms.NewTextListField("response_types",
|
||||
forms.ChoicesPairs([][2]string{
|
||||
{"code", "code"},
|
||||
}),
|
||||
forms.Default([]string{"code"}),
|
||||
),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *clientForm) setClient(c *Client) {
|
||||
ctx := context.WithValue(f.Context(), ctxClientFormKey{}, c)
|
||||
f.SetContext(ctx)
|
||||
}
|
||||
|
||||
func (f *clientForm) getError() oauthError {
|
||||
switch {
|
||||
case len(f.Get("redirect_uris").Errors()) > 0:
|
||||
@@ -93,60 +99,27 @@ func (f *clientForm) getError() oauthError {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *clientForm) createClient() (*Client, error) {
|
||||
client := &Client{
|
||||
Name: f.Get("client_name").String(),
|
||||
Website: f.Get("client_uri").String(),
|
||||
Logo: f.Get("logo_uri").String(),
|
||||
RedirectURIs: f.Get("redirect_uris").Value().([]string),
|
||||
SoftwareID: f.Get("software_id").String(),
|
||||
SoftwareVersion: f.Get("software_version").String(),
|
||||
func (f *clientForm) createClient() (*oauthClient, error) {
|
||||
client := &oauthClient{
|
||||
ID: uuid.New().URN(),
|
||||
Name: f.Get("client_name").String(),
|
||||
URI: f.Get("client_uri").String(),
|
||||
Logo: f.Get("logo_uri").String(),
|
||||
RedirectURIs: f.Get("redirect_uris").(*forms.TextListField).V(),
|
||||
GrantTypes: f.Get("grant_types").(*forms.TextListField).V(),
|
||||
TokenEndpointAuthMethod: f.Get("token_endpoint_auth_method").String(),
|
||||
ResponseTypes: f.Get("response_types").(*forms.TextListField).V(),
|
||||
SoftwareID: f.Get("software_id").String(),
|
||||
SoftwareVersion: f.Get("software_version").String(),
|
||||
}
|
||||
|
||||
if err := Clients.Create(client); err != nil {
|
||||
if err := client.store(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (f *clientForm) updateClient(client *Client) (res map[string]any, err error) {
|
||||
if !f.IsBound() {
|
||||
err = errors.New("form is not bound")
|
||||
return
|
||||
}
|
||||
|
||||
res = make(map[string]any)
|
||||
for _, field := range f.Fields() {
|
||||
if !field.IsBound() || field.IsNil() {
|
||||
continue
|
||||
}
|
||||
switch field.Name() {
|
||||
case "client_name":
|
||||
res["name"] = field.String()
|
||||
case "client_uri":
|
||||
res["website"] = field.String()
|
||||
case "logo_uri":
|
||||
res["logo"] = field.String()
|
||||
case "redirect_uris":
|
||||
res["redirect_uris"] = types.Strings(field.Value().([]string))
|
||||
case "software_version":
|
||||
res["software_version"] = field.String()
|
||||
}
|
||||
}
|
||||
|
||||
if len(res) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err = client.Update(res); err != nil {
|
||||
f.AddErrors("", forms.ErrUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type authorizationForm struct {
|
||||
*forms.Form
|
||||
}
|
||||
@@ -318,28 +291,62 @@ func newRevokeTokenForm(tr forms.Translator) *revokeTokenForm {
|
||||
)}
|
||||
}
|
||||
|
||||
func (f *revokeTokenForm) revoke(client *Client) error {
|
||||
func (f *revokeTokenForm) revoke(r *http.Request) error {
|
||||
tokenID, err := configs.Keys.TokenKey().Decode(f.Get("token").String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := tokens.Tokens.GetOne(goqu.C("uid").Eq(tokenID))
|
||||
if err != nil && !errors.Is(err, tokens.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
if token == nil {
|
||||
return nil
|
||||
// must be authenticated with the same token
|
||||
if tokenID != auth.GetRequestAuthInfo(r).Provider.ID {
|
||||
return errAccessDenied
|
||||
}
|
||||
|
||||
// A client can only remove its own tokens
|
||||
if *token.ClientID != client.ID {
|
||||
return errInvalidRequest
|
||||
token, err := tokens.Tokens.GetOne(goqu.C("uid").Eq(tokenID))
|
||||
if err != nil {
|
||||
if errors.Is(err, tokens.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return token.Delete()
|
||||
}
|
||||
|
||||
// isValidClientURI checks the given client URL.
|
||||
// It must be https only and resolve to an ip that is not
|
||||
// private or a loopback address.
|
||||
var isValidClientURI = forms.TypedValidator(func(v string) bool {
|
||||
u, err := url.Parse(v)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
if u.Hostname() == "" {
|
||||
return false
|
||||
}
|
||||
host, err := idna.ToASCII(u.Hostname())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Private and loopback is not allowed
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}, errors.New("invalid client URI"))
|
||||
|
||||
var isValidLogoURI = forms.TypedValidator(func(v string) bool {
|
||||
if v == "" {
|
||||
return true
|
||||
|
||||
@@ -111,7 +111,6 @@ func newAuthorizeViewRouter() *authorizeViewRouter {
|
||||
server.WithSession(),
|
||||
server.WithRedirectLogin,
|
||||
auth.Required,
|
||||
router.withClient,
|
||||
).Route("/", func(r chi.Router) {
|
||||
r.Get("/", router.authorizeHandler)
|
||||
r.Post("/", router.authorizeHandler)
|
||||
@@ -120,29 +119,6 @@ func newAuthorizeViewRouter() *authorizeViewRouter {
|
||||
return router
|
||||
}
|
||||
|
||||
func (h *authorizeViewRouter) withClient(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientID := r.URL.Query().Get("client_id")
|
||||
if clientID == "" {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := Clients.GetOne(goqu.C("uid").Eq(clientID))
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
} else {
|
||||
server.Err(w, r, errServerError.withError(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx := withClient(r.Context(), client)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// authorizeHandler is the authorization page returned to a user.
|
||||
// It shows a form with an accept or deny action. Once the form is
|
||||
// submitted, it returns a 302 response with the redirection
|
||||
@@ -161,7 +137,11 @@ func (h *authorizeViewRouter) authorizeHandler(w http.ResponseWriter, r *http.Re
|
||||
forms.Bind(f, r)
|
||||
}
|
||||
|
||||
client := getClient(r.Context())
|
||||
client, err := loadClient(f.Get("client_id").String(), grantTypeAuthCode)
|
||||
if err != nil {
|
||||
server.Err(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate redirect URI first
|
||||
redir, _ := url.Parse(f.Get("redirect_uri").String())
|
||||
@@ -208,6 +188,7 @@ func (h *authorizeViewRouter) authorizeHandler(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
if !f.Get("granted").(*forms.BooleanField).V() {
|
||||
client.remove() //nolint:errcheck
|
||||
errAccessDenied.withDescription("access denied").redirect(w, r, redir, params)
|
||||
return
|
||||
}
|
||||
@@ -248,19 +229,20 @@ func (api *oauthAPI) authorizationCodeHandler(w http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
client, err := Clients.GetOne(goqu.C("uid").Eq(req.ClientID))
|
||||
client, err := loadClient(req.ClientID, grantTypeAuthCode)
|
||||
if err != nil {
|
||||
server.Err(w, r, errInvalidClient.withDescription("client not found").withError(err))
|
||||
server.Err(w, r, err)
|
||||
return
|
||||
}
|
||||
defer client.remove() //nolint:errcheck
|
||||
|
||||
t := &tokens.Token{
|
||||
UID: req.TokenID,
|
||||
UserID: &user.ID,
|
||||
ClientID: &client.ID,
|
||||
IsEnabled: true,
|
||||
Application: client.Name,
|
||||
Roles: req.Scopes,
|
||||
ClientInfo: client.toClientInfo(),
|
||||
}
|
||||
if err = tokens.Tokens.Create(t); err != nil {
|
||||
server.Err(w, r,
|
||||
|
||||
@@ -7,7 +7,6 @@ package oauth2
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -80,7 +79,7 @@ type deviceAuthorizationResponse struct {
|
||||
// The request is stored in Readeck's key/value store with
|
||||
// a TTL of 5 minutes.
|
||||
type deviceAuthorizationRequest struct {
|
||||
ClientID int `json:"c"`
|
||||
ClientID string `json:"c"`
|
||||
UserID int `json:"u"`
|
||||
TokenID int `json:"t"`
|
||||
Expires time.Time `json:"e"`
|
||||
@@ -90,27 +89,15 @@ type deviceAuthorizationRequest struct {
|
||||
}
|
||||
|
||||
func (r *deviceAuthorizationRequest) store(code userCode) error {
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
duration := r.Expires.Sub(time.Now().UTC())
|
||||
|
||||
return bus.Store().Set(
|
||||
fmt.Sprintf("oauth:device-code:%s", code),
|
||||
string(data),
|
||||
duration,
|
||||
)
|
||||
return bus.SetJSON("oauth:device-code:"+string(code), r, r.Expires.Sub(time.Now().UTC()))
|
||||
}
|
||||
|
||||
func loadDeviceAuthorizationRequest(code userCode) (*deviceAuthorizationRequest, error) {
|
||||
data := bus.Store().Get(fmt.Sprintf("oauth:device-code:%s", code))
|
||||
if data == "" {
|
||||
return &deviceAuthorizationRequest{}, nil
|
||||
}
|
||||
r := &deviceAuthorizationRequest{}
|
||||
err := json.Unmarshal([]byte(data), r)
|
||||
return r, err
|
||||
if err := bus.GetJSON("oauth:device-code:"+string(code), r); err != nil {
|
||||
return nil, errServerError.withError(err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// newUserCode generate a random [userCode] of 8 letters from [deviceCodeAlphabet].
|
||||
@@ -210,7 +197,7 @@ func (h *deviceViewRouter) authorizeHandler(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
switch req.Status {
|
||||
case codeRequestPending:
|
||||
client, err := Clients.GetOne(goqu.C("id").Eq(req.ClientID))
|
||||
client, err := loadClient(req.ClientID, grantTypeDeviceCode)
|
||||
if err != nil {
|
||||
server.Err(w, r, err)
|
||||
return
|
||||
@@ -278,9 +265,9 @@ func (api *oauthAPI) deviceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := Clients.GetOne(goqu.C("uid").Eq(f.Get("client_id").String()))
|
||||
client, err := loadClient(f.Get("client_id").String(), grantTypeDeviceCode)
|
||||
if err != nil {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
server.Err(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -326,9 +313,9 @@ func (api *oauthAPI) deviceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// "urn:ietf:params:oauth:grant-type:device_code" grant_type.
|
||||
func (api *oauthAPI) deviceCodeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
f := getTokenForm(r.Context())
|
||||
client, err := Clients.GetOne(goqu.C("uid").Eq(f.Get("client_id").String()))
|
||||
client, err := loadClient(f.Get("client_id").String(), grantTypeDeviceCode)
|
||||
if err != nil {
|
||||
server.Err(w, r, errInvalidClient)
|
||||
server.Err(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -375,10 +362,10 @@ func (api *oauthAPI) deviceCodeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a token when it doesn't exist.
|
||||
t = &tokens.Token{
|
||||
UserID: &user.ID,
|
||||
ClientID: &client.ID,
|
||||
IsEnabled: true,
|
||||
Application: client.Name,
|
||||
Roles: req.Scopes,
|
||||
ClientInfo: client.toClientInfo(),
|
||||
}
|
||||
if err = tokens.Tokens.Create(t); err != nil {
|
||||
server.Err(w, r,
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
// SPDX-FileCopyrightText: © 2025 Olivier Meunier <olivier@neokraft.net>
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
|
||||
"codeberg.org/readeck/readeck/internal/db"
|
||||
"codeberg.org/readeck/readeck/internal/db/types"
|
||||
"codeberg.org/readeck/readeck/pkg/base58"
|
||||
)
|
||||
|
||||
const (
|
||||
// TableName is the database table.
|
||||
TableName = "oauth2_client"
|
||||
)
|
||||
|
||||
var (
|
||||
// Clients is the model manager for [Client] instances.
|
||||
Clients = Manager{}
|
||||
|
||||
// ErrNotFound is returned when a token record was not found.
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
// Client is an oauth2 client record in database.
|
||||
type Client struct {
|
||||
ID int `db:"id" goqu:"skipinsert,skipupdate"`
|
||||
UID string `db:"uid"`
|
||||
Created time.Time `db:"created" goqu:"skipupdate"`
|
||||
Name string `db:"name"`
|
||||
Website string `db:"website"`
|
||||
Logo string `db:"logo"`
|
||||
RedirectURIs types.Strings `db:"redirect_uris"`
|
||||
SoftwareID string `db:"software_id"`
|
||||
SoftwareVersion string `db:"software_version"`
|
||||
}
|
||||
|
||||
// Manager is a query helper for client entries.
|
||||
type Manager struct{}
|
||||
|
||||
// Query returns a prepared [goqu.SelectDataset] that can be extended later.
|
||||
func (m *Manager) Query() *goqu.SelectDataset {
|
||||
return db.Q().From(goqu.T(TableName).As("c")).Prepared(true)
|
||||
}
|
||||
|
||||
// GetOne executes the a select query and returns the first result or an error
|
||||
// when there's no result.
|
||||
func (m *Manager) GetOne(expressions ...goqu.Expression) (*Client, error) {
|
||||
var c Client
|
||||
found, err := m.Query().Where(expressions...).ScanStruct(&c)
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
case !found:
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Create insert a new client in the database.
|
||||
func (m *Manager) Create(client *Client) error {
|
||||
client.Created = time.Now().UTC()
|
||||
client.UID = base58.NewUUID()
|
||||
|
||||
ds := db.Q().Insert(TableName).
|
||||
Rows(client).
|
||||
Prepared(true)
|
||||
|
||||
id, err := db.InsertWithID(ds, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates some bookmark values.
|
||||
func (c *Client) Update(v interface{}) error {
|
||||
if c.ID == 0 {
|
||||
return errors.New("no ID")
|
||||
}
|
||||
|
||||
_, err := db.Q().Update(TableName).Prepared(true).
|
||||
Set(v).
|
||||
Where(goqu.C("id").Eq(c.ID)).
|
||||
Executor().Exec()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Save updates all the token values.
|
||||
func (c *Client) Save() error {
|
||||
return c.Update(c)
|
||||
}
|
||||
|
||||
// Delete removes a token from the database.
|
||||
func (c *Client) Delete() error {
|
||||
_, err := db.Q().Delete(TableName).Prepared(true).
|
||||
Where(goqu.C("id").Eq(c.ID)).
|
||||
Executor().Exec()
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"codeberg.org/readeck/readeck/internal/auth"
|
||||
"codeberg.org/readeck/readeck/internal/server"
|
||||
"codeberg.org/readeck/readeck/pkg/ctxr"
|
||||
"codeberg.org/readeck/readeck/pkg/forms"
|
||||
@@ -35,10 +36,10 @@ type oauthAPI struct {
|
||||
|
||||
func newOAuthAPI() *oauthAPI {
|
||||
router := &oauthAPI{chi.NewRouter()}
|
||||
router.Mount("/client", newClientAPI())
|
||||
router.Post("/client", router.clientCreate)
|
||||
router.Post("/device", router.deviceHandler)
|
||||
router.Post("/token", router.tokenHandler)
|
||||
router.With(withAuthenticatedClient).Post("/revoke", router.revokeToken)
|
||||
router.With(auth.Required).Post("/revoke", router.revokeToken)
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -75,8 +76,7 @@ func (api *oauthAPI) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
client := getClient(r.Context())
|
||||
if err := f.revoke(client); err != nil {
|
||||
if err := f.revoke(r); err != nil {
|
||||
switch err.(type) {
|
||||
case oauthError:
|
||||
server.Err(w, r, err)
|
||||
|
||||
@@ -14,13 +14,21 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"codeberg.org/readeck/readeck/configs"
|
||||
"codeberg.org/readeck/readeck/internal/auth/oauth2"
|
||||
"codeberg.org/readeck/readeck/internal/db/types"
|
||||
|
||||
. "codeberg.org/readeck/readeck/internal/testing" //revive:disable:dot-imports
|
||||
)
|
||||
|
||||
func registerClient(t *testing.T, client *Client) string {
|
||||
rsp := client.RequestJSON(http.MethodPost, "/api/oauth/client", map[string]any{
|
||||
"client_name": "Test App",
|
||||
"client_uri": "https://example.net/",
|
||||
"software_id": uuid.NewString(),
|
||||
"software_version": "1.0.2",
|
||||
"redirect_uris": []string{"http://[::1]:4000/callback"},
|
||||
})
|
||||
require.Equal(t, 201, rsp.StatusCode)
|
||||
return rsp.JSON.(map[string]any)["client_id"].(string)
|
||||
}
|
||||
|
||||
func TestServerMetadata(t *testing.T) {
|
||||
app := NewTestApp(t)
|
||||
defer app.Close(t)
|
||||
@@ -52,205 +60,261 @@ func TestAuthorizationCodeFlow(t *testing.T) {
|
||||
|
||||
client := NewClient(t, app)
|
||||
|
||||
appClient := &oauth2.Client{
|
||||
Name: "Test App",
|
||||
Website: "https://example.net",
|
||||
RedirectURIs: types.Strings{"http://[::1]:4000/callback"},
|
||||
SoftwareID: uuid.NewString(),
|
||||
SoftwareVersion: "1.0.2",
|
||||
}
|
||||
require.NoError(t, oauth2.Clients.Create(appClient))
|
||||
appToken, err := configs.Keys.OauthClientTokenKey().Encode(appClient.UID)
|
||||
require.NoError(t, err)
|
||||
t.Run("authorization form", func(t *testing.T) {
|
||||
clientID := registerClient(t, client)
|
||||
|
||||
codeVerifier := "210e967c91ae52d32bf414f1439769fb0eda1f828fc8d49c78e18ac5"
|
||||
h := sha256.New()
|
||||
h.Write([]byte(codeVerifier))
|
||||
codeVerifier := "210e967c91ae52d32bf414f1439769fb0eda1f828fc8d49c78e18ac5"
|
||||
h := sha256.New()
|
||||
h.Write([]byte(codeVerifier))
|
||||
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
params := url.Values{}
|
||||
params.Set("client_id", appClient.UID)
|
||||
params.Set("redirect_uri", "http://[::1]:4000/callback")
|
||||
params.Set("scope", "bookmarks:read bookmarks:write")
|
||||
params.Set("state", "random.state")
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
params := url.Values{}
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", "http://[::1]:4000/callback")
|
||||
params.Set("scope", "bookmarks:read bookmarks:write")
|
||||
params.Set("state", "random.state")
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
|
||||
tokenCode := ""
|
||||
accessToken := ""
|
||||
_ = tokenCode
|
||||
_ = accessToken
|
||||
RunRequestSequence(t, client, "", RequestTest{
|
||||
Target: "/authorize",
|
||||
ExpectStatus: 303,
|
||||
})
|
||||
|
||||
RunRequestSequence(t, client, "", RequestTest{
|
||||
Target: "/authorize",
|
||||
ExpectStatus: 303,
|
||||
RunRequestSequence(t, client, "user",
|
||||
RequestTest{
|
||||
Name: "authorize ko",
|
||||
Target: "/authorize",
|
||||
ExpectStatus: 401,
|
||||
ExpectJSON: `{"error": "invalid_client"}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize form",
|
||||
Target: "/authorize?" + params.Encode(),
|
||||
ExpectContains: `Authorize</button>`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize ok",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"granted": []string{"1"},
|
||||
},
|
||||
ExpectStatus: http.StatusFound,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
assert := require.New(t)
|
||||
location := r.Header.Get("location")
|
||||
assert.NotEmpty(location)
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(err)
|
||||
|
||||
query := u.Query()
|
||||
assert.Contains(query, "code")
|
||||
assert.NotEmpty(query.Get("code"))
|
||||
assert.Equal("random.state", query.Get("state"))
|
||||
},
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize deny",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"granted": []string{"0"},
|
||||
},
|
||||
ExpectStatus: http.StatusFound,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
assert := require.New(t)
|
||||
location := r.Header.Get("location")
|
||||
assert.NotEmpty(location)
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(err)
|
||||
|
||||
query := u.Query()
|
||||
assert.Equal("access_denied", query.Get("error"))
|
||||
assert.Equal("access denied", query.Get("error_description"))
|
||||
assert.Contains(query, "code")
|
||||
assert.Empty(query.Get("code"))
|
||||
},
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client gone after deny",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodGet,
|
||||
ExpectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
RunRequestSequence(t, client, "user",
|
||||
RequestTest{
|
||||
Name: "authorize ko",
|
||||
Target: "/authorize",
|
||||
ExpectStatus: 401,
|
||||
ExpectJSON: `{"error": "invalid_client"}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize form",
|
||||
Target: "/authorize?" + params.Encode(),
|
||||
ExpectContains: `Authorize</button>`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize ok",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"granted": []string{"1"},
|
||||
},
|
||||
ExpectStatus: http.StatusFound,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
assert := require.New(t)
|
||||
location := r.Header.Get("location")
|
||||
assert.NotEmpty(location)
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(err)
|
||||
t.Run("token", func(t *testing.T) {
|
||||
clientID := registerClient(t, client)
|
||||
|
||||
query := u.Query()
|
||||
assert.Contains(query, "code")
|
||||
assert.NotEmpty(query.Get("code"))
|
||||
assert.Equal("random.state", query.Get("state"))
|
||||
tokenCode = query.Get("code")
|
||||
},
|
||||
},
|
||||
RequestTest{
|
||||
Name: "authorize deny",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"granted": []string{"0"},
|
||||
},
|
||||
ExpectStatus: http.StatusFound,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
assert := require.New(t)
|
||||
location := r.Header.Get("location")
|
||||
assert.NotEmpty(location)
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(err)
|
||||
codeVerifier := "210e967c91ae52d32bf414f1439769fb0eda1f828fc8d49c78e18ac5"
|
||||
h := sha256.New()
|
||||
h.Write([]byte(codeVerifier))
|
||||
|
||||
query := u.Query()
|
||||
assert.Equal("access_denied", query.Get("error"))
|
||||
assert.Equal("access denied", query.Get("error_description"))
|
||||
assert.Contains(query, "code")
|
||||
assert.Empty(query.Get("code"))
|
||||
},
|
||||
},
|
||||
)
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
params := url.Values{}
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", "http://[::1]:4000/callback")
|
||||
params.Set("scope", "bookmarks:read bookmarks:write")
|
||||
params.Set("state", "random.state")
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_request",
|
||||
"error_description":"error on field \"grant_type\": field is required"
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{"pnIEw47"},
|
||||
"code_verifier": []string{"FaRhB"},
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_grant",
|
||||
"error_description":"code is not valid"
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{tokenCode},
|
||||
"code_verifier": []string{"8GrkY4Fk1XpnIEw47w71XDEoMFaRhB8SvLhs9ZLCCNI"},
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_grant",
|
||||
"error_description":"code is not valid"
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "token ok",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{tokenCode},
|
||||
"code_verifier": []string{codeVerifier},
|
||||
},
|
||||
ExpectStatus: http.StatusCreated,
|
||||
ExpectJSON: `{
|
||||
"access_token": "<<PRESENCE>>",
|
||||
"id": "<<PRESENCE>>",
|
||||
"scope": "bookmarks:read bookmarks:write",
|
||||
"token_type": "Bearer"
|
||||
}`,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
accessToken = r.JSON.(map[string]any)["access_token"].(string)
|
||||
t.Log(accessToken)
|
||||
},
|
||||
},
|
||||
)
|
||||
tokenCode := ""
|
||||
accessToken := ""
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "profile with new token",
|
||||
Target: "/api/profile",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
RunRequestSequence(t, client, "user",
|
||||
RequestTest{
|
||||
Name: "authorize form",
|
||||
Target: "/authorize?" + params.Encode(),
|
||||
ExpectContains: `Authorize</button>`,
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "revoke token",
|
||||
Target: "/api/oauth/revoke",
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + appToken,
|
||||
RequestTest{
|
||||
Name: "authorize ok",
|
||||
Target: "{{ (index .History 0).URL }}",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"granted": []string{"1"},
|
||||
},
|
||||
ExpectStatus: http.StatusFound,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
u, err := url.Parse(r.Header.Get("location"))
|
||||
require.NoError(t, err)
|
||||
|
||||
query := u.Query()
|
||||
tokenCode = query.Get("code")
|
||||
},
|
||||
},
|
||||
JSON: map[string]string{
|
||||
"token": accessToken,
|
||||
)
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_request",
|
||||
"error_description":"error on field \"grant_type\": field is required"
|
||||
}`,
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "profile with revoked token",
|
||||
Target: "/api/profile",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{"pnIEw47"},
|
||||
"code_verifier": []string{"FaRhB"},
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_grant",
|
||||
"error_description":"code is not valid"
|
||||
}`,
|
||||
},
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "revoke token already revoked",
|
||||
Target: "/api/oauth/revoke",
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + appToken,
|
||||
RequestTest{
|
||||
Name: "token challenge ko",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{tokenCode},
|
||||
"code_verifier": []string{"8GrkY4Fk1XpnIEw47w71XDEoMFaRhB8SvLhs9ZLCCNI"},
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"invalid_grant",
|
||||
"error_description":"code is not valid"
|
||||
}`,
|
||||
},
|
||||
JSON: map[string]string{
|
||||
"token": accessToken,
|
||||
RequestTest{
|
||||
Name: "token ok",
|
||||
Target: "/api/oauth/token",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{tokenCode},
|
||||
"code_verifier": []string{codeVerifier},
|
||||
},
|
||||
ExpectStatus: http.StatusCreated,
|
||||
ExpectJSON: `{
|
||||
"access_token": "<<PRESENCE>>",
|
||||
"id": "<<PRESENCE>>",
|
||||
"scope": "bookmarks:read bookmarks:write",
|
||||
"token_type": "Bearer"
|
||||
}`,
|
||||
Assert: func(t *testing.T, r *Response) {
|
||||
accessToken = r.JSON.(map[string]any)["access_token"].(string)
|
||||
t.Log(accessToken)
|
||||
require.Empty(t, Store().Get("oauth:client:"+clientID))
|
||||
},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "profile with new token",
|
||||
Target: "/api/profile",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
)
|
||||
|
||||
RunRequestSequence(t, client, "user",
|
||||
RequestTest{
|
||||
Name: "revoke token with session auth",
|
||||
Target: "/api/oauth/revoke",
|
||||
Method: http.MethodPost,
|
||||
JSON: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{
|
||||
"error":"access_denied"
|
||||
}`,
|
||||
},
|
||||
)
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "revoke token",
|
||||
Target: "/api/oauth/revoke",
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
},
|
||||
JSON: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "profile with revoked token",
|
||||
Target: "/api/profile",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
},
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "revoke token already revoked",
|
||||
Target: "/api/oauth/revoke",
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
},
|
||||
JSON: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceCodeFlow(t *testing.T) {
|
||||
@@ -259,16 +323,8 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
|
||||
client := NewClient(t, app)
|
||||
|
||||
appClient := &oauth2.Client{
|
||||
Name: "Test App",
|
||||
Website: "https://example.net",
|
||||
RedirectURIs: types.Strings{"http://[::1]:4000/callback"},
|
||||
SoftwareID: uuid.NewString(),
|
||||
SoftwareVersion: "1.0.2",
|
||||
}
|
||||
require.NoError(t, oauth2.Clients.Create(appClient))
|
||||
|
||||
t.Run("granted access", func(t *testing.T) {
|
||||
clientID := registerClient(t, client)
|
||||
deviceCode := ""
|
||||
userCode := ""
|
||||
|
||||
@@ -278,7 +334,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
Target: "/api/oauth/device",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"client_id": {appClient.UID},
|
||||
"client_id": {clientID},
|
||||
"scope": {"bookmarks:read"},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
@@ -311,7 +367,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
JSON: map[string]any{
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": "{{ (index .History 0).JSON.device_code }}",
|
||||
"client_id": appClient.UID,
|
||||
"client_id": clientID,
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{"error": "authorization_pending"}`,
|
||||
@@ -323,7 +379,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
JSON: map[string]any{
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": "{{ (index .History 1).JSON.device_code }}",
|
||||
"client_id": appClient.UID,
|
||||
"client_id": clientID,
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{"error": "slow_down"}`,
|
||||
@@ -387,7 +443,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
JSON: map[string]any{
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": deviceCode,
|
||||
"client_id": appClient.UID,
|
||||
"client_id": clientID,
|
||||
},
|
||||
ExpectStatus: 201,
|
||||
ExpectJSON: `{
|
||||
@@ -409,6 +465,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("denied access", func(t *testing.T) {
|
||||
clientID := registerClient(t, client)
|
||||
deviceCode := ""
|
||||
userCode := ""
|
||||
|
||||
@@ -418,7 +475,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
Target: "/api/oauth/device",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"client_id": {appClient.UID},
|
||||
"client_id": {clientID},
|
||||
"scope": {"bookmarks:read"},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
@@ -465,7 +522,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
JSON: map[string]any{
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": deviceCode,
|
||||
"client_id": appClient.UID,
|
||||
"client_id": clientID,
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{"error": "access_denied"}`,
|
||||
@@ -474,6 +531,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("expired code", func(t *testing.T) {
|
||||
clientID := registerClient(t, client)
|
||||
deviceCode := ""
|
||||
userCode := ""
|
||||
|
||||
@@ -483,7 +541,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
Target: "/api/oauth/device",
|
||||
Method: http.MethodPost,
|
||||
Form: url.Values{
|
||||
"client_id": {appClient.UID},
|
||||
"client_id": {clientID},
|
||||
"scope": {"bookmarks:read"},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
@@ -519,7 +577,7 @@ func TestDeviceCodeFlow(t *testing.T) {
|
||||
JSON: map[string]any{
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": deviceCode,
|
||||
"client_id": appClient.UID,
|
||||
"client_id": clientID,
|
||||
},
|
||||
ExpectStatus: 400,
|
||||
ExpectJSON: `{"error": "expired_token"}`,
|
||||
@@ -534,23 +592,6 @@ func TestClientRegistration(t *testing.T) {
|
||||
|
||||
client := NewClient(t, app)
|
||||
|
||||
appClient1 := &oauth2.Client{
|
||||
Name: "Test App 1",
|
||||
Website: "https://example.net",
|
||||
RedirectURIs: types.Strings{"http://[::1]:4000/callback"},
|
||||
SoftwareID: uuid.NewString(),
|
||||
SoftwareVersion: "1.0.2",
|
||||
}
|
||||
appClient2 := &oauth2.Client{
|
||||
Name: "Test App 2",
|
||||
Website: "https://example.net",
|
||||
RedirectURIs: types.Strings{"http://[::1]:4000/callback"},
|
||||
SoftwareID: uuid.NewString(),
|
||||
SoftwareVersion: "1.0.2",
|
||||
}
|
||||
require.NoError(t, oauth2.Clients.Create(appClient1))
|
||||
require.NoError(t, oauth2.Clients.Create(appClient2))
|
||||
|
||||
t.Run("invalid_redirect_uri", func(t *testing.T) {
|
||||
tests := []string{
|
||||
"http://test.localhost/",
|
||||
@@ -575,31 +616,33 @@ func TestClientRegistration(t *testing.T) {
|
||||
RunRequestSequence(t, client, "", seq...)
|
||||
})
|
||||
|
||||
t.Run("client bearer", func(t *testing.T) {
|
||||
t1, err := configs.Keys.OauthClientTokenKey().Encode(appClient1.UID)
|
||||
require.NoError(t, err)
|
||||
t.Run("invalid client_uri", func(t *testing.T) {
|
||||
tests := []string{
|
||||
"http://example.com/",
|
||||
"https://192.168.0.1:4000/test",
|
||||
"https://[fc00::1]/",
|
||||
"https://localhost/",
|
||||
"https://test.localhost/",
|
||||
}
|
||||
|
||||
t2, err := configs.Keys.OauthClientTokenKey().Encode(appClient2.UID)
|
||||
require.NoError(t, err)
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
RequestTest{
|
||||
Name: "ok",
|
||||
Target: "/api/oauth/client/" + appClient1.UID,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + t1,
|
||||
seq := []RequestTest{}
|
||||
for _, test := range tests {
|
||||
seq = append(seq, RequestTest{
|
||||
Name: "no client URI",
|
||||
Target: "/api/oauth/client",
|
||||
Method: "POST",
|
||||
JSON: map[string]any{
|
||||
"client_name": "test",
|
||||
"client_uri": test,
|
||||
"redirect_uris": []string{"https://example.org/callback"},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "mismatch id and token",
|
||||
Target: "/api/oauth/client/" + appClient1.UID,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + t2,
|
||||
},
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
)
|
||||
ExpectJSON: `{
|
||||
"error": "invalid_client_metadata",
|
||||
"error_description": "error on field \"client_uri\": invalid client URI"
|
||||
}`,
|
||||
})
|
||||
}
|
||||
RunRequestSequence(t, client, "", seq...)
|
||||
})
|
||||
|
||||
RunRequestSequence(t, client, "",
|
||||
@@ -616,7 +659,7 @@ func TestClientRegistration(t *testing.T) {
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "no client name",
|
||||
Name: "no client URI",
|
||||
Target: "/api/oauth/client",
|
||||
Method: "POST",
|
||||
JSON: map[string]any{
|
||||
@@ -650,8 +693,6 @@ func TestClientRegistration(t *testing.T) {
|
||||
ExpectStatus: 201,
|
||||
ExpectJSON: `{
|
||||
"client_id":"<<PRESENCE>>",
|
||||
"registration_access_token": "<<PRESENCE>>",
|
||||
"registration_client_uri": "<<PRESENCE>>",
|
||||
"client_name": "test",
|
||||
"client_uri": "https://example.org/",
|
||||
"logo_uri": "",
|
||||
@@ -666,97 +707,9 @@ func TestClientRegistration(t *testing.T) {
|
||||
"software_id":"10098d11-8b3f-4ebb-b519-cab2301975fa",
|
||||
"software_version":"1.0.0",
|
||||
"token_endpoint_auth_method":"none",
|
||||
"grant_types":["authorization_code"],
|
||||
"grant_types":["authorization_code","urn:ietf:params:oauth:grant-type:device_code"],
|
||||
"response_types":["code"]
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client info",
|
||||
Target: "/api/oauth/client/{{ (index .History 0).JSON.client_id }}",
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client info",
|
||||
Target: "/api/oauth/client/{{ (index .History 1).JSON.client_id }}",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer {{ (index .History 1).JSON.registration_access_token }}",
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
ExpectJSON: `{
|
||||
"client_id":"<<PRESENCE>>",
|
||||
"registration_access_token": "<<PRESENCE>>",
|
||||
"registration_client_uri": "<<PRESENCE>>",
|
||||
"client_name": "test",
|
||||
"client_uri": "https://example.org/",
|
||||
"logo_uri": "",
|
||||
"redirect_uris": [
|
||||
"https://example.org/callback",
|
||||
"https://example.org/callback2",
|
||||
"net.myapp:oauth-callback",
|
||||
"net.myapp:///oauth-callback",
|
||||
"http://127.0.0.8:8000/callback",
|
||||
"http://[::1]:8000/callback"
|
||||
],
|
||||
"software_id": "10098d11-8b3f-4ebb-b519-cab2301975fa",
|
||||
"software_version": "1.0.0",
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"]
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client update",
|
||||
Target: "/api/oauth/client/{{ (index .History 0).JSON.client_id }}",
|
||||
Method: "PUT",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer {{ (index .History 0).JSON.registration_access_token }}",
|
||||
},
|
||||
JSON: map[string]any{
|
||||
"client_id": "{{ (index .History 0).JSON.client_id }}",
|
||||
"client_name": "test",
|
||||
"client_uri": "https://example.org/",
|
||||
"logo_uri": "",
|
||||
"redirect_uris": []string{"http://[::1]:8000/callback"},
|
||||
"software_id": "10098d11-8b3f-4ebb-b519-cab2301975fa",
|
||||
"software_version": "1.0.0",
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": []string{"authorization_code"},
|
||||
"response_types": []string{"code"},
|
||||
},
|
||||
ExpectStatus: 200,
|
||||
ExpectJSON: `{
|
||||
"client_id":"<<PRESENCE>>",
|
||||
"registration_access_token": "<<PRESENCE>>",
|
||||
"registration_client_uri": "<<PRESENCE>>",
|
||||
"client_name":"test",
|
||||
"client_uri":"https://example.org/",
|
||||
"logo_uri":"",
|
||||
"redirect_uris":[
|
||||
"http://[::1]:8000/callback"
|
||||
],
|
||||
"software_id":"10098d11-8b3f-4ebb-b519-cab2301975fa",
|
||||
"software_version":"1.0.0",
|
||||
"token_endpoint_auth_method":"none",
|
||||
"grant_types":["authorization_code"],
|
||||
"response_types":["code"]
|
||||
}`,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client delete",
|
||||
Target: "/api/oauth/client/{{ (index .History 0).JSON.client_id }}",
|
||||
Method: "DELETE",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer {{ (index .History 0).JSON.registration_access_token }}",
|
||||
},
|
||||
ExpectStatus: 204,
|
||||
},
|
||||
RequestTest{
|
||||
Name: "client info",
|
||||
Target: "/api/oauth/client/{{ (index .History 1).JSON.client_id }}",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer {{ (index .History 1).JSON.registration_access_token }}",
|
||||
},
|
||||
ExpectStatus: 401,
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -7,6 +7,8 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -37,7 +39,7 @@ type Token struct {
|
||||
ID int `db:"id" goqu:"skipinsert,skipupdate"`
|
||||
UID string `db:"uid"`
|
||||
UserID *int `db:"user_id"`
|
||||
ClientID *int `db:"client_id"`
|
||||
ClientInfo *ClientInfo `db:"client_info"`
|
||||
Created time.Time `db:"created" goqu:"skipupdate"`
|
||||
LastUsed *time.Time `db:"last_used"`
|
||||
Expires *time.Time `db:"expires"`
|
||||
@@ -158,6 +160,44 @@ func (t *Token) IsExpired() bool {
|
||||
return time.Now().UTC().After(*t.Expires)
|
||||
}
|
||||
|
||||
// ClientInfo contains a token's OAuth registered client.
|
||||
type ClientInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Website string `json:"website"`
|
||||
Logo string `json:"logo"`
|
||||
SoftwareID string `json:"software_id"`
|
||||
SoftwareVersion string `json:"software_version"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
}
|
||||
|
||||
// Scan loads a UserSettings instance from a column.
|
||||
func (s *ClientInfo) Scan(value any) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v, err := types.JSONBytes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
json.Unmarshal(v, s) //nolint:errcheck
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value encodes a UserSettings value for storage.
|
||||
func (s *ClientInfo) Value() (driver.Value, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
v, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
// TokenAndUser is a result of a joint query on user and token tables.
|
||||
type TokenAndUser struct {
|
||||
Token *Token `db:"t"`
|
||||
|
||||
30
internal/bus/store.go
Normal file
30
internal/bus/store.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-FileCopyrightText: © 2025 Olivier Meunier <olivier@neokraft.net>
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
package bus
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SetJSON stores a value as a JSON string.
|
||||
func SetJSON(key string, value any, expiration time.Duration) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.Set(key, string(data), expiration)
|
||||
}
|
||||
|
||||
// GetJSON retrieves a value as a JSON string. It returns [ErrNotExists]
|
||||
// when the value was not in the store already.
|
||||
func GetJSON(key string, value any) error {
|
||||
data := store.Get(key)
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal([]byte(data), value)
|
||||
}
|
||||
@@ -101,5 +101,5 @@ var migrationList = []migrationEntry{
|
||||
newMigrationEntry(19, "bookmark_text_normalization", migrations.M19bookmarkTextNormalization),
|
||||
newMigrationEntry(20, "bookmark_removed", applyMigrationFile("20_bookmark_removed.sql")),
|
||||
newMigrationEntry(21, "token_roles", migrations.M21tokenRoles),
|
||||
newMigrationEntry(22, "oauth2", applyMigrationFile("22_oauth2.sql")),
|
||||
newMigrationEntry(23, "oauth2", migrations.M23oauth),
|
||||
}
|
||||
|
||||
59
internal/db/migrations/23_oauth2.go
Normal file
59
internal/db/migrations/23_oauth2.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-FileCopyrightText: © 2025 Olivier Meunier <olivier@neokraft.net>
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
)
|
||||
|
||||
// M23oauth adds a client_info column in the token table.
|
||||
// It also takes care of removing the never released (but in nightly builds)
|
||||
// oauth2_client table and the token.client_id column.
|
||||
func M23oauth(db *goqu.TxDatabase, _ fs.FS) error {
|
||||
var err error
|
||||
|
||||
// What should be the script without nightly clean-up
|
||||
switch db.Dialect() {
|
||||
case "sqlite3":
|
||||
_, err = db.Exec(`ALTER TABLE token ADD COLUMN client_info json NULL`)
|
||||
case "postgres":
|
||||
_, err = db.Exec(`ALTER TABLE token ADD COLUMN client_info jsonb NULL`)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove unreleased token.client_id column
|
||||
var found bool
|
||||
|
||||
switch db.Dialect() {
|
||||
case "sqlite3":
|
||||
var cname string
|
||||
found, err = db.ScanVal(&cname, `SELECT name FROM pragma_table_info('token') WHERE name = 'client_id'`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "postgres":
|
||||
var cname string
|
||||
found, err = db.ScanVal(&cname, `SELECT column_name FROM information_schema.columns WHERE table_name = 'token' AND column_name = 'client_id' `)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if found {
|
||||
if _, err = db.Exec(`DELETE FROM token WHERE client_id IS NOT NULL`); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = db.Exec(`ALTER TABLE token DROP COLUMN client_id`); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up unreleased oauth2_client table
|
||||
_, err = db.Exec(`DROP TABLE IF EXISTS oauth2_client`)
|
||||
return err
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
-- SPDX-FileCopyrightText: © 2025 Olivier Meunier <olivier@neokraft.net>
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
CREATE TABLE IF NOT EXISTS oauth2_client (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid varchar(32) UNIQUE NOT NULL,
|
||||
created timestamptz NOT NULL,
|
||||
name varchar(128) NOT NULL,
|
||||
website varchar(256) NULL,
|
||||
logo text NULL,
|
||||
redirect_uris jsonb NOT NULL,
|
||||
software_id varchar(128) NOT NULL,
|
||||
software_version varchar(128) NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE token ADD COLUMN client_id integer NULL
|
||||
CONSTRAINT fk_token_oauth2_client REFERENCES oauth2_client(id) ON DELETE CASCADE;
|
||||
@@ -21,32 +21,19 @@ CREATE TABLE IF NOT EXISTS "user" (
|
||||
seed integer NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS oauth2_client (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid varchar(32) UNIQUE NOT NULL,
|
||||
created timestamptz NOT NULL,
|
||||
name varchar(128) NOT NULL,
|
||||
website varchar(256) NULL,
|
||||
logo text NULL,
|
||||
redirect_uris jsonb NOT NULL,
|
||||
software_id varchar(128) NOT NULL,
|
||||
software_version varchar(128) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS token (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid varchar(32) UNIQUE NOT NULL,
|
||||
user_id integer NOT NULL,
|
||||
client_id integer NULL,
|
||||
created timestamptz NOT NULL,
|
||||
last_used timestamptz NULL,
|
||||
expires timestamptz NULL,
|
||||
is_enabled boolean NOT NULL DEFAULT true,
|
||||
application varchar(128) NOT NULL,
|
||||
roles jsonb NOT NULL DEFAULT '[]',
|
||||
client_info jsonb NULL,
|
||||
|
||||
CONSTRAINT fk_token_user FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_token_oauth2_client FOREIGN KEY (client_id) REFERENCES oauth2_client(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "credential" (
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
-- SPDX-FileCopyrightText: © 2025 Olivier Meunier <olivier@neokraft.net>
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
CREATE TABLE IF NOT EXISTS oauth2_client (
|
||||
id integer PRIMARY KEY AUTOINCREMENT,
|
||||
uid text UNIQUE NOT NULL,
|
||||
created datetime NOT NULL,
|
||||
name text NOT NULL,
|
||||
website text NULL,
|
||||
logo text NULL,
|
||||
redirect_uris json NOT NULL,
|
||||
software_id text NOT NULL,
|
||||
software_version text NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE token ADD COLUMN client_id integer NULL
|
||||
CONSTRAINT fk_token_oauth2_client REFERENCES oauth2_client(id) ON DELETE CASCADE;
|
||||
@@ -21,32 +21,19 @@ CREATE TABLE IF NOT EXISTS user (
|
||||
seed integer NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS oauth2_client (
|
||||
id integer PRIMARY KEY AUTOINCREMENT,
|
||||
uid text UNIQUE NOT NULL,
|
||||
created datetime NOT NULL,
|
||||
name text NOT NULL,
|
||||
website text NULL,
|
||||
logo text NULL,
|
||||
redirect_uris json NOT NULL,
|
||||
software_id text NOT NULL,
|
||||
software_version text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS token (
|
||||
id integer PRIMARY KEY AUTOINCREMENT,
|
||||
uid text UNIQUE NOT NULL,
|
||||
user_id integer NOT NULL,
|
||||
client_id integer NULL,
|
||||
created datetime NOT NULL,
|
||||
last_used datetime NULL,
|
||||
expires datetime NULL,
|
||||
is_enabled integer NOT NULL DEFAULT 1,
|
||||
application text NOT NULL,
|
||||
roles json NOT NULL DEFAULT "",
|
||||
client_info json NULL,
|
||||
|
||||
CONSTRAINT fk_token_user FOREIGN KEY (user_id) REFERENCES user(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_token_oauth2_client FOREIGN KEY (client_id) REFERENCES oauth2_client(id) ON DELETE CASCADE
|
||||
CONSTRAINT fk_token_user FOREIGN KEY (user_id) REFERENCES user(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS credential (
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"codeberg.org/readeck/readeck/internal/auth"
|
||||
"codeberg.org/readeck/readeck/internal/auth/oauth2"
|
||||
"codeberg.org/readeck/readeck/internal/auth/tokens"
|
||||
"codeberg.org/readeck/readeck/internal/auth/users"
|
||||
"codeberg.org/readeck/readeck/internal/db/scanner"
|
||||
@@ -171,7 +170,6 @@ func (api *profileAPI) withTokenList(t tokenType) func(next http.Handler) http.H
|
||||
}
|
||||
|
||||
ds := tokens.Tokens.Query().
|
||||
LeftOuterJoin(goqu.T(oauth2.TableName).As("c"), goqu.On(goqu.I("c.id").Eq(goqu.I("t.client_id")))).
|
||||
Where(
|
||||
goqu.C("user_id").Eq(auth.GetRequestUser(r).ID),
|
||||
).
|
||||
@@ -184,9 +182,9 @@ func (api *profileAPI) withTokenList(t tokenType) func(next http.Handler) http.H
|
||||
|
||||
switch t {
|
||||
case userToken:
|
||||
ds = ds.Where(goqu.C("client_id").Is(nil))
|
||||
ds = ds.Where(goqu.C("client_info").IsNull())
|
||||
case clientToken:
|
||||
ds = ds.Where(goqu.C("client_id").IsNot(nil))
|
||||
ds = ds.Where(goqu.C("client_info").IsNotNull())
|
||||
}
|
||||
|
||||
var res *tokenItemList
|
||||
@@ -208,7 +206,6 @@ func (api *profileAPI) withToken(t tokenType) func(next http.Handler) http.Handl
|
||||
uid := chi.URLParam(r, "uid")
|
||||
|
||||
ds := tokens.Tokens.Query().
|
||||
LeftOuterJoin(goqu.T(oauth2.TableName).As("c"), goqu.On(goqu.I("c.id").Eq(goqu.I("t.client_id")))).
|
||||
Where(
|
||||
goqu.C("uid").Table("t").Eq(uid),
|
||||
goqu.C("user_id").Eq(auth.GetRequestUser(r).ID),
|
||||
@@ -216,12 +213,12 @@ func (api *profileAPI) withToken(t tokenType) func(next http.Handler) http.Handl
|
||||
|
||||
switch t {
|
||||
case userToken:
|
||||
ds = ds.Where(goqu.C("client_id").Is(nil))
|
||||
ds = ds.Where(goqu.C("client_info").IsNull())
|
||||
case clientToken:
|
||||
ds = ds.Where(goqu.C("client_id").IsNot(nil))
|
||||
ds = ds.Where(goqu.C("client_info").IsNotNull())
|
||||
}
|
||||
|
||||
t := new(tokenAndClient)
|
||||
t := new(tokens.Token)
|
||||
found, err := ds.ScanStruct(t)
|
||||
|
||||
if !found {
|
||||
@@ -257,17 +254,6 @@ func (api *profileAPI) tokenDelete(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type tokenAndClient struct {
|
||||
*tokens.Token `db:"t"`
|
||||
Client struct {
|
||||
UID *string `db:"uid"`
|
||||
Name *string `db:"name"`
|
||||
URI *string `db:"website"`
|
||||
Logo *string `db:"logo"`
|
||||
Version *string `db:"software_version"`
|
||||
} `db:"c"`
|
||||
}
|
||||
|
||||
type tokenItemList struct {
|
||||
Count int64
|
||||
Pagination server.Pagination
|
||||
@@ -277,25 +263,24 @@ type tokenItemList struct {
|
||||
type tokenItem struct {
|
||||
*tokens.Token `json:"-"`
|
||||
|
||||
ID string `json:"id"`
|
||||
Href string `json:"href"`
|
||||
Created time.Time `json:"created"`
|
||||
LastUsed *time.Time `json:"last_used"`
|
||||
Expires *time.Time `json:"expires"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
Roles []string `json:"roles"`
|
||||
RoleNames []string `json:"-"`
|
||||
ClientName string `json:"client_name"`
|
||||
ClientURI string `json:"client_uri"`
|
||||
ClientLogo string `json:"client_logo"`
|
||||
ClientVersion string `json:"client_version"`
|
||||
ID string `json:"id"`
|
||||
Href string `json:"href"`
|
||||
Created time.Time `json:"created"`
|
||||
LastUsed *time.Time `json:"last_used"`
|
||||
Expires *time.Time `json:"expires"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
Roles []string `json:"roles"`
|
||||
RoleNames []string `json:"-"`
|
||||
ClientName string `json:"client_name"`
|
||||
ClientURI string `json:"client_uri"`
|
||||
ClientLogo string `json:"client_logo"`
|
||||
}
|
||||
|
||||
func newTokenItem(ctx context.Context, t *tokenAndClient) *tokenItem {
|
||||
func newTokenItem(ctx context.Context, t *tokens.Token) *tokenItem {
|
||||
tr := server.LocaleContext(ctx)
|
||||
res := &tokenItem{
|
||||
Token: t.Token,
|
||||
Token: t,
|
||||
ID: t.UID,
|
||||
Href: urls.AbsoluteURLContext(ctx, ".", t.UID).String(),
|
||||
Created: t.Created,
|
||||
@@ -307,11 +292,10 @@ func newTokenItem(ctx context.Context, t *tokenAndClient) *tokenItem {
|
||||
RoleNames: users.GroupNames(tr, t.Roles),
|
||||
}
|
||||
|
||||
if t.Client.UID != nil {
|
||||
res.ClientName = *t.Client.Name
|
||||
res.ClientURI = *t.Client.URI
|
||||
res.ClientLogo = *t.Client.Logo
|
||||
res.ClientVersion = *t.Client.Version
|
||||
if t.ClientInfo != nil && t.ClientInfo.ID != "" {
|
||||
res.ClientName = t.ClientInfo.Name
|
||||
res.ClientURI = t.ClientInfo.Website
|
||||
res.ClientLogo = t.ClientInfo.Logo
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user