transfer.sh/transfersh-server/vendor/github.com/PuerkitoBio/ghost/handlers/session.go
2016-07-11 22:15:19 +02:00

321 lines
9.4 KiB
Go

package handlers
import (
"encoding/json"
"errors"
"hash/crc32"
"net/http"
"strings"
"time"
"github.com/PuerkitoBio/ghost"
"github.com/gorilla/securecookie"
"github.com/nu7hatch/gouuid"
)
const defaultCookieName = "ghost.sid"
var (
ErrSessionSecretMissing = errors.New("session secret is missing")
ErrNoSessionID = errors.New("session ID could not be generated")
)
// The Session holds the data map that persists for the duration of the session.
// The information stored in this map should be marshalable for the target Session store
// format (i.e. json, sql, gob, etc. depending on how the store persists the data).
type Session struct {
isNew bool // keep private, not saved to JSON, will be false once read from the store
internalSession
}
// Use a separate private struct to hold the private fields of the Session,
// although those fields are exposed (public). This is a trick to simplify
// JSON encoding.
type internalSession struct {
Data map[string]interface{} // JSON cannot marshal a map[interface{}]interface{}
ID string
Created time.Time
MaxAge time.Duration
}
// Create a new Session instance. It panics in the unlikely event that a new random ID cannot be generated.
func newSession(maxAge int) *Session {
uid, err := uuid.NewV4()
if err != nil {
panic(ErrNoSessionID)
}
return &Session{
true, // is new
internalSession{
make(map[string]interface{}),
uid.String(),
time.Now(),
time.Duration(maxAge) * time.Second,
},
}
}
// Gets the ID of the session.
func (ø *Session) ID() string {
return ø.internalSession.ID
}
// Get the max age duration
func (ø *Session) MaxAge() time.Duration {
return ø.internalSession.MaxAge
}
// Get the creation time of the session.
func (ø *Session) Created() time.Time {
return ø.internalSession.Created
}
// Is this a new Session (created by the current request)
func (ø *Session) IsNew() bool {
return ø.isNew
}
// TODO : Resets the max age property of the session to its original value (sliding expiration).
func (ø *Session) resetMaxAge() {
}
// Marshal the session to JSON.
func (ø *Session) MarshalJSON() ([]byte, error) {
return json.Marshal(ø.internalSession)
}
// Unmarshal the JSON into the internal session struct.
func (ø *Session) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, &ø.internalSession)
}
// Options object for the session handler. It specified the Session store to use for
// persistence, the template for the session cookie (name, path, maxage, etc.),
// whether or not the proxy should be trusted to determine if the connection is secure,
// and the required secret to sign the session cookie.
type SessionOptions struct {
Store SessionStore
CookieTemplate http.Cookie
TrustProxy bool
Secret string
}
// Create a new SessionOptions struct, using default cookie and proxy values.
func NewSessionOptions(store SessionStore, secret string) *SessionOptions {
return &SessionOptions{
Store: store,
Secret: secret,
}
}
// The augmented ResponseWriter struct for the session handler. It holds the current
// Session object and Session store, as well as flags and function to send the actual
// session cookie at the end of the request.
type sessResponseWriter struct {
http.ResponseWriter
sess *Session
sessStore SessionStore
sessSent bool
sendCookieFn func()
}
// Implement the WrapWriter interface.
func (ø *sessResponseWriter) WrappedWriter() http.ResponseWriter {
return ø.ResponseWriter
}
// Intercept the Write() method to add the Set-Cookie header before it's too late.
func (ø *sessResponseWriter) Write(data []byte) (int, error) {
if !ø.sessSent {
ø.sendCookieFn()
ø.sessSent = true
}
return ø.ResponseWriter.Write(data)
}
// Intercept the WriteHeader() method to add the Set-Cookie header before it's too late.
func (ø *sessResponseWriter) WriteHeader(code int) {
if !ø.sessSent {
ø.sendCookieFn()
ø.sessSent = true
}
ø.ResponseWriter.WriteHeader(code)
}
// SessionHandlerFunc is the same as SessionHandler, it is just a convenience
// signature that accepts a func(http.ResponseWriter, *http.Request) instead of
// a http.Handler interface. It saves the boilerplate http.HandlerFunc() cast.
func SessionHandlerFunc(h http.HandlerFunc, opts *SessionOptions) http.HandlerFunc {
return SessionHandler(h, opts)
}
// Create a Session handler to offer the Session behaviour to the specified handler.
func SessionHandler(h http.Handler, opts *SessionOptions) http.HandlerFunc {
// Make sure the required cookie fields are set
if opts.CookieTemplate.Name == "" {
opts.CookieTemplate.Name = defaultCookieName
}
if opts.CookieTemplate.Path == "" {
opts.CookieTemplate.Path = "/"
}
// Secret is required
if opts.Secret == "" {
panic(ErrSessionSecretMissing)
}
// Return the actual handler
return func(w http.ResponseWriter, r *http.Request) {
if _, ok := getSessionWriter(w); ok {
// Self-awareness
h.ServeHTTP(w, r)
return
}
if strings.Index(r.URL.Path, opts.CookieTemplate.Path) != 0 {
// Session does not apply to this path
h.ServeHTTP(w, r)
return
}
// Create a new Session or retrieve the existing session based on the
// session cookie received.
var sess *Session
var ckSessId string
exCk, err := r.Cookie(opts.CookieTemplate.Name)
if err != nil {
sess = newSession(opts.CookieTemplate.MaxAge)
ghost.LogFn("ghost.session : error getting session cookie : %s", err)
} else {
ckSessId, err = parseSignedCookie(exCk, opts.Secret)
if err != nil {
sess = newSession(opts.CookieTemplate.MaxAge)
ghost.LogFn("ghost.session : error parsing signed cookie : %s", err)
} else if ckSessId == "" {
sess = newSession(opts.CookieTemplate.MaxAge)
ghost.LogFn("ghost.session : no existing session ID")
} else {
// Get the session
sess, err = opts.Store.Get(ckSessId)
if err != nil {
sess = newSession(opts.CookieTemplate.MaxAge)
ghost.LogFn("ghost.session : error getting session from store : %s", err)
} else if sess == nil {
sess = newSession(opts.CookieTemplate.MaxAge)
ghost.LogFn("ghost.session : nil session")
}
}
}
// Save the original hash of the session, used to compare if the contents
// have changed during the handling of the request, so that it has to be
// saved to the stored.
oriHash := hash(sess)
// Create the augmented ResponseWriter.
srw := &sessResponseWriter{w, sess, opts.Store, false, func() {
// This function is called when the header is about to be written, so that
// the session cookie is correctly set.
// Check if the connection is secure
proto := strings.Trim(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), " ")
tls := r.TLS != nil || (strings.HasPrefix(proto, "https") && opts.TrustProxy)
if opts.CookieTemplate.Secure && !tls {
ghost.LogFn("ghost.session : secure cookie on a non-secure connection, cookie not sent")
return
}
if !sess.IsNew() {
// If this is not a new session, no need to send back the cookie
// TODO : Handle expires?
return
}
// Send the session cookie
ck := opts.CookieTemplate
ck.Value = sess.ID()
err := signCookie(&ck, opts.Secret)
if err != nil {
ghost.LogFn("ghost.session : error signing cookie : %s", err)
return
}
http.SetCookie(w, &ck)
}}
// Call wrapped handler
h.ServeHTTP(srw, r)
// TODO : Expiration management? srw.sess.resetMaxAge()
// Do not save if content is the same, unless session is new (to avoid
// creating a new session and sending a cookie on each successive request).
if newHash := hash(sess); !sess.IsNew() && oriHash == newHash && newHash != 0 {
// No changes to the session, no need to save
ghost.LogFn("ghost.session : no changes to save to store")
return
}
err = opts.Store.Set(sess)
if err != nil {
ghost.LogFn("ghost.session : error saving session to store : %s", err)
}
}
}
// Helper function to retrieve the session for the current request.
func GetSession(w http.ResponseWriter) (*Session, bool) {
ss, ok := getSessionWriter(w)
if ok {
return ss.sess, true
}
return nil, false
}
// Helper function to retrieve the session store
func GetSessionStore(w http.ResponseWriter) (SessionStore, bool) {
ss, ok := getSessionWriter(w)
if ok {
return ss.sessStore, true
}
return nil, false
}
// Internal helper function to retrieve the session writer object.
func getSessionWriter(w http.ResponseWriter) (*sessResponseWriter, bool) {
ss, ok := GetResponseWriter(w, func(tst http.ResponseWriter) bool {
_, ok := tst.(*sessResponseWriter)
return ok
})
if ok {
return ss.(*sessResponseWriter), true
}
return nil, false
}
// Parse a signed cookie and return the cookie value
func parseSignedCookie(ck *http.Cookie, secret string) (string, error) {
var val string
sck := securecookie.New([]byte(secret), nil)
err := sck.Decode(ck.Name, ck.Value, &val)
if err != nil {
return "", err
}
return val, nil
}
// Sign the specified cookie's value
func signCookie(ck *http.Cookie, secret string) error {
sck := securecookie.New([]byte(secret), nil)
enc, err := sck.Encode(ck.Name, ck.Value)
if err != nil {
return err
}
ck.Value = enc
return nil
}
// Compute a CRC32 hash of the session's JSON-encoded contents.
func hash(s *Session) uint32 {
data, err := json.Marshal(s)
if err != nil {
ghost.LogFn("ghost.session : error hash : %s", err)
return 0 // 0 is always treated as "modified" session content
}
return crc32.ChecksumIEEE(data)
}